├── images
├── cnn.png
├── logo.png
├── model.png
├── logo128.png
├── mannequin.png
├── rdm2RDspects.gif
├── spectrograms.png
└── real_synth_skeleton.gif
├── ifxaion
├── img
│ └── aion.png
└── README.md
├── EUSIPCO2022
├── images
│ ├── soli.gif
│ └── spectrograms.gif
└── README.md
├── _config.yml
├── configurations
└── radar_configs.csv
├── assets
└── css
│ └── style.scss
├── .gitignore
├── utils
├── radar.py
├── skeletons.py
├── provider.py
├── synthesize.py
├── visualization.py
├── slicer.py
└── preprocess.py
├── CITATION.cff
├── networks
├── img_transf.py
└── perceptual.py
├── README.md
├── visualize
└── notebook.ipynb
└── main.py
/images/cnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/cnn.png
--------------------------------------------------------------------------------
/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/logo.png
--------------------------------------------------------------------------------
/images/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/model.png
--------------------------------------------------------------------------------
/images/logo128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/logo128.png
--------------------------------------------------------------------------------
/ifxaion/img/aion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/ifxaion/img/aion.png
--------------------------------------------------------------------------------
/images/mannequin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/mannequin.png
--------------------------------------------------------------------------------
/images/rdm2RDspects.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/rdm2RDspects.gif
--------------------------------------------------------------------------------
/images/spectrograms.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/spectrograms.png
--------------------------------------------------------------------------------
/EUSIPCO2022/images/soli.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/EUSIPCO2022/images/soli.gif
--------------------------------------------------------------------------------
/images/real_synth_skeleton.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/real_synth_skeleton.gif
--------------------------------------------------------------------------------
/EUSIPCO2022/images/spectrograms.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/EUSIPCO2022/images/spectrograms.gif
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | remote_theme: pages-themes/cayman@v0.2.0
2 | plugins:
3 | - jekyll-remote-theme
4 | title: Machine Learning for Radar
5 | description: Additional material
6 |
--------------------------------------------------------------------------------
/configurations/radar_configs.csv:
--------------------------------------------------------------------------------
1 | name,I,II,III,IV
2 | ChirpsPerFrame,64,64,64,128
3 | SamplesPerChirp,256,256,256,256
4 | AdcSamplerate,2000,1500,1500,1500
5 | FramePeriod,50000,32000,50000,50000
6 | LowerFrequency,59000,58000,58000,58000
7 | UpperFrequency,61000,62000,62000,62000
8 |
--------------------------------------------------------------------------------
/ifxaion/README.md:
--------------------------------------------------------------------------------
1 | # Infineon Data Acquisition 
2 |
3 | Infineon's Data Acquisition (**ifxaion**) radar module is proprietary and cannot be publicly disclosed.
4 | For inquiries hereby, please contact
5 | [Lorenzo Servadei](https://linkedin.com/in/lorenzo-servadei-32140937).
6 |
--------------------------------------------------------------------------------
/assets/css/style.scss:
--------------------------------------------------------------------------------
1 | ---
2 | ---
3 |
4 | // Breakpoints
5 | $large-breakpoint: 64em !default;
6 | $medium-breakpoint: 42em !default;
7 |
8 | // Headers
9 | $header-heading-color: #fff !default;
10 | $header-bg-color: #aec16e;
11 | $header-bg-color-secondary: #88b7a7;
12 |
13 | // Text
14 | $section-headings-color: #db0036;
15 | $body-text-color: #606c71 !default;
16 | $body-link-color: #1e6bb8 !default;
17 | $blockquote-text-color: #819198 !default;
18 |
19 | // Code
20 | $code-bg-color: #f3f6fa !default;
21 | $code-text-color: #567482 !default;
22 |
23 | // Borders
24 | $border-color: #dce6f0 !default;
25 | $table-border-color: #e9ebec !default;
26 | $hr-border-color: #eff0f1 !default;
27 |
28 | @import "{{ site.theme }}";
29 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | env/
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # dotenv
85 | .env
86 |
87 | # virtualenv
88 | .venv
89 | venv/
90 | ENV/
91 |
92 | # Spyder project settings
93 | .spyderproject
94 | .spyproject
95 |
96 | # Rope project settings
97 | .ropeproject
98 |
99 | # mkdocs documentation
100 | /site
101 |
102 | # mypy
103 | .mypy_cache/
104 | visualization/
105 | log/
106 |
--------------------------------------------------------------------------------
/utils/radar.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Tue Jul 3 16:59:19 2018
4 |
5 | @author: Rodrigo Hernangomez
6 | """
7 |
8 | import numpy as np
9 |
10 | constants = {'c': 3e8, 'f0': 60e9}
11 |
12 | eps = 10.0 ** (- 100)
13 |
14 |
15 | def range_axis(bw, n_samples, c=constants['c']):
16 | """
17 | Calculate range resolution and scope.
18 | :param bw: FMCW bandwidth in Hz.
19 | :param n_samples: Number of samples per chirp.
20 | :param c: Speed of light in m/s.
21 | :return: Range resolution dr and scope rmax
22 | """
23 | dr = c / (2 * bw)
24 | rmax = dr * n_samples / 2
25 | return dr, rmax
26 |
27 |
28 | def doppler_axis(prt, n_chirps, lambda0=constants['c']/constants['f0']):
29 | """
30 | Calculate velocity resolution and scope.
31 | :param prt: Pulse Repetition Time in seconds.
32 | :param n_chirps: Number of chirps per frame.
33 | :param lambda0: Wavelength of the radar in meters. For 60 Ghz, this equals around 5 mm.
34 | :return: Velocity resolution dv and scope vmax
35 | """
36 | vmax = lambda0 / (4 * prt)
37 | dv = 2 * vmax / n_chirps
38 | return dv, vmax
39 |
40 |
41 | def rx_mask(rx1, rx2, rx3):
42 | mask = 0b0
43 | for i, rx in enumerate([rx1, rx2, rx3]):
44 | mask += rx << i
45 | return mask
46 |
47 |
48 | def normalize_db(db_data, axis=None):
49 | return db_data - np.max(db_data, axis=axis, keepdims=True)
50 |
51 |
52 | def absolute(rdm, normalize=False):
53 | rdm_abs = np.abs(rdm)
54 | if normalize:
55 | rdm_abs /= rdm_abs.max()
56 | return rdm_abs
57 |
58 |
59 | def mag2db(rdm, normalize=True):
60 | rdm_ = np.abs(rdm)
61 | rdm_db = 20 * np.log10(rdm_ + eps)
62 | if normalize:
63 | rdm_db = normalize_db(rdm_db)
64 | return rdm_db
65 |
66 |
67 | def db2mag(db):
68 | return np.power(10, db / 20)
69 |
70 |
71 | def complex2vector(array_cx):
72 | return np.stack([array_cx.real, array_cx.imag], axis=-1)
73 |
74 |
75 | def vector2complex(array_vec):
76 | cx_vec = array_vec.astype(np.complex)
77 | return cx_vec[..., 0] + cx_vec[..., 1] * 1j
78 |
79 |
80 | def affine_transform(x, a, b):
81 | return a * x + b
82 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use RACPIT's code or you take the publication as a reference for your research, please cite it using the preferred-citation information."
3 | authors:
4 | - family-names: Hernangómez
5 | given-names: Rodrigo
6 | orcid: https://orcid.org/0000-0002-1284-4951
7 | affiliation: Fraunhofer Heinrich Hertz Institute
8 | - family-names: Visentin
9 | given-names: Tristan
10 | affiliation: Fraunhofer Heinrich Hertz Institute
11 | - family-names: Servadei
12 | given-names: Lorenzo
13 | orcid: https://orcid.org/0000-0003-4322-834X
14 | affiliation: Infineon Technologies AG
15 | - family-names: Khodabakhshandeh
16 | given-names: Hamid
17 | affiliation: Fraunhofer Heinrich Hertz Institute
18 | - family-names: Stanczak
19 | given-names: Slawomir
20 | orcid: https://orcid.org/0000-0003-3829-4668
21 | affiliation: Fraunhofer Heinrich Hertz Institute
22 | title: "Code and additional information to 'RACPIT: Improving Radar Human Activity Classification Using Synthetic Data with Image Transformation'"
23 | version: 1.0.0
24 | doi: 10.3390/s22041519
25 | url: "https://www.mdpi.com/1424-8220/22/4/1519"
26 | preferred-citation:
27 | type: article
28 | title: 'Improving Radar Human Activity Classification Using Synthetic Data with Image Transformation'
29 | abstract:
30 | authors:
31 | - family-names: Hernangómez
32 | given-names: Rodrigo
33 | orcid: https://orcid.org/0000-0002-1284-4951
34 | affiliation: Fraunhofer Heinrich Hertz Institute
35 | - family-names: Visentin
36 | given-names: Tristan
37 | affiliation: Fraunhofer Heinrich Hertz Institute
38 | - family-names: Servadei
39 | given-names: Lorenzo
40 | orcid: https://orcid.org/0000-0003-4322-834X
41 | affiliation: Infineon Technologies AG
42 | - family-names: Khodabakhshandeh
43 | given-names: Hamid
44 | affiliation: Fraunhofer Heinrich Hertz Institute
45 | - family-names: Stanczak
46 | given-names: Slawomir
47 | orcid: https://orcid.org/0000-0003-3829-4668
48 | affiliation: Fraunhofer Heinrich Hertz Institute
49 | year: 2022
50 | month: 1
51 | volume: 22
52 | issue: 4
53 | start: 1519
54 | issn: '1424-8220'
55 | journal: Sensors
56 | doi: "10.3390/s22041519"
57 |
58 |
--------------------------------------------------------------------------------
/networks/img_transf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | # Conv Layer
5 | class ConvLayer(nn.Module):
6 | def __init__(self, in_channels, out_channels, kernel_size, stride):
7 | super(ConvLayer, self).__init__()
8 | padding = kernel_size // 2
9 | self.reflection_pad = nn.ReflectionPad2d(padding)
10 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) #, padding)
11 |
12 | def forward(self, x):
13 | out = self.reflection_pad(x)
14 | out = self.conv2d(out)
15 | return out
16 |
17 | # Upsample Conv Layer
18 | class UpsampleConvLayer(nn.Module):
19 | def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
20 | super(UpsampleConvLayer, self).__init__()
21 | self.upsample = upsample
22 | if upsample:
23 | self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest')
24 | reflection_padding = kernel_size // 2
25 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
26 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
27 |
28 | def forward(self, x):
29 | if self.upsample:
30 | x = self.upsample(x)
31 | out = self.reflection_pad(x)
32 | out = self.conv2d(out)
33 | return out
34 |
35 | # Residual Block
36 | # adapted from pytorch tutorial
37 | # https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-
38 | # intermediate/deep_residual_network/main.py
39 | class ResidualBlock(nn.Module):
40 | def __init__(self, channels):
41 | super(ResidualBlock, self).__init__()
42 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
43 | self.in1 = nn.InstanceNorm2d(channels, affine=True)
44 | self.relu = nn.ReLU()
45 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
46 | self.in2 = nn.InstanceNorm2d(channels, affine=True)
47 |
48 | def forward(self, x):
49 | residual = x
50 | out = self.relu(self.in1(self.conv1(x)))
51 | out = self.in2(self.conv2(out))
52 | out = out + residual
53 | out = self.relu(out)
54 | return out
55 |
56 | # Image Transform Network
57 | class ImageTransformNet(nn.Module):
58 | def __init__(self, num_channels=3):
59 | super(ImageTransformNet, self).__init__()
60 |
61 | # nonlineraity
62 | self.relu = nn.ReLU()
63 | self.tanh = nn.Tanh()
64 |
65 | # encoding layers
66 | self.conv1 = ConvLayer(num_channels, 32, kernel_size=9, stride=1)
67 | self.in1_e = nn.InstanceNorm2d(32, affine=True)
68 |
69 | self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
70 | self.in2_e = nn.InstanceNorm2d(64, affine=True)
71 |
72 | self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
73 | self.in3_e = nn.InstanceNorm2d(128, affine=True)
74 |
75 | # residual layers
76 | self.res1 = ResidualBlock(128)
77 | self.res2 = ResidualBlock(128)
78 | self.res3 = ResidualBlock(128)
79 | self.res4 = ResidualBlock(128)
80 | self.res5 = ResidualBlock(128)
81 |
82 | # decoding layers
83 | self.deconv3 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2 )
84 | self.in3_d = nn.InstanceNorm2d(64, affine=True)
85 |
86 | self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2 )
87 | self.in2_d = nn.InstanceNorm2d(32, affine=True)
88 |
89 | self.deconv1 = UpsampleConvLayer(32, num_channels, kernel_size=9, stride=1)
90 | self.in1_d = nn.InstanceNorm2d(num_channels, affine=True)
91 |
92 | def forward(self, x):
93 | # encode
94 | y = self.relu(self.in1_e(self.conv1(x)))
95 | y = self.relu(self.in2_e(self.conv2(y)))
96 | y = self.relu(self.in3_e(self.conv3(y)))
97 |
98 | # residual layers
99 | y = self.res1(y)
100 | y = self.res2(y)
101 | y = self.res3(y)
102 | y = self.res4(y)
103 | y = self.res5(y)
104 |
105 | # decode
106 | y = self.relu(self.in3_d(self.deconv3(y)))
107 | y = self.relu(self.in2_d(self.deconv2(y)))
108 | #y = self.tanh(self.in1_d(self.deconv1(y)))
109 | y = self.deconv1(y)
110 |
111 | return y
112 |
113 |
114 | class MultiTransformNet(nn.Module):
115 | def __init__(self, num_inputs=2, num_channels=1):
116 | super(MultiTransformNet, self).__init__()
117 | self.transformers = nn.ModuleList([ImageTransformNet(num_channels) for _ in range(num_inputs)])
118 |
119 | def forward(self, inputs):
120 | outputs = [trans(x) for x, trans in zip(inputs, self.transformers)]
121 | return outputs
122 |
--------------------------------------------------------------------------------
/utils/skeletons.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from ifxaion.daq import Daq
3 | import scipy.interpolate
4 | import xarray as xr
5 |
6 |
7 | def rcsellipsoid(a, b, c, phi,theta):
8 | rcs = (np.pi * a**2 * b**2 * c**2) / (a**2 * (np.sin(theta))**2 * (np.cos(phi))**2 +
9 | b**2 * (np.sin(theta))**2 * (np.sin(phi))**2 + c**2 * (np.cos(theta))**2)**2
10 | return rcs
11 |
12 |
13 | class Segment:
14 | """ Defines a body segment and allows to calculate RCS of segment """
15 |
16 | def __init__(self, segmentPositions, ellipsoidParams, aspect):
17 | self.ellipsoidParams = ellipsoidParams
18 | self.segmentPositions = segmentPositions
19 | self.aspect = aspect
20 |
21 |
22 | def calculateRange(self, radarloc):
23 | """
24 | Calculates the distance to the radar for all segment positions
25 | :return: vector of the distances [d1, d2, d3, ...]
26 | """
27 | self.r_dist = np.abs(self.segmentPositions - radarloc)
28 | self.r_total = np.sqrt(self.r_dist[:, 0] ** 2 + self.r_dist[:, 1] ** 2 + self.r_dist[:, 2] ** 2)
29 | return self.r_total
30 |
31 |
32 | def calculateAngles(self, radarloc):
33 | """
34 | Calculates the angles to the radar for all segment positions
35 | :return: vector of phi [phi1, phi2, phi3,...] and vector of theta [theta1,theta2,theta3,...]
36 | """
37 | A = np.column_stack((
38 | radarloc[0] - self.segmentPositions[:, 0],
39 | radarloc[1] - self.segmentPositions[:, 1],
40 | radarloc[2] - self.segmentPositions[:, 2]
41 | ))
42 |
43 | B = np.column_stack((
44 | self.aspect[:, 0],
45 | self.aspect[:, 1],
46 | self.aspect[:, 2]
47 | ))
48 |
49 | a_dot_b = np.sum(A * B, axis=1)
50 | a_sum_sqrt = np.sqrt(np.sum(A * A, axis=1))
51 | b_sum_sqrt = np.sqrt(np.sum(B * B, axis=1))
52 |
53 | theta = np.arccos(a_dot_b / (a_sum_sqrt * b_sum_sqrt))
54 | phi = np.arcsin((radarloc[1] - self.segmentPositions[:, 1]) /
55 | np.sqrt(self.r_dist[:, 0] ** 2 + self.r_dist[:, 1] ** 2))
56 |
57 | return phi, theta
58 |
59 | def calculateRCS(self, phiAngle, thetaAngle):
60 | """
61 | Calculates the RCS
62 | :return: vector of RCS [rcs1, rcs2, rcs3,...]
63 | """
64 | a = self.ellipsoidParams[0]
65 | b = self.ellipsoidParams[1]
66 | c = self.ellipsoidParams[2]
67 | rcs = rcsellipsoid(a, b, c, phiAngle, thetaAngle)
68 | return rcs
69 |
70 |
71 | def load(path, verbose=False):
72 | """
73 | Read the skeleton data with help of the Daq module.
74 | """
75 |
76 | daq = Daq(rec_dir=path)
77 |
78 | skeleton_df = daq.skeletons.data
79 |
80 | skeletonTimestamps = skeleton_df.index
81 |
82 | duration = (skeletonTimestamps[-1]-skeletonTimestamps[0]).total_seconds()
83 |
84 | if verbose:
85 | print("Reading skeleton data: {} skeletons, in {} second recording".format(len(skeleton_df), duration))
86 |
87 | return skeleton_df
88 |
89 |
90 | def interpolate(skeletons, timestamp_seconds):
91 | sk_data = np.array(skeletons.Data)
92 | sk_data = np.stack(sk_data)
93 | sk_timestamps = skeletons.index
94 |
95 | sk_interpolation = scipy.interpolate.interp1d(sk_timestamps.total_seconds(), sk_data,
96 | axis=0, fill_value="extrapolate", kind='linear')
97 | sk_new = sk_interpolation(timestamp_seconds)
98 | return sk_new
99 |
100 |
101 | def to_xarray(data, timestamps, name="Skeletons", attrs=None):
102 | skeleton_da = xr.DataArray(data, dims=("time", "space", "keypoints"),
103 | name=name, attrs=attrs,
104 | coords={"time": timestamps.to_numpy(),
105 | "space": ["x", "y", "z"],
106 | "keypoints": list(keypoints)}).assign_attrs(units="m")
107 | return skeleton_da
108 |
109 |
110 | def get_edges(skeleton):
111 | if isinstance(skeleton, xr.DataArray):
112 | sk_data = skeleton.values
113 | else:
114 | sk_data = skeleton
115 | return np.stack([[sk_frame[:, segment] for segment in coco_edges] for sk_frame in sk_data])
116 |
117 |
118 | keypoints = ("nose", "left_eye", "right_eye", "left_ear", "right_ear",
119 | "left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
120 | "left_wrist", "right_wrist", "left_hip", "right_hip",
121 | "left_knee", "right_knee", "left_ankle", "right_ankle")
122 |
123 | # Edges between keypoints used for the COCO API to visualize skeletons
124 | # See https://github.com/facebookresearch/Detectron/issues/640
125 | coco_edges = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7],
126 | [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RACPIT
2 | [][numpy]
3 | [][pytorch]
4 | [][pandas]
5 |
6 | This repository contains supplementary material for our article
7 | ["Improving Radar Human Activity Classification
8 | Using Synthetic Data with Image Transformation"](https://www.mdpi.com/1424-8220/22/4/1519),
9 | published in
10 | [MDPI Sensors](https://www.mdpi.com/journal/sensors)
11 | as part of the
12 | [Special Issue "Advances in Radar Sensors"](https://www.mdpi.com/journal/sensors/special_issues/radar_application).
13 | There we introduce **RACPIT**:
14 | Radar Activity Classification with Perceptual Image Transformation,
15 | a deep-learning approach to human activity classification using
16 | [FMCW radar](https://community.infineon.com/t5/Knowledge-Base-Articles/Understanding-FMCW-Radars-Features-and-operational-principles/ta-p/767198)
17 | and enhanced with synthetic data.
18 |
19 | ## Background
20 |
21 | ### Radar data
22 |
23 | We use **Range Doppler Maps (RDMs)**
24 | as a basis for our input data. These can be either real data acquired
25 | with Infineon's
26 | [Radar sensors for IoT](https://www.infineon.com/cms/en/product/sensor/radar-sensors/radar-sensors-for-iot/)
27 | or synthetic using kinematic data with the following model:
28 |
29 | $\Large s\left(t\right)=\sum_{k}{\sqrt{\frac{A_{k,t}}{L_{k,t}}}\sin{\left(2\pi f_{k,t}t+\phi_{k,t}\right)}}$
30 |
31 |
32 |

33 |
34 |
35 | $A_{k,t}$,
36 | $L_{k,t}$,
37 | $f_{k,t}$ and
38 | $\phi_{k,t}$
39 | represent the radar cross section, free-space path loss,
40 | instant frequency and instant phase, respectively,
41 | of the returned and mixed-down signal for every modelled human limb
42 | $k$
43 | and instant
44 | $t$.
45 | The latter three parameters depend
46 | on the instantaneous distance of the limb to the radar sensor,
47 | $d_{k,t}$,
48 | and are calculated using the customary
49 | [radar](https://www.radartutorial.eu/01.basics/The%20Radar%20Range%20Equation.en.html) and
50 | [FMCW](https://www.radartutorial.eu/02.basics/Frequency%20Modulated%20Continuous%20Wave%20Radar.en.html)
51 | equations.
52 |
53 | 
54 |
55 | We further preprocess the RDMs by stacking them and summing over Doppler and range axis
56 | to obtain range and Doppler spectrograms, respectively:
57 |
58 | 
59 |
60 | ### Deep learning
61 |
62 | We train our image transformation networks with an adapted version of
63 | [Perceptual Losses for Real-Time Style Transfer and Super-Resolution][perceptual].
64 |
65 | [perceptual]: https://arxiv.org/abs/1603.08155
66 |
67 | 
68 |
69 | Since we are working with radar data, we substitute VGG16 as the perceptual network
70 | with our two-branch convolutional neural network from
71 | [Domain Adaptation Across Configurations of FMCW Radar for Deep Learning Based Human Activity Classification](https://doi.org/10.23919/IRS51887.2021.9466179).
72 |
73 |
74 |
75 | If we train the image transformation networks with real data as our input and synthetic data as our ground truth,
76 | we obtain a denoising behavior for the image transformation networks.
77 |
78 |
79 |
80 | ## Implementation
81 |
82 | The code has been written for
83 | PyTorch based on
84 | [Daniel Yang's implementation](https://github.com/dxyang/StyleTransfer)
85 | of [Perceptual loss][perceptual].
86 |
87 | Data preprocessing is heavily based on
88 | *x*array. You can take a closer look at it
89 | in our
90 | [example](./visualize).
91 |
92 | ### Prerequisites
93 | - [Python 3.8](https://www.python.org/)
94 | - [PyTorch 1.7.0][pytorch]
95 | - [*x*array](https://xarray.pydata.org)
96 | - [NumPy][numpy]
97 | - [Pandas][pandas]
98 | - [Matplotlib](https://matplotlib.org/)
99 | - [Cuda 11.0](https://developer.nvidia.com/cuda-11.0-download-archive)
100 | (For GPU training)
101 |
102 | [numpy]: http://www.numpy.org/
103 | [pytorch]: http://pytorch.org/
104 | [pandas]: https://pandas.pydata.org/
105 |
106 | ### Usage
107 |
108 | Radar data can be batch-preprocessed and stored
109 | for faster training:
110 |
111 | ```bash
112 | $ python utils/preprocess.py --raw "/path/to/data/raw" --output "/path/to/data/real" --value "db" --marginalize "incoherent"
113 | $ python utils/preprocess.py --raw "/path/to/data/raw" --output "/path/to/data/synthetic" --synthetic --value "db" --marginalize "incoherent"
114 | ```
115 |
116 | After this, you can train your CNN, that will serve as a perceptual network:
117 |
118 | ```bash
119 | $ python main.py --log "cnn" train-classify --range --config "I" --gpu 0 --no-split --dataset "/path/to/data/synthetic"
120 | ```
121 |
122 | Then you can train the image transformation networks:
123 |
124 | ```bash
125 | $ python main.py --log "trans" train-transfer --range --config "I" --gpu 0 --visualize 5 --input "/path/to/data/real" --output "/path/to/data/synthetic" --recordings first --model "models/cnn.model"
126 | ```
127 |
128 | And finally test the whole pipeline:
129 |
130 | ```bash
131 | $ python main.py test --range --config "I" --gpu 0 --visualize 10 --dataset "/path/to/data/real" --recordings last --transformer "models/trans.model" --model "models/cnn.model"
132 | ```
133 |
134 | ## Citation
135 |
136 | If you use RACPIT's code or you take the publication as a reference for your research,
137 | please cite our work in the following way:
138 |
139 | ```bibtex
140 | @Article{s22041519,
141 | AUTHOR = {Hernang{\'o}mez, Rodrigo and Visentin, Tristan and Servadei, Lorenzo and Khodabakhshandeh, Hamid and Sta{\'n}czak, S{\l}awomir},
142 | TITLE = {Improving Radar Human Activity Classification Using Synthetic Data with Image Transformation},
143 | JOURNAL = {Sensors},
144 | VOLUME = {22},
145 | YEAR = {2022},
146 | NUMBER = {4},
147 | ARTICLE-NUMBER = {1519},
148 | URL = {https://www.mdpi.com/1424-8220/22/4/1519},
149 | ISSN = {1424-8220},
150 | DOI = {10.3390/s22041519}
151 | }
152 | ```
153 |
--------------------------------------------------------------------------------
/networks/perceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models
4 | from networks.img_transf import ImageTransformNet, MultiTransformNet
5 |
6 |
7 | class Vgg16(nn.Module):
8 | def __init__(self):
9 | super(Vgg16, self).__init__()
10 | # Convert 1-channel image to 3-channel
11 | model = models.vgg16(pretrained=True)
12 | first_conv_layer = [nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)]
13 | first_conv_layer.extend(list(model.features))
14 | model.features = nn.Sequential(*first_conv_layer)
15 | features = model.features
16 |
17 | self.to_relu_1_2 = nn.Sequential()
18 | self.to_relu_2_2 = nn.Sequential()
19 | self.to_relu_3_3 = nn.Sequential()
20 | self.to_relu_4_3 = nn.Sequential()
21 |
22 | for x in range(4):
23 | self.to_relu_1_2.add_module(str(x), features[x])
24 | for x in range(4, 9):
25 | self.to_relu_2_2.add_module(str(x), features[x])
26 | for x in range(9, 16):
27 | self.to_relu_3_3.add_module(str(x), features[x])
28 | for x in range(16, 23):
29 | self.to_relu_4_3.add_module(str(x), features[x])
30 |
31 | # don't need the gradients, just want the features
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 | def forward(self, x):
36 | h = self.to_relu_1_2(x)
37 | h_relu_1_2 = h
38 | h = self.to_relu_2_2(h)
39 | h_relu_2_2 = h
40 | h = self.to_relu_3_3(h)
41 | h_relu_3_3 = h
42 | h = self.to_relu_4_3(h)
43 | h_relu_4_3 = h
44 | out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3)
45 | return out
46 |
47 |
48 | class RDConv(nn.Module):
49 | """
50 | Convolutional layer(s) for radar spectrograms
51 | """
52 | def __init__(self, num_branches):
53 | super(RDConv, self).__init__()
54 | self.branches = nn.ModuleList([self._branch() for _ in range(num_branches)])
55 |
56 | @staticmethod
57 | def _branch():
58 | conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(7, 7), padding=(3, 3))
59 | relu1 = nn.ReLU()
60 | max1 = nn.MaxPool2d(4, 3)
61 | conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5, 5), padding=(2, 2))
62 | relu2 = nn.ReLU()
63 | max2 = nn.MaxPool2d(3, 2)
64 | conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=(1, 1))
65 | relu3 = nn.ReLU()
66 | max3 = nn.MaxPool2d(3, 2)
67 |
68 | return nn.ModuleList([conv1, relu1, max1, conv2, relu2, max2,
69 | conv3, relu3, max3])
70 |
71 | def _branch_forward(self, x, branch=0, flatten=True):
72 | for layer in self.branches[branch]:
73 | x = layer(x)
74 | if flatten:
75 | x = torch.flatten(x, start_dim=1)
76 | return x
77 |
78 | def forward(self, inputs):
79 | if len(self.branches) > 1 or any(isinstance(inputs, _type) for _type in (list, tuple)):
80 | features = []
81 | for i, x in enumerate(inputs):
82 | features.append(self._branch_forward(x, i))
83 | return torch.cat(features, dim=1)
84 | else:
85 | return self._branch_forward(inputs)
86 |
87 | def output_num(self, input_shapes):
88 | inputs = [torch.zeros(1, *s[1:]) for s in input_shapes]
89 | if len(inputs) == 1:
90 | inputs = inputs[0]
91 | y = self.forward(inputs)
92 | return y.size()[-1]
93 |
94 |
95 | class RDNet(nn.Module):
96 | def __init__(self, input_shapes=None, class_num=5):
97 | super(RDNet, self).__init__()
98 | # set base network
99 |
100 | branches = len(input_shapes)
101 | self.conv_layers = RDConv(branches)
102 | features_num = self.conv_layers.output_num(input_shapes)
103 | self.classifier_layer_list = [nn.Linear(features_num, 16), nn.ReLU(), nn.Dropout(0.1),
104 | nn.Linear(16, 16), nn.ReLU(), nn.Dropout(0.1),
105 | nn.Linear(16, 8), nn.ReLU(), nn.Dropout(0.1),
106 | nn.Linear(8, class_num)]
107 | self.classifier_layer = nn.Sequential(*self.classifier_layer_list)
108 | self.softmax = nn.Softmax(dim=1)
109 |
110 | # initialization
111 | for dep in range(4):
112 | self.classifier_layer[dep * 3].weight.data.normal_(0, 0.01)
113 | self.classifier_layer[dep * 3].bias.data.fill_(0.0)
114 |
115 | def forward(self, inputs):
116 | features = self.conv_layers(inputs)
117 | outputs = self.classifier_layer(features)
118 | return outputs
119 |
120 | def predict(self, inputs):
121 | outputs = self.forward(inputs)
122 | return self.softmax(outputs)
123 |
124 |
125 | class RDPerceptual(RDNet):
126 | """
127 | Use a trained RD classifier for perceptual loss
128 | """
129 |
130 | def __init__(self, model_path, input_shapes=None, class_num=5):
131 | super(RDPerceptual, self).__init__(input_shapes=input_shapes, class_num=class_num)
132 | self.load_state_dict(torch.load(model_path))
133 |
134 | for param in self.parameters():
135 | param.requires_grad = False
136 |
137 | def forward(self, inputs):
138 | features = self.conv_layers(inputs)
139 | outputs = self.classifier_layer(features)
140 | return outputs, features
141 |
142 |
143 | class RACPIT(RDNet):
144 | """
145 | RD classifier including image transformer
146 | """
147 | def __init__(self, trans_path, model_path, input_shapes=None, class_num=5, train_classifier=False):
148 | super(RACPIT, self).__init__(input_shapes=input_shapes, class_num=class_num)
149 | self.load_state_dict(torch.load(model_path))
150 |
151 | if len(input_shapes) > 1:
152 | self.transformer = MultiTransformNet(num_inputs=len(input_shapes), num_channels=1)
153 | else:
154 | self.transformer = ImageTransformNet(num_channels=1)
155 | self.transformer.load_state_dict(torch.load(trans_path))
156 |
157 | frozen_params = self.transformer.parameters() if train_classifier else self.parameters()
158 | for param in frozen_params:
159 | param.requires_grad = False
160 |
161 | def forward(self, x):
162 | if len(x) == 1:
163 | x = x[0]
164 | x_trans = self.transformer(x)
165 | return super(RACPIT, self).forward(x_trans)
166 |
--------------------------------------------------------------------------------
/EUSIPCO2022/README.md:
--------------------------------------------------------------------------------
1 | # Unsupervised Domain Adaptation across FMCW Radar Configurations Using Margin Disparity Discrepancy
2 |
3 | The content of this page serves as supplementary material to the presentation of our work at [EUSIPCO2022](https://2022.eusipco.org/) in Belgrade. [A preprint is also available][preprint].
4 |
5 | [preprint]: https://arxiv.org/abs/2203.04588
6 |
7 | ## Radar ML
8 |
9 | Radar sensing has gained research interest over the past few years due to several
10 | factors. On the one hand, semiconductor companies have managed to produce highly
11 | integrated radar chipsets thanks to the
12 | [Frequency-modulated continuous
13 | wave (FMCW)](https://www.infineon.com/dgdl/Infineon-Radar%20FAQ-PI-v02_00-EN.pdf?fileId=5546d46266f85d6301671c76d2a00614) technology. An example of this is the
14 | [60 GHz sensor frontend](https://www.infineon.com/cms/en/product/promopages/60GHz/)
15 | developed by [Infineon](https://www.infineon.com/)
16 | that lies at the core of Google's [Soli Project](https://atap.google.com/soli/).
17 |
18 | 
19 |
20 | On the other hand, radar sensors enable interesting
21 | [IoT applications](https://www.infineon.com/cms/en/product/sensor/radar-sensors/radar-sensors-for-iot/), such as human activity and surveillance or hand gesture
22 | recognition, and it presents some advantages to cameras when
23 | it comes to privacy concerns and ill-posed optical scenarios, including
24 | through-wall and bad-lighting situations.
25 |
26 | All the data used in this work has been recorded with Infineon's
27 | [BGT60TR13C](https://www.infineon.com/cms/en/product/sensor/radar-sensors/radar-sensors-for-iot/60ghz-radar/bgt60tr13c/)
28 | radar sensors.
29 |
30 | ### Signal preprocessing
31 |
32 | Range and speed (also known as Doppler shift) of FMCW radar targets can
33 | be explored through so-called Range-Doppler maps, which convey information
34 | from the radar targets over both dimensions.
35 |
36 | Since a series of Range Doppler maps is 3-dimensional and thus requires large neural networks to work with,
37 | we extract so-called range and Doppler spectrograms by summing over
38 | the corresponding axis. This allows us to train a light neural network that takes both range and
39 | Doppler dynamic signatures and extracts the most relevant features from
40 | them through two separate branches convolutional layers.
41 |
42 | 
43 |
44 | 
45 |
46 | ## Domain adaptation
47 |
48 | The complexity of the classification of radar signatures has driven radar approaches to
49 | resort to machine learning and deep learning techniques that require
50 | high amounts of data. This poses a challenge to radar, especially
51 | when aspects such as the sensor's configuration, the environment or user
52 | settings have a large diversity. Each of these aspects may indeed lead to changes in the data domain in the sense of **domain adaptation theory**.
53 |
54 | Domain adaptation techniques consider the case where the underlying probability distribution of data differs
55 | between a **source domain**, from which we can draw sufficient training
56 | data, and a **target domain** that we ultimately aim for deployment.
57 |
58 | Such a framework allows an easier deployment from pre-trained models,
59 | where only few labeled data from the target domain is required for fine-tuning.
60 | This is known as **supervised domain adaptation** and we
61 | already investigated for radar in [[3]](#ref3). For this paper, we have
62 | moved on to the more demanding situation where the target data is
63 | provided without labels from the label space, also called **unsupervised
64 | domain adaptation**.
65 |
66 | ### Margin Disparity Discrepancy
67 |
68 | The unsupervised domain adaptation
69 | method we have used is called Margin Disparity Discrepancy (MDD) and it has been recently
70 | developed by Zhang et al. [[1]](#ref1) based on theoretical guarantees and tested on
71 | computer vision datasets. The authors have proved that the true error in the
72 | target domain can be bounded by the empirical error in the source domain
73 | plus a residual ideal loss and the so-called MDD term.
74 |
75 | As a small **theoretical contribution to MDD**, we have managed to upper
76 | bound the empirical source error from the original, which originally
77 | uses the ramp loss, by its cross-entropy counterpart.
78 | This is important since the ramp loss is non-convex und thus unpractical for ML training.
79 | In fact, the authors of [[1]](#ref1) perform their experiments with the cross-entropy loss.
80 |
81 | The full proof for the cross-entropy bound is in [our paper][preprint]. As a visual intuition,
82 | though, one can just regard the cross-entropy as a smooth
83 | version of hinge loss, which in turn is nothing more
84 | than a convex relaxation of the ramp loss.
85 |
86 | 
87 |
88 | ## References
89 |
90 | 1. Zhang, Y., Liu, T., Long, M. and Jordan, M. (2019) 'Bridging Theory
91 | and Algorithm for Domain Adaptation', in *International Conference
92 | on Machine Learning*. *International Conference on Machine
93 | Learning*, PMLR, pp. 7404--7413. Available at: .
94 | 2. Liang, X., Wang, X., Lei, Z., Liao, S. and Li, S.Z. (2017)
95 | 'Soft-Margin Softmax for Deep Classification', in D. Liu, S. Xie, Y.
96 | Li, D. Zhao, and E.-S.M. El-Alfy (eds) *Neural Information
97 | Processing*. Cham: Springer International Publishing (Lecture Notes
98 | in Computer Science), pp. 413--421. Available at:
99 | .
100 | 3. Khodabakhshandeh, H., Visentin, T., Hernangómez, R. and
101 | Pütz, M. (2021) 'Domain Adaptation Across Configurations of FMCW
102 | Radar for Deep Learning Based Human Activity Classification', in
103 | *2021 21st International Radar Symposium (IRS)*. *2021 21st
104 | International Radar Symposium (IRS)*, Berlin, Germany, pp. 1--10. Available at: .
105 | 4. Hernangómez, R., Bjelakovic, I., Servadei, L. and
106 | Stańczak, S. (2022) 'Unsupervised Domain Adaptation across FMCW
107 | Radar Configurations Using Margin Disparity Discrepancy', in *2022
108 | 30th European Signal Processing Conference (EUSIPCO)*. Belgrade,
109 | Serbia. Available at: .
110 | 5. Hernangómez, R., Santra, A. and Stańczak, S. (2019) 'Human Activity
111 | Classification with Frequency Modulated Continuous Wave Radar Using
112 | Deep Convolutional Neural Networks', in *2019 International Radar
113 | Conference (RADAR)*. *2019 International Radar Conference (RADAR)*,
114 | Toulon, France: IEEE, pp. 1--6. Available at:
115 | .
116 | 6. Hernangómez, R., Santra, A. and Stańczak, S. (2021) 'Study on
117 | feature processing schemes for deep-learning-based human activity
118 | classification using frequency-modulated continuous-wave radar',
119 | *IET Radar, Sonar & Navigation*, 15(8), pp. 932--944. Available at:
120 | .
121 | 7. Lien, J., Gillian, N., Karagozler, M.E., Amihood, P., Schwesig, C.,
122 | Olson, E., Raja, H. and Poupyrev, I. (2016) 'Soli: ubiquitous
123 | gesture sensing with millimeter wave radar', *ACM Transactions on
124 | Graphics*, 35(4), p. 142:1-142:19. Available at:
125 | .
126 | 8. Motiian, S., Jones, Q., Iranmanesh, S. and Doretto, G. (2017)
127 | 'Few-Shot Adversarial Domain Adaptation', *Advances in Neural
128 | Information Processing Systems*, 30. Available at:
129 | .
130 | 9. Santra, A. and Hazra, S. (2020) *Deep learning applications of
131 | short-range radars*. Artech House.
132 |
--------------------------------------------------------------------------------
/utils/provider.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from collections.abc import Sequence
4 |
5 | from utils import preprocess
6 | from utils.slicer import train_test_slice
7 | from utils.radar import affine_transform, normalize_db as norm_db
8 | from utils.visualization import spec_plot
9 |
10 | import torch.utils.data as util_data
11 |
12 |
13 | class _TimeInfo(list):
14 | def __init__(self, dataset, time_labels=("date", "offset"), dim="batch_index", drop_attrs=True):
15 | self._ds = dataset[list(time_labels)]
16 | if drop_attrs:
17 | self._ds.attrs = {}
18 | self.iter_dim = dim
19 | super(_TimeInfo, self).__init__()
20 |
21 | def __getitem__(self, index):
22 | return self._ds[{self.iter_dim: index}]
23 |
24 | def __setitem__(self, key, value):
25 | raise TypeError("Assignment operation not permitted for TimeInfo")
26 |
27 | def __len__(self):
28 | return self._ds.sizes[self.iter_dim]
29 |
30 | def __str__(self):
31 | return self._ds.__str__()
32 |
33 |
34 | class RadarDataset(util_data.Dataset):
35 | """
36 | Base class to use as a dataset.
37 | Sequence superclass is straightforwardly exchangeable with the appropriate Tensorflow/Pytorch superclass
38 | """
39 | def __init__(self, recordings, slices=None, normalize_db=True, clip=(-40, 0), add_channel_dim=-3,
40 | dim="batch_index", class_attr="activities", label_var="label", ignore_dims=None,
41 | norm_values=None, feature_dtype=np.single, label_dtype=np.longlong):
42 |
43 | self.label = label_var
44 | self.iter_dim = dim
45 | self.class_attr = class_attr
46 |
47 | self.ftype = feature_dtype
48 | self.ltype = label_dtype
49 |
50 | process_funcs = []
51 | if normalize_db:
52 | preprocess.proc_register(process_funcs, norm_db, axis=(-1, -2))
53 | if clip is not None:
54 | preprocess.proc_register(process_funcs, np.clip, clip[0], clip[1])
55 | if norm_values is not None: # Change the range of the data from `clip` to `norm_values`
56 | a = (norm_values[1] - norm_values[0])/(clip[1] - clip[0])
57 | b = norm_values[0] - a * clip[0]
58 | preprocess.proc_register(process_funcs, affine_transform, a, b)
59 | self._process_funcs = process_funcs
60 |
61 | if ignore_dims is True:
62 | self.ignore_dims = tuple(c for c in recordings[0].coords)
63 | else:
64 | self.ignore_dims = ignore_dims
65 | self._slices = slices
66 | self.channel_dim = add_channel_dim
67 | self._recordings = recordings
68 | self.attrs = dict_intersection([r.attrs for r in recordings])
69 | self.total_bytes = preprocess.get_size(recordings, loaded_only=False)
70 |
71 | self.dataset = self.Iterable(self)
72 |
73 | super(RadarDataset, self).__init__()
74 |
75 | def _access_item(self, index):
76 | if isinstance(index, int):
77 | item = preprocess.slice_recording(self._recordings, self._slices[index])
78 | elif isinstance(index, slice):
79 | item = preprocess.recs2dataset(self._recordings, self._slices[index], ignore_dims=self.ignore_dims)
80 | elif isinstance(index, np.ndarray):
81 | item = preprocess.recs2dataset(self._recordings, [self._slices[i] for i in index],
82 | ignore_dims=self.ignore_dims)
83 | else:
84 | raise IndexError(f"Only ints, np.int arrays and slices are accepted, got {type(index)}")
85 | return item
86 |
87 | def __getitem__(self, index):
88 | item = self._access_item(index)
89 | labels = self._get_labels(item).astype(self.ltype).values
90 | features_ds = preprocess.apply_processing(self._get_features(item), self._process_funcs)
91 | if self.channel_dim is not None:
92 | features_ds = features_ds.expand_dims(dim="channel", axis=self.channel_dim)
93 |
94 | features = [f.astype(self.ftype).values for f in features_ds.values()]
95 | return features, labels
96 |
97 | def plot(self, index, axes=None, **kwargs):
98 | features = self.dataset[index]
99 | spec_plot(features, axes=axes, **kwargs)
100 |
101 | def __len__(self):
102 | return len(self._slices)
103 |
104 | def _get_features(self, dataset):
105 | return dataset.drop(self.label)
106 |
107 | def _get_labels(self, dataset):
108 | return dataset[self.label]
109 |
110 | @property
111 | def feature_shapes(self):
112 | features, _ = self.__getitem__(0)
113 | length = (self.__len__(),)
114 | shapes = [length + f.shape for f in features]
115 | return shapes
116 |
117 | def scope(self, unit):
118 | item = self._access_item(0)
119 | return preprocess.get_scope(item, unit)
120 |
121 | @property
122 | def loaded_bytes(self):
123 | return preprocess.get_size(self._recordings, loaded_only=True)
124 |
125 | @property
126 | def class_num(self):
127 | return len(self.attrs[self.class_attr])
128 |
129 | @property
130 | def branches(self):
131 | return len(self.feature_shapes)
132 |
133 | class Iterable(Sequence):
134 | def __init__(self, radar_dataset):
135 | self._rds = radar_dataset
136 | super(RadarDataset.Iterable, self).__init__()
137 |
138 | def __len__(self):
139 | assert isinstance(self._rds, RadarDataset)
140 | return len(self._rds._slices)
141 |
142 | def __getitem__(self, index):
143 | assert isinstance(self._rds, RadarDataset)
144 | item = self._rds._access_item(index)
145 | ds = preprocess.apply_processing(self._rds._get_features(item), self._rds._process_funcs)
146 | return ds
147 |
148 |
149 | def load_rd(config, preprocessed_path, spec_length, stride,
150 | range_length=None, doppler_length=None, train_load=0.8, gpu=False, split=None):
151 | """
152 | Split, slice and load preprocessed data as range/Doppler spectrograms to be used with Pytorch/Tensorflow
153 | :param config: Radar configuration to load, e.g. "E"
154 | :param preprocessed_path: Path to the preprocessed data
155 | :param spec_length: Length of the spectrograms in bins
156 | :param stride: Spectrogram stride or hop length for the slicing
157 | :param range_length: If provided, data will be cropped on the range axis from 0 to range_length
158 | :param doppler_length: If provided, data will be cropped on the doppler axis from 0 to doppler_length
159 | :param train_load: Target train load to perform train-test split
160 | :param gpu: If True, load the xarray.Datasets. This is normally desirable if training in the GPU cluster.
161 | :param split: Type of split.
162 | It can be None (no split), "deterministic", single", "double" (see train_test_slice) or a dictionary
163 | containing already split segments (useful for reproducibility)
164 | :return: RadarDataset objects with the data. Segments are also returned for reproducibility
165 | """
166 | recordings = preprocess.open_recordings(config, preprocessed_path, load=gpu,
167 | range_length=range_length, doppler_length=doppler_length)
168 |
169 | if split is None:
170 | slices = train_test_slice(recordings, spec_length, stride, train_load, split=split)
171 | rds = RadarDataset(recordings, slices=slices)
172 | return rds
173 | elif isinstance(split, dict):
174 | slices = train_test_slice(recordings, spec_length, stride, train_load,
175 | split=split, return_segments=False)
176 | rd_datasets = [RadarDataset(recordings, slices=s) for s in slices]
177 | return rd_datasets
178 | else:
179 | slices = train_test_slice(recordings, spec_length, stride, train_load,
180 | split=split, return_segments=True)
181 | segments = slices.pop(-1)
182 | rd_datasets = [RadarDataset(recordings, slices=s) for s in slices]
183 | return rd_datasets, segments
184 |
185 |
186 | def dict_intersection(dictionaries):
187 | intersection = set() | dictionaries[0].items()
188 | for d in dictionaries:
189 | intersection &= d.items()
190 | return dict(intersection)
191 |
--------------------------------------------------------------------------------
/utils/synthesize.py:
--------------------------------------------------------------------------------
1 | from utils import radar, skeletons as sk
2 | import numpy as np
3 | # import scipy.io
4 |
5 |
6 | def synthetic_radar(skeletons, config, frame_times=None):
7 | """
8 | :param skeletons: skeleton data as provided by the Daq module
9 | :param config: the configuration of the radar
10 | :param frame_times: the start times of all frames in the format [f1,f2,...]
11 | :return: Synthetic data in the format (nFrames, 1, nChirps, nSamples)
12 | """
13 |
14 | radarLoc = [p[0] for p in config["position"]]
15 | timestamps = skeletons.index
16 | startTime = timestamps[0].total_seconds()
17 | endTime = timestamps[-1].total_seconds()
18 |
19 | name = config["cfg"]
20 | f_low = config["LowerFrequency"] * 1e6
21 | f_up = config["UpperFrequency"] * 1e6
22 | numChirps = config["ChirpsPerFrame"]
23 | numSamples = config["SamplesPerChirp"]
24 | framePeriod = config["FramePeriod"] * 1e-6
25 | chirpPeriod = config["ChirpToChirpTime"] * 1e-9
26 | samplePeriod = 1e-3 / config["AdcSamplerate"]
27 |
28 | print("Create synthetic data for configuration {}".format(name))
29 | if frame_times is None:
30 | frame_times = np.arange(startTime, endTime, framePeriod)
31 | print("Creating {} frames of synthetic data with framePeriod {}".format(len(frame_times), framePeriod))
32 | else:
33 | print("Creating {} frames of synthetic data from provided frameTimes".format(len(frame_times)))
34 |
35 | C = 299792458 # m / s
36 |
37 | numFrames = len(frame_times)
38 |
39 | # array containing the start time of all chirps in the format [frames, chirps]
40 | chirpTimes = np.linspace(frame_times, frame_times + (numChirps - 1) * chirpPeriod, numChirps).T
41 |
42 | # Interpolation to increase temporal resolution of the skeleton data
43 | # returns segment position at each chirp start time
44 | sk_data = sk.interpolate(skeletons, chirpTimes.flatten())
45 |
46 | # if saveMatlab:
47 | # print("Save skeletons for matlab to location: {}".format(saveMatlab.absolute()))
48 | # scipy.io.savemat(saveMatlab, {'data': data,
49 | # 'timestamp': timestamps.to_numpy()})
50 |
51 | # Height of the person
52 | heightNose = np.mean([sk_keypoints[2, 0] for sk_keypoints in skeletons.Data])
53 | Height = heightNose + 0.16
54 | # body segments length (meter)
55 | headlen = 0.130 * Height
56 | shoulderlen = (0.259 / 2) * Height
57 | torsolen = 0.288 * Height
58 | hiplen = (0.191 / 2) * Height
59 | upperleglen = 0.245 * Height
60 | # lowerleglen = 0.246 * Height
61 | # footlen = 0.143 * Height
62 | upperarmlen = 0.188 * Height
63 | lowerarmlen = 0.152 * Height
64 | # Ht = upperleglen + lowerleglen
65 |
66 | # Get coordinates of the person
67 | head = (sk_data[:, :, 3] + sk_data[:, :, 4]) / 2
68 |
69 | neck = head.copy()
70 | neck[:, 2] -= 0.17
71 |
72 | base = (sk_data[:, :, 11] + sk_data[:, :, 12]) / 2
73 | base[:, 2] += 0.1
74 |
75 | lshoulder = sk_data[:, :, 5]
76 | lelbow = sk_data[:, :, 7]
77 | lhand = sk_data[:, :, 9]
78 |
79 | lhip = sk_data[:, :, 11]
80 | lknee = sk_data[:, :, 13]
81 | lankle = sk_data[:, :, 15]
82 | # ltoe = lankle.copy()
83 |
84 | rshoulder = sk_data[:, :, 6]
85 | relbow = sk_data[:, :, 8]
86 | rhand = sk_data[:, :, 10]
87 |
88 | rhip = sk_data[:, :, 12]
89 | rknee = sk_data[:, :, 14]
90 | rankle = sk_data[:, :, 16]
91 | # rtoe = rankle.copy()
92 |
93 | torso = (neck + base) / 2
94 | lupperarm = (lshoulder + lelbow) / 2
95 | rupperarm = (rshoulder + relbow) / 2
96 | lupperleg = (lhip + lknee) / 2
97 | rupperleg = (rhip + rknee) / 2
98 | llowerleg = (lankle + lknee) / 2
99 | rlowerleg = (rankle + rknee) / 2
100 |
101 | # Based on "A global human walking model with real-time kinematic personification",
102 | # by R. Boulic, N.M. Thalmann, and D. Thalmann % The Visual Computer, vol .6, pp .344 - 358, 1990
103 | segments = [
104 | sk.Segment(
105 | segmentPositions=head,
106 | ellipsoidParams=[0.1, 0.1, headlen / 2],
107 | aspect=head - neck
108 | ),
109 | sk.Segment(
110 | segmentPositions=torso,
111 | ellipsoidParams=[0.15, 0.15, torsolen / 2],
112 | aspect=neck - base
113 | ),
114 | sk.Segment(
115 | segmentPositions=lshoulder,
116 | ellipsoidParams=[0.06, 0.06, shoulderlen / 2],
117 | aspect=lshoulder - neck
118 | ),
119 | sk.Segment(
120 | segmentPositions=rshoulder,
121 | ellipsoidParams=[0.06, 0.06, shoulderlen / 2],
122 | aspect=rshoulder - neck
123 | ),
124 | sk.Segment(
125 | segmentPositions=lupperarm,
126 | ellipsoidParams=[0.06, 0.06, upperarmlen / 2],
127 | aspect=lshoulder - lelbow
128 | ),
129 | sk.Segment(
130 | segmentPositions=rupperarm,
131 | ellipsoidParams=[0.06, 0.06, upperarmlen / 2],
132 | aspect=rshoulder - relbow
133 | ),
134 | sk.Segment(
135 | segmentPositions=lhand,
136 | ellipsoidParams=[0.05, 0.05, lowerarmlen / 2],
137 | aspect=lelbow - lhand
138 | ),
139 | sk.Segment(
140 | segmentPositions=rhand,
141 | ellipsoidParams=[0.05, 0.05, lowerarmlen / 2],
142 | aspect=relbow - rhand
143 | ),
144 | sk.Segment(
145 | segmentPositions=lhip,
146 | ellipsoidParams=[0.07, 0.07, hiplen / 2],
147 | aspect=lhip - base
148 | ),
149 | sk.Segment(
150 | segmentPositions=rhip,
151 | ellipsoidParams=[0.07, 0.07, hiplen / 2],
152 | aspect=rhip - base
153 | ),
154 | sk.Segment(
155 | segmentPositions=lupperleg,
156 | ellipsoidParams=[0.07, 0.07, upperleglen / 2],
157 | aspect=lknee - lhip
158 | ),
159 | sk.Segment(
160 | segmentPositions=rupperleg,
161 | ellipsoidParams=[0.07, 0.07, upperleglen / 2],
162 | aspect=rknee - rhip
163 | ),
164 | sk.Segment(
165 | segmentPositions=llowerleg,
166 | ellipsoidParams=[0.06, 0.06, upperleglen / 2],
167 | aspect=lankle - lknee
168 | ),
169 | sk.Segment(
170 | segmentPositions=rlowerleg,
171 | ellipsoidParams=[0.06, 0.06, upperleglen / 2],
172 | aspect=rankle - rknee
173 | )
174 | ]
175 |
176 | # define timestamps of a single chirp
177 | t = np.linspace(0, (numSamples - 1) * samplePeriod, numSamples)
178 |
179 | bw = f_up - f_low
180 | alpha = bw / t[-1]
181 | f_c = f_low + (f_up - f_low) / 2
182 | s = np.zeros((numSamples, np.size(chirpTimes)))
183 |
184 | dr, r_max = radar.range_axis(bw, numSamples)
185 |
186 | for segment in segments:
187 | r_total = segment.calculateRange(radarLoc)
188 | phi, theta = segment.calculateAngles(radarLoc)
189 | rcs = segment.calculateRCS(phi, theta)
190 | amp = np.sqrt(rcs)
191 |
192 | # Formally it should not be squared, but squaring resembles the effect of the low pass filter
193 | fspl = (4 * np.pi * r_total * f_c / C) ** 2
194 |
195 | # calculates the signal for all chirps c and samples s in the format [[chirp1s1, chirp2s1, ..., chirpns1]
196 | # [chirp1s2, chirp2s2, ..., chripns2 ]
197 | # [ ... ]
198 | # [chrip1sm, chrip2sm, ..., chripnsm]]
199 | s_segment = (amp / fspl) * np.cos(2 * np.pi * (2 * f_low * r_total / C +
200 | 2 * alpha * np.outer(t, r_total) / C -
201 | 2 * alpha * r_total.reshape(1, -1) ** 2 / C ** 2))
202 |
203 | # print("amp shape: {}".format(amp.shape))
204 | # print("np.outer shape: {}".format(np.outer(t, r_total).shape))
205 | # print("r_total.reshape: {}".format(r_total.reshape(1,-1).shape))
206 |
207 | # set whole chirp to zero if segment is out of sight
208 | s_segment[:, r_total > r_max] = 0
209 | s += s_segment
210 |
211 | s = s.T.reshape(numFrames, numChirps, numSamples)
212 | s = np.expand_dims(s, axis=1)
213 |
214 | print("Synthetic data with shape {} successfully created".format(s.shape))
215 | return s
216 |
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import pandas as pd
4 |
5 | import matplotlib.pyplot as plt
6 | plt.rcParams['svg.fonttype'] = 'none'
7 | import matplotlib.ticker as mticker
8 | import matplotlib.animation as manimation
9 |
10 | from mpl_toolkits import mplot3d # noqa: F401 unused import
11 |
12 | from utils import skeletons as sk
13 |
14 | # Put all visualizations here
15 |
16 |
17 | def spec_plot(ds, axes=None, ax_xlabel=None, cbar_global=None, **kwargs):
18 | """
19 | Plot radar spectrograms using xarray functionality
20 | :param ds: xarray.Dataset with spectrograms to plot
21 | :param axes: Matplotlib Axes to plot onto. If not given, a new Matplotlib Figure will be created
22 | :param axes: Matplotlib Axes where x label should be present. Default to the last Axes
23 | :param cbar_global: Title of a global colorbar
24 | :param kwargs: Keyword arguments for xarray.DataArray.plot
25 | :return: None
26 | """
27 | ds_time = ds["time"]
28 | ds['time'] = tdelta2secs(ds_time)
29 | ds.time.attrs = {"long_name": "time offset", "units": "min:sec"}
30 |
31 | title = create_title(ds)
32 |
33 | num_features = len(ds)
34 |
35 | if axes is None:
36 | fig, axes = plt.subplots(num_features, 1)
37 | else:
38 | fig = axes[-1].figure
39 |
40 | if cbar_global:
41 | kwargs["add_colorbar"] = False
42 |
43 | if ax_xlabel is None:
44 | ax_xlabel = [axes[-1]]
45 | for ax, feat in zip(axes, ds.values()):
46 | plt.sca(ax)
47 | cbar_kwargs = {"label": f"Spectrogram [{feat.units}]"}
48 | try:
49 | if not kwargs["add_colorbar"]:
50 | cbar_kwargs = {}
51 | except KeyError:
52 | pass
53 | feat.plot.imshow(x="time", cbar_kwargs=cbar_kwargs, **kwargs)
54 | ax.xaxis.set_major_formatter(mticker.FuncFormatter(format_seconds))
55 | if ax not in ax_xlabel:
56 | ax.set_xlabel(None)
57 |
58 | fig.suptitle(title, wrap=True)
59 |
60 | if cbar_global:
61 | im = axes[-1].get_images()[0]
62 | fig.subplots_adjust(right=0.8)
63 | cbar_ax = fig.add_axes([0.85, 0.15, 0.03, 0.7])
64 | cbar = fig.colorbar(im, cax=cbar_ax)
65 | cbar.set_label(cbar_global)
66 |
67 | ds['time'] = ds_time
68 |
69 |
70 | def format_seconds(x, _pos):
71 | timestamp = "{:02d}:{:02d}".format(int(x // 60), int(x % 60))
72 | rest = x % 1.0
73 | if rest != 0.0:
74 | return ''
75 | else:
76 | return timestamp
77 |
78 |
79 | def animation_real_synthetic(real_rd, synthetic_rd, skeletons=None,
80 | sensor_loc=None, notebook=False, save_path=None, max_z=2.5, **kwargs):
81 | """
82 | Compare real and synthetic data with an animated plot
83 | :param real_rd: Real RDM sequence (in dB) as xarray.DataArray
84 | :param synthetic_rd: Synthetic RDM sequence (in dB) as xarray.DataArray.
85 | The time dimension must be consistent across all data arrays
86 | :param skeletons: Skeleton data as xarray.DataArray (optional)
87 | :param sensor_loc: Radar Sensor position as given by daq.env
88 | :param notebook: if True, return `anim_obj` to be used in a Jupyter notebook
89 | :param save_path: if given, save the video under this path
90 | :param max_z: Maximum height in meters
91 | :param kwargs: keyword arguments for RDM plotting.
92 | :return: `anim_obj` if `notebook=True`
93 | """
94 | fig = plt.figure(figsize=(10, 5))
95 |
96 | sk_lines = False
97 | if skeletons is None:
98 | subplot_real = 121
99 | subplot_synth = 122
100 | else:
101 | subplot_real = 222
102 | subplot_synth = 224
103 |
104 | ax_sk = fig.add_subplot(121, projection='3d')
105 | room_size = skeletons.room_size + [max_z]
106 | sk_edges = sk.get_edges(skeletons)
107 | sk_lines = plt_skeleton(sk_edges[0, ], room_size=room_size, axis=ax_sk, sensor_loc=sensor_loc)
108 |
109 | if "vmin" not in kwargs:
110 | kwargs["vmin"] = min(real_rd.min(), synthetic_rd.min())
111 | if "vmax" not in kwargs:
112 | kwargs["vmax"] = max(real_rd.max(), synthetic_rd.max())
113 | ax_real = fig.add_subplot(subplot_real)
114 | im_real = real_rd.isel(time=0).plot.imshow(add_colorbar=False, **kwargs)
115 | if skeletons is not None:
116 | ax_real.set_xticklabels([])
117 | ax_real.set_xlabel(None)
118 | plt.title("Real Data")
119 | ax_synth = fig.add_subplot(subplot_synth)
120 | im_synth = synthetic_rd.isel(time=0).plot.imshow(add_colorbar=False, **kwargs)
121 | plt.title("Synthetic Data")
122 |
123 | ts_format = "%M:%S.%f"
124 | timestamps = pd.to_timedelta(real_rd.time.values) + pd.Timestamp(0)
125 | frame_period_ms = real_rd.FramePeriod // 1000
126 | num_frames = len(timestamps)
127 |
128 | title = create_title(real_rd)
129 |
130 | def suptitle(timestamp):
131 | return f"{title}+{timestamp.strftime(ts_format)[:-3]}"
132 |
133 | fig.suptitle(suptitle(timestamps[0]), wrap=True)
134 |
135 | cbar_orient = 'horizontal' if skeletons is None else 'vertical'
136 | cbar = fig.colorbar(im_synth, ax=(ax_real, ax_synth), orientation=cbar_orient)
137 | cbar.set_label('Amplitude [dB]')
138 |
139 | # animation function. This is called sequentially
140 | def animate(i):
141 | im_real.set_array(real_rd.isel(time=i))
142 | im_synth.set_array(synthetic_rd.isel(time=i))
143 | if sk_lines:
144 | update_skeleton(sk_lines, sk_edges[i, ])
145 | fig.suptitle(suptitle(timestamps[i]))
146 | return [im_real, im_synth]
147 |
148 | anim_obj = manimation.FuncAnimation(
149 | fig,
150 | # The function that does the updating of the Figure
151 | animate,
152 | frames=num_frames,
153 | # Frame-time in ms
154 | interval=frame_period_ms,
155 | blit=True
156 | )
157 | if save_path is not None:
158 | anim_obj.save(save_path)
159 | elif notebook:
160 | return anim_obj
161 | else:
162 | plt.show()
163 |
164 |
165 | def animation(real_rd, skeletons=None,
166 | sensor_loc=None, notebook=False, save_path=None, max_z=2.5, **kwargs):
167 | """
168 | Plot range doppler maps
169 | :param real_rd: Real RDM sequence (in dB) as xarray.DataArray
170 | :param skeletons: Skeleton data as xarray.DataArray (optional)
171 | :param sensor_loc: Radar Sensor position as given by daq.env
172 | :param notebook: if True, return `anim_obj` to be used in a Jupyter notebook
173 | :param save_path: if given, save the video under this path
174 | :param max_z: Maximum height in meters
175 | :param kwargs: keyword arguments for RDM plotting.
176 | :return: `anim_obj` if `notebook=True`
177 | """
178 | fig = plt.figure(figsize=(11, 5))
179 |
180 | sk_lines = False
181 | if skeletons is None:
182 | subplot_real = 111
183 | else:
184 | subplot_real = 122
185 |
186 | ax_sk = fig.add_subplot(121, projection='3d')
187 | room_size = skeletons.room_size + [max_z]
188 | sk_edges = sk.get_edges(skeletons)
189 | sk_lines = plt_skeleton(sk_edges[0, ], room_size=room_size, axis=ax_sk, sensor_loc=sensor_loc)
190 |
191 | if "vmin" not in kwargs:
192 | kwargs["vmin"] = real_rd.min()
193 | if "vmax" not in kwargs:
194 | kwargs["vmax"] = real_rd.max()
195 | ax_real = fig.add_subplot(subplot_real)
196 | im_real = real_rd.isel(time=0).plot.imshow(add_colorbar=False, **kwargs)
197 | plt.title("Range Doppler Map")
198 |
199 | ts_format = "%M:%S.%f"
200 | timestamps = pd.to_timedelta(real_rd.time.values) + pd.Timestamp(0)
201 | frame_period_ms = real_rd.FramePeriod // 1000
202 | num_frames = len(timestamps)
203 |
204 | title = create_title(real_rd)
205 |
206 | def suptitle(timestamp):
207 | return f"{title}+{timestamp.strftime(ts_format)[:-3]}"
208 |
209 | fig.suptitle(suptitle(timestamps[0]), wrap=True)
210 |
211 | cbar = fig.colorbar(im_real, ax=ax_real, orientation='vertical')
212 | cbar.set_label('Amplitude [dB]')
213 |
214 | # animation function. This is called sequentially
215 | def animate(i):
216 | im_real.set_array(real_rd.isel(time=i))
217 | if sk_lines:
218 | update_skeleton(sk_lines, sk_edges[i, ])
219 | fig.suptitle(suptitle(timestamps[i]))
220 | return [im_real]
221 |
222 | anim_obj = manimation.FuncAnimation(
223 | fig,
224 | # The function that does the updating of the Figure
225 | animate,
226 | frames=num_frames,
227 | # Frame-time in ms
228 | interval=frame_period_ms,
229 | blit=True
230 | )
231 | if save_path is not None:
232 | anim_obj.save(save_path)
233 | elif notebook:
234 | return anim_obj
235 | else:
236 | plt.show()
237 |
238 |
239 | def animation_spec(rdm, spectrograms, gap=False, marker_color='white', marker_style="d",
240 | notebook=False, save_path=None, **kwargs):
241 | """
242 | Animation to show how spectrograms are extracted from the RDM sequence
243 | :param rdm: Range Doppler Map sequence as xarray.DataArray
244 | :param spectrograms: Range spectrograms as xarray.DataSet, extracted from `rdm`
245 | :param gap: If True, leave a gap between RDM and spectrograms
246 | :param marker_color: Color of the time marker on the spectrograms
247 | :param marker_style: Style of the marker tips on the spectrograms
248 | :param notebook: if True, return `anim_obj` to be used in a Jupyter notebook
249 | :param save_path: if given, save the video under this path
250 | :param kwargs: keyword arguments for RDM plotting.
251 | :return: `anim_obj` if `notebook=True`
252 | """
253 |
254 | if gap:
255 | fig = plt.figure(figsize=(16, 5))
256 | subplot_rdm = 131
257 | subplot_range = 233
258 | subplot_doppler = 236
259 | else:
260 | fig = plt.figure(figsize=(10, 5))
261 | subplot_rdm = 121
262 | subplot_range = 222
263 | subplot_doppler = 224
264 |
265 | ts_format = "%M:%S.%f"
266 | timestamps = pd.to_timedelta(rdm.time.values) + pd.Timestamp(0)
267 | frame_period_ms = rdm.FramePeriod // 1000
268 | num_frames = len(timestamps)
269 |
270 | time_spect = tdelta2secs(spectrograms.time)
271 | rng = spectrograms.range.values
272 | dopp = spectrograms.doppler.values
273 |
274 | def set_rdm_title(timestamp):
275 | return f"{rdm_title}at time offset {timestamp.strftime(ts_format)[:-3]}"
276 |
277 | if "vmin" not in kwargs:
278 | kwargs["vmin"] = min(rdm.min(), min(s.min() for s in spectrograms.values()))
279 | if "vmax" not in kwargs:
280 | kwargs["vmax"] = max(rdm.max(), max(s.max() for s in spectrograms.values()))
281 | ax_rdm = fig.add_subplot(subplot_rdm)
282 | im_rdm = rdm.isel(time=0).plot.imshow(add_colorbar=False, **kwargs)
283 |
284 | rdm_title = "Range Doppler Map\n" if gap else "RDM "
285 | plt.title(set_rdm_title(timestamps[0]))
286 |
287 | ax_range = fig.add_subplot(subplot_range)
288 | ax_doppler = fig.add_subplot(subplot_doppler)
289 |
290 | spec_plot(spectrograms, axes=[ax_range, ax_doppler], add_colorbar=False, **kwargs)
291 | ax_range.set_title("Radar Spectrograms")
292 | fig.align_ylabels([ax_range, ax_doppler])
293 | plt.sca(ax_range)
294 | r_marker, = plt.plot([time_spect[0], time_spect[0]], [rng[0], rng[-1]], c=marker_color, marker=marker_style)
295 | d_marker, = ax_doppler.plot([time_spect[0], time_spect[0]],
296 | [dopp[0], dopp[-1]], c=marker_color, marker=marker_style)
297 |
298 | fig.suptitle(create_title(rdm), wrap=True)
299 |
300 | if gap:
301 | cbar_rdm = fig.colorbar(im_rdm, ax=ax_rdm, orientation='vertical')
302 | cbar_rdm.set_label('Amplitude [dB]')
303 | cbar_spect = fig.colorbar(ax_range.get_images()[0], ax=(ax_range, ax_doppler), orientation='vertical')
304 | cbar_spect.set_label('Amplitude [dB]')
305 |
306 | # animation function. This is called sequentially
307 | def animate(i):
308 | im_rdm.set_array(rdm.isel(time=i))
309 | ax_rdm.set_title(set_rdm_title(timestamps[i]))
310 | t_marker = [time_spect[i]] * 2
311 | r_marker.set_xdata(t_marker)
312 | d_marker.set_xdata(t_marker)
313 | return [im_rdm, r_marker, d_marker]
314 |
315 | anim_obj = manimation.FuncAnimation(
316 | fig,
317 | # The function that does the updating of the Figure
318 | animate,
319 | frames=num_frames,
320 | # Frame-time in ms
321 | interval=frame_period_ms,
322 | blit=True
323 | )
324 | if save_path is not None:
325 | anim_obj.save(save_path)
326 | elif notebook:
327 | return anim_obj
328 | else:
329 | plt.show()
330 |
331 |
332 | def plt_skeleton(sk_edges, room_size=None, axis=None, sensor_loc=None):
333 | if axis is None:
334 | fig = plt.figure()
335 | axis = fig.add_subplot(projection='3d')
336 | sk_lines = []
337 | for edge in sk_edges:
338 | line, = axis.plot(*edge, color='black', marker='o')
339 | sk_lines.append(line)
340 | if sensor_loc is not None:
341 | sensor_x, sensor_y, sensor_z = (sl[0] for sl in sensor_loc)
342 | axis.scatter(sensor_x, sensor_y, sensor_z, color="red", marker="*")
343 | axis.text(sensor_x, sensor_y, sensor_z, "Radar")
344 |
345 | axis.set_xlabel('x')
346 | axis.set_ylabel('y')
347 | axis.set_zlabel('z')
348 | if room_size is not None:
349 | axis.set_xlim(0, room_size[0])
350 | axis.set_ylim(0, room_size[1])
351 | axis.set_zlim(0, room_size[2])
352 |
353 | plt.title("Room reconstruction")
354 |
355 | return sk_lines
356 |
357 |
358 | def update_skeleton(sk_lines, sk_edges):
359 | for line, edge in zip(sk_lines, sk_edges):
360 | line.set_xdata(edge[0])
361 | line.set_ydata(edge[1])
362 | line.set_3d_properties(edge[2])
363 |
364 |
365 | def create_title(ds):
366 | act = ds.attrs['activity']
367 | cfg = ds.attrs['cfg']
368 | date = pd.to_datetime(ds.date).strftime("on %d-%m-%Y at %H:%M")
369 | title = f"{act} from confg. {cfg} {date}"
370 | return title
371 |
372 |
373 | def tdelta2secs(time_delta):
374 | return time_delta / np.timedelta64(1, 's')
375 |
--------------------------------------------------------------------------------
/utils/slicer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.random import default_rng
3 |
4 | from utils.preprocess import recs2dataset
5 |
6 | rng = default_rng()
7 |
8 |
9 | def verbose_split(train_idx, test_idx):
10 | i = 0
11 | for tr, ts in zip(train_idx, test_idx):
12 | i += 1
13 |
14 | partial_load = sum(len(t) for t in tr)
15 | rec_len = partial_load + sum(len(t) for t in ts)
16 | print(f"Recording {i}:")
17 | print(f"\tPartial load:\t{partial_load / rec_len:02},\t{partial_load}/{rec_len} frames")
18 |
19 | total_load = sum(sum(len(t) for t in tr) for tr in train_idx)
20 | total_length = total_load + sum(sum(len(t) for t in ts) for ts in test_idx)
21 | print(f"Obtained load:\t{total_load / total_length:02},\t{total_load}/{total_length} frames")
22 |
23 |
24 | def train_split_no_cut(recording_lengths, train_load, labels, verbose=True):
25 | """
26 | Perform test-train split of a list of recordings without cutting recordings, only assigning recordings
27 | to different sets in a label-balanced way
28 | Args:
29 | recording_lengths: List of int, each one of them represents the length of a recording file
30 | train_load: float, Desired train load in [0.5, 1)
31 | labels: List of int representing the label assigned to each recording to take into account for splitting
32 | verbose: Verbose output to verify the split
33 |
34 | Returns:
35 | train_idx, test_idx
36 | The indices for the train and test indices as two nested lists of range objects
37 |
38 | """
39 |
40 | if not (0.5 <= train_load < 1):
41 | raise ValueError(f"Invalid value for train_load={train_load}.\n"
42 | f"It must lie within [0.5 , 1).")
43 |
44 | label_lengths = {}
45 | for rec_len, lbl in zip(recording_lengths, labels):
46 | try:
47 | label_lengths[lbl].append(rec_len)
48 | except KeyError:
49 | label_lengths[lbl] = [rec_len]
50 | total_loads = {k: round(sum(v) * train_load) for k, v in label_lengths.items()}
51 |
52 | train_idx = []
53 | test_idx = []
54 | acc_lengths = {k: 0 for k in total_loads}
55 | for r_len, lbl in zip(recording_lengths, labels):
56 | if acc_lengths[lbl] < total_loads[lbl]:
57 | train_idx.append([range(0, r_len)])
58 | test_idx.append([range(0, 0)])
59 | else:
60 | train_idx.append([range(0, 0)])
61 | test_idx.append([range(0, r_len)])
62 | acc_lengths[lbl] += r_len
63 |
64 | if verbose:
65 | print(f"Required load:\t{train_load}")
66 | verbose_split(train_idx, test_idx)
67 |
68 | return train_idx, test_idx
69 |
70 |
71 | def train_split_singlecut(recording_lengths, train_load, beta=10.0, verbose=True):
72 | """
73 | Perform test-train split of a list of recordings through a single random cut on every recording
74 | Args:
75 | recording_lengths: List of int, each one of them represents the length of a recording file
76 | train_load: float, Desired train load in [0.5, 1)
77 | beta: Parameter of the beta distribution. The bigger, the less standard deviation over the random cuts.
78 | If None, it divides every measurement deterministically according to train_load
79 | verbose: Verbose output to verify the split
80 |
81 | Returns:
82 | train_idx, test_idx
83 | The indices for the train and test indices as two nested lists of range objects
84 |
85 | """
86 |
87 | if not (0.5 <= train_load < 1):
88 | raise ValueError(f"Invalid value for train_load={train_load}.\n"
89 | f"It must lie within [0.5 , 1).")
90 |
91 | if beta is None:
92 | deterministic = True
93 | pass
94 | elif beta < 1:
95 | raise ValueError(f"Invalid value for beta={beta}.\n"
96 | f"It must be equal or greater than one to enforce unimodality.")
97 | else:
98 | deterministic = False
99 |
100 | rec_lengths = np.array(recording_lengths)
101 | n_recs = rec_lengths.size
102 | total_length = rec_lengths.sum()
103 |
104 | # Use the longest recording to even out the obtained load
105 | i_max = rec_lengths.argmax()
106 | max_length = rec_lengths[i_max]
107 | max_out_idx = np.arange(n_recs) != i_max
108 |
109 | total_load = np.round(total_length * train_load)
110 |
111 | max_out_loads = None
112 | max_load = -1
113 | if deterministic:
114 | max_out_loads = np.round(train_load * rec_lengths[max_out_idx])
115 | max_load = total_load - max_out_loads.sum()
116 | else: # Use a beta distribution to generate random loads with rejection
117 | while not 0 < max_load < max_length:
118 | alpha = beta * train_load / (1 - train_load)
119 | beta_loads = rng.beta(alpha, beta, size=n_recs - 1)
120 | max_out_loads = np.round(beta_loads * rec_lengths[max_out_idx])
121 | max_load = total_load - max_out_loads.sum()
122 |
123 | partial_loads = np.zeros(n_recs, dtype=np.int32)
124 | partial_loads[max_out_idx] = max_out_loads
125 | partial_loads[i_max] = max_load
126 |
127 | head_tail = rng.binomial(1, 0.5, size=n_recs) # Choose randomly the head or the tail of the chunk
128 |
129 | train_idx = []
130 | test_idx = []
131 |
132 | for head_train, partial_load, rec_len in zip(head_tail, partial_loads, rec_lengths):
133 | if head_train:
134 | train_idx.append([range(0, partial_load)])
135 | test_idx.append([range(partial_load, rec_len)])
136 | else:
137 | train_idx.append([range(rec_len - partial_load, rec_len)])
138 | test_idx.append([range(0, rec_len - partial_load)])
139 |
140 | if verbose:
141 | print(f"Longest recording is {i_max + 1}, {rec_lengths[i_max]} frames")
142 | print(f"Required load:\t{train_load}")
143 | verbose_split(train_idx, test_idx)
144 |
145 | return train_idx, test_idx
146 |
147 |
148 | def train_split_doublecut(recording_lengths, train_load, min_len=1, stride=1, verbose=True):
149 | """
150 | Perform test-train split of a list of recordings through a double cut of fixed length
151 | at a random position of every recording
152 | Args:
153 | recording_lengths: List of int, each one of them represents the length of a recording file
154 | train_load: float, Desired train load in [0.5, 1)
155 | min_len: int, Minimum length of a chunk, important if the data will be further sliced
156 | stride: int, Stride to consider when choosing the random position of the cuts
157 | verbose: Verbose output to verify the split
158 |
159 | Returns:
160 | train_idx, test_idx
161 | The indices for the train and test indices as two nested lists of range objects
162 |
163 | """
164 |
165 | if not (0.5 <= train_load < 1):
166 | raise ValueError(f"Invalid value for train_load={train_load}.\n"
167 | f"It must lie within [0.5 , 1).")
168 |
169 | rec_lengths = np.array(recording_lengths)
170 | n_recs = rec_lengths.size
171 | total_length = rec_lengths.sum()
172 |
173 | # Use the longest recording to even out the obtained load
174 | i_max = rec_lengths.argmax()
175 | max_out_idx = np.arange(n_recs) != i_max
176 |
177 | total_load = np.round(total_length * train_load)
178 |
179 | max_out_loads = np.round(train_load * rec_lengths[max_out_idx])
180 | max_load = total_load - max_out_loads.sum()
181 |
182 | partial_loads = np.zeros(n_recs, dtype=np.int32)
183 | partial_loads[max_out_idx] = max_out_loads
184 | partial_loads[i_max] = max_load
185 |
186 | train_idx = []
187 | test_idx = []
188 |
189 | for partial_load, rec_len in zip(partial_loads, rec_lengths):
190 |
191 | trimmed_length = rec_len - partial_load
192 |
193 | offset_choices = [0, trimmed_length] # head and tail offsets
194 | offset_choices.extend(range(min_len, trimmed_length - min_len, stride)) # centered train chunk
195 | offset_choices.extend(range(trimmed_length + min_len, rec_len - min_len, stride)) # centered test chunk
196 |
197 | offset = rng.choice(offset_choices)
198 | second_cut = (offset + partial_load) % rec_len
199 |
200 | if offset == 0: # head train chunk
201 | train_idx.append([range(offset, partial_load)])
202 | test_idx.append([range(partial_load, rec_len)])
203 | elif offset == trimmed_length: # tail train chunk
204 | train_idx.append([range(offset, rec_len)])
205 | test_idx.append([range(0, offset)])
206 | elif offset < second_cut: # centered train chunk
207 | train_idx.append([range(offset, second_cut)])
208 | test_idx.append([range(0, offset),
209 | range(second_cut, rec_len)])
210 | else: # centered test chunk
211 | train_idx.append([range(0, second_cut),
212 | range(offset, rec_len)])
213 | test_idx.append([range(second_cut, offset)])
214 |
215 | if verbose:
216 | print(f"Longest recording is {i_max + 1}, {rec_lengths[i_max]} frames")
217 | print(f"Required load:\t{train_load}")
218 | verbose_split(train_idx, test_idx)
219 |
220 | return train_idx, test_idx
221 |
222 |
223 | def serialize_ranges(ranges):
224 | if isinstance(ranges, range):
225 | return {"start": ranges.start, "stop": ranges.stop}
226 | elif isinstance(ranges, (tuple, list)):
227 | return [serialize_ranges(r) for r in ranges]
228 | else:
229 | raise TypeError(f"Unexpected class {type(ranges)}")
230 |
231 |
232 | def deserialize_ranges(ranges):
233 | if isinstance(ranges, dict):
234 | return range(ranges["start"], ranges["stop"])
235 | elif isinstance(ranges, (tuple, list)):
236 | return [deserialize_ranges(r) for r in ranges]
237 | else:
238 | raise TypeError(f"Unexpected class {type(ranges)}")
239 |
240 |
241 | def index_ranges(ranges, recordings):
242 | lut_date = {r.date: i for i, r in enumerate(recordings)}
243 | new_ranges = []
244 | for date, r in ranges.items():
245 | new_ranges.append((lut_date[date], deserialize_ranges(r)))
246 | return new_ranges
247 |
248 |
249 | def train_test_slice(recordings, time_length, time_hop, train_load, split="single", beta=20.0,
250 | verbose=True, merge=False, return_segments=False):
251 | """Slice recordings into train and test datasets
252 |
253 | Args:
254 | recordings: Sequence of Dataset objects holding the preprocessed recordings
255 | time_length: Length of the data examples along the time axis
256 | time_hop: Hop size between the slices across the time axis
257 | train_load: float, Desired train load in [0.5, 1)
258 | return_segments: Return serialized segments
259 | verbose: Verbose output to verify the split
260 | merge: If True, merge slices into a new lazy dataset, otherwise return a flat indexed slice list
261 | split: Split variant, either "deterministic", single" or "double" or None for no train-test split
262 | beta: float, beta value for the random single split
263 |
264 | Returns: train_idx, test_idx
265 | Train and test Dataset objects or slice lists, depending on 'merge'
266 |
267 | """
268 |
269 | rec_lens = [r.sizes["time"] for r in recordings]
270 | rec_labels = [r.label.item() for r in recordings]
271 |
272 | def slice_indices(indices):
273 | sliced_indices = []
274 | for i, rec_segments in enumerate(indices):
275 | for range_seg in rec_segments:
276 | for offset in range(0, len(range_seg) - time_length + 1, time_hop):
277 | sliced_indices.append((i, range_seg[offset:offset+time_length]))
278 | return sliced_indices
279 |
280 | train_splits = {"deterministic": lambda rcl, tld: train_split_singlecut(rcl, tld, beta=None, verbose=verbose),
281 | "no-cut": lambda rcl, tld: train_split_no_cut(rcl, tld, labels=rec_labels, verbose=verbose),
282 | "single": lambda rcl, tld: train_split_singlecut(rcl, tld, beta=beta, verbose=verbose),
283 | "double": lambda rcl, tld: train_split_doublecut(rcl, tld, min_len=time_length,
284 | stride=time_hop, verbose=verbose)}
285 |
286 | if split is None:
287 | full_idx = [[range(rl)] for rl in rec_lens]
288 | full_slices = slice_indices(full_idx)
289 | if merge:
290 | return recs2dataset(recordings, full_slices)
291 | else:
292 | return full_slices
293 | else:
294 | if isinstance(split, dict):
295 | return_segments = False
296 | rec_dates = [r.date for r in recordings]
297 | train_split = split['train']
298 | test_split = split['test']
299 |
300 | train_idx = [deserialize_ranges(train_split[d]) for d in rec_dates]
301 | test_idx = [deserialize_ranges(test_split[d]) for d in rec_dates]
302 | else:
303 | try:
304 | train_idx, test_idx = train_splits[split](rec_lens, train_load)
305 | except KeyError as ke:
306 | ts_keys = tuple(train_splits.keys())
307 | raise ValueError(f"Unrecognized split argument '{split}'."
308 | f" Argument must belong to {ts_keys}") from ke
309 |
310 | train_sliced = slice_indices(train_idx)
311 | test_sliced = slice_indices(test_idx)
312 |
313 | if merge:
314 | train_ds = recs2dataset(recordings, train_sliced)
315 | test_ds = recs2dataset(recordings, test_sliced)
316 | ret = [train_ds, test_ds]
317 | else:
318 | ret = [train_sliced, test_sliced]
319 |
320 | if return_segments:
321 | serial_segments = {"train": {r.date: s for r, s in zip(recordings, serialize_ranges(train_idx))},
322 | "test": {r.date: s for r, s in zip(recordings, serialize_ranges(test_idx))}}
323 | ret.append(serial_segments)
324 |
325 | return ret
326 |
--------------------------------------------------------------------------------
/visualize/notebook.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "581bcae7-0e04-478c-8d92-ea4f253b86bc",
6 | "metadata": {
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | },
10 | "tags": []
11 | },
12 | "source": [
13 | "# RACPIT visualization"
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "id": "b2f5d92a-8f06-458e-9f98-f362d750e100",
19 | "metadata": {
20 | "pycharm": {
21 | "name": "#%% md\n"
22 | },
23 | "tags": []
24 | },
25 | "source": [
26 | " This notebook demonstrates our simulation, preprocessing and visualization pipeline for radar data.\n",
27 | "See it rendered\n",
28 | "[here](https://fraunhoferhhi.github.io/racpit/visualize)."
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "id": "fe5aa707",
35 | "metadata": {
36 | "collapsed": false,
37 | "jupyter": {
38 | "outputs_hidden": false
39 | },
40 | "pycharm": {
41 | "name": "#%%\n"
42 | }
43 | },
44 | "outputs": [],
45 | "source": [
46 | "import os\n",
47 | "os.chdir(\"..\")"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "id": "cca2b70b",
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "from pathlib import Path\n",
58 | "import pandas as pd\n",
59 | "\n",
60 | "import numpy as np\n",
61 | "import xarray as xr\n",
62 | "\n",
63 | "from utils.synthesize import synthetic_radar\n",
64 | "from utils.preprocess import open_recordings, identify_config, raw2rdm\n",
65 | "\n",
66 | "from utils import radar\n",
67 | "from utils import skeletons as sk\n",
68 | "\n",
69 | "from utils.visualization import spec_plot, animation_real_synthetic, animation_spec\n",
70 | "\n",
71 | "from ifxaion.daq import Daq\n",
72 | "\n",
73 | "import matplotlib.pyplot as plt\n",
74 | "from IPython.display import HTML\n",
75 | "\n",
76 | "from networks.img_transf import MultiTransformNet\n",
77 | "\n",
78 | "import torch\n",
79 | "from torch.autograd import Variable"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "id": "5c85fb6b-358a-4850-9968-6c20eab5e09b",
86 | "metadata": {
87 | "tags": []
88 | },
89 | "outputs": [],
90 | "source": [
91 | "data_path = Path(\"/mnt/infineon-radar\")\n",
92 | "\n",
93 | "raw_dir = data_path / \"daq_x-har\"\n",
94 | "activity = \"5_Walking_Boxing\"\n",
95 | "path = raw_dir / f\"{activity}_converted/recording-2020-01-28_12-37-12\"\n",
96 | "\n",
97 | "real_path = data_path / \"preprocessed/fixed_size/real\"\n",
98 | "synth_path = data_path / \"preprocessed/fixed_size/synthetic\"\n",
99 | "\n",
100 | "itn_config = \"D\"\n",
101 | "itn_recording = \"2020-02-05T15:16:08\"\n",
102 | "itn_path = \"models/publication_I.model\""
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "id": "e9862806-9955-4297-b062-e3717f5cc0c7",
108 | "metadata": {
109 | "pycharm": {
110 | "name": "#%% md\n"
111 | },
112 | "tags": []
113 | },
114 | "source": [
115 | "## Simulation"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "id": "161437b9-1232-442c-b1ff-653b573d8262",
121 | "metadata": {
122 | "pycharm": {
123 | "name": "#%% md\n"
124 | },
125 | "tags": []
126 | },
127 | "source": [
128 | "### Radar data loading"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "id": "05eec437",
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "daq = Daq(rec_dir=path)\n",
139 | "env = daq.env\n",
140 | "recording = daq.radar[2]\n",
141 | "rec_config = daq.radar[2].cfg\n",
142 | "timestamps = recording.data.index\n",
143 | "ts_seconds = timestamps.total_seconds()"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": null,
149 | "id": "6195148a",
150 | "metadata": {},
151 | "outputs": [],
152 | "source": [
153 | "config_name = identify_config(rec_config)\n",
154 | "\n",
155 | "rec_config['RadarName'] = rec_config.pop(\"Name\")\n",
156 | "rec_config['cfg'] = config_name\n",
157 | "rec_config[\"activity\"] = activity\n",
158 | "\n",
159 | "n_samples = rec_config['SamplesPerChirp']\n",
160 | "m_chirps = rec_config['ChirpsPerFrame']\n",
161 | "\n",
162 | "print(f\"Synthetizing data for configuration {config_name}\")"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": null,
168 | "id": "6c39f9e6",
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "frame_interval_ms = np.mean((timestamps[1:] - timestamps[:-1]).total_seconds()) * 1e3\n",
173 | "duration_sec = (timestamps[-1] - timestamps[0]).total_seconds()\n",
174 | "print(f'Mean frame interval:\\t{frame_interval_ms} ms')\n",
175 | "print(f'Total duration:\\t{duration_sec} seconds')"
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "id": "42f66451-ea13-4591-8f4f-b28ae83fcad4",
181 | "metadata": {
182 | "pycharm": {
183 | "name": "#%% md\n"
184 | },
185 | "tags": []
186 | },
187 | "source": [
188 | "### Data synthesis"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "id": "e6717ffc-ea38-4325-90d1-8ea074be7e58",
194 | "metadata": {
195 | "pycharm": {
196 | "name": "#%% md\n"
197 | },
198 | "tags": []
199 | },
200 | "source": [
201 | "Load skeleton data"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "id": "611fca25",
208 | "metadata": {
209 | "tags": []
210 | },
211 | "outputs": [],
212 | "source": [
213 | "skeletons = sk.load(path, verbose=True)\n",
214 | "sk_interp = sk.interpolate(skeletons, timestamps.total_seconds())\n",
215 | "sk_da = sk.to_xarray(sk_interp, timestamps, attrs=env)"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "id": "9310c5a7",
221 | "metadata": {
222 | "pycharm": {
223 | "name": "#%% md\n"
224 | }
225 | },
226 | "source": [
227 | "Synthesize raw data from skeleton points using a radar configuration"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": null,
233 | "id": "8d1f79ce",
234 | "metadata": {
235 | "pycharm": {
236 | "name": "#%%\n"
237 | }
238 | },
239 | "outputs": [],
240 | "source": [
241 | "syntheticData = synthetic_radar(skeletons, rec_config, ts_seconds)\n",
242 | "\n",
243 | "assert syntheticData.shape[-2] == m_chirps, \"Number of chirps of synthetic data not correct\"\n",
244 | "assert syntheticData.shape[-1] == n_samples, \"Number of samples per chirp of synthetic data not correct\"\n",
245 | "\n",
246 | "smin = syntheticData.min()\n",
247 | "smax = syntheticData.max()\n",
248 | "snorm = (syntheticData - smin) / (smax - smin)\n",
249 | "raw_synth = pd.DataFrame({\"Timestamps\": timestamps, \"NormData\": [sn for sn in snorm]}).set_index(\"Timestamps\")"
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "id": "fd06c539",
255 | "metadata": {
256 | "pycharm": {
257 | "name": "#%% md\n"
258 | }
259 | },
260 | "source": [
261 | "The result is a DataFrame"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": null,
267 | "id": "d5ec385d",
268 | "metadata": {
269 | "pycharm": {
270 | "name": "#%%\n"
271 | }
272 | },
273 | "outputs": [],
274 | "source": [
275 | "raw_synth"
276 | ]
277 | },
278 | {
279 | "cell_type": "markdown",
280 | "id": "5df5924d-3e2a-47c6-90b8-a3ca70d26d16",
281 | "metadata": {
282 | "pycharm": {
283 | "name": "#%% md\n"
284 | },
285 | "tags": []
286 | },
287 | "source": [
288 | "### Preprocessing"
289 | ]
290 | },
291 | {
292 | "cell_type": "markdown",
293 | "id": "c46026bd-5dd5-4008-be07-64f9f4c48c2d",
294 | "metadata": {
295 | "pycharm": {
296 | "name": "#%% md\n"
297 | }
298 | },
299 | "source": [
300 | "The raw data is processed and converted to an [*x*array](http://xarray.pydata.org/en/stable/) DataArray"
301 | ]
302 | },
303 | {
304 | "cell_type": "code",
305 | "execution_count": null,
306 | "id": "c777f60f",
307 | "metadata": {},
308 | "outputs": [],
309 | "source": [
310 | "rdm_synth = raw2rdm(raw_synth, rec_config, env, name=f\"{activity}-{config_name}\")"
311 | ]
312 | },
313 | {
314 | "cell_type": "markdown",
315 | "id": "97315310",
316 | "metadata": {
317 | "pycharm": {
318 | "name": "#%% md\n"
319 | }
320 | },
321 | "source": [
322 | "The Range Doppler Maps can be converted to dB"
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": null,
328 | "id": "06c3813f",
329 | "metadata": {},
330 | "outputs": [],
331 | "source": [
332 | "rdm_db = xr.apply_ufunc(radar.mag2db, rdm_synth, keep_attrs=True, kwargs={\"normalize\": True})\n",
333 | "rdm_db.assign_attrs(units=\"dB\")"
334 | ]
335 | },
336 | {
337 | "cell_type": "markdown",
338 | "id": "01230d8e",
339 | "metadata": {},
340 | "source": [
341 | "Range & Doppler spectrograms in dB can also be calculated"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": null,
347 | "id": "153dd8c5",
348 | "metadata": {
349 | "collapsed": false,
350 | "jupyter": {
351 | "outputs_hidden": false
352 | },
353 | "pycharm": {
354 | "name": "#%%\n"
355 | }
356 | },
357 | "outputs": [],
358 | "source": [
359 | "rdm_abs = np.abs(rdm_synth)\n",
360 | "rspect = rdm_abs.sum(dim=\"doppler\").assign_attrs({\"long_name\": \"Range spectrogram\", \"units\": \"dB\"})\n",
361 | "dspect = rdm_abs.sum(dim=\"range\").assign_attrs({\"long_name\": \"Doppler spectrogram\", \"units\": \"dB\"})\n",
362 | "synth_spects = xr.Dataset({\"range_spect\": rspect, \"doppler_spect\": dspect}, attrs=rdm_synth.attrs)\n",
363 | "synth_spects = xr.apply_ufunc(radar.mag2db, synth_spects, keep_attrs=True, kwargs={\"normalize\": True})\n",
364 | "\n",
365 | "synth_spects"
366 | ]
367 | },
368 | {
369 | "cell_type": "markdown",
370 | "id": "9d5011b5-7fc7-4ba5-94f9-03c79190c9b2",
371 | "metadata": {
372 | "tags": []
373 | },
374 | "source": [
375 | "### Data animations"
376 | ]
377 | },
378 | {
379 | "cell_type": "markdown",
380 | "id": "325edabd-469d-44ff-b603-114fe1c41d77",
381 | "metadata": {
382 | "tags": []
383 | },
384 | "source": [
385 | "Process range & Doppler information from the real recording and extract a short time slice from it"
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": null,
391 | "id": "22205cdc",
392 | "metadata": {},
393 | "outputs": [],
394 | "source": [
395 | "time_slice = slice(\"00:00:52\", \"00:01:07\")\n",
396 | "\n",
397 | "skeleton_slice = sk_da.sel(time=time_slice)\n",
398 | "\n",
399 | "rdm_real = raw2rdm(recording.data, rec_config, env, name=f\"{activity}-{config_name}\")\n",
400 | "rdm_rabs = np.abs(rdm_real.sel(time=time_slice))\n",
401 | "\n",
402 | "rng_spect = rdm_rabs.sum(dim=\"doppler\").assign_attrs({\"long_name\": \"Range spectrogram\", \"units\": \"dB\"})\n",
403 | "dopp_spect = rdm_rabs.sum(dim=\"range\").assign_attrs({\"long_name\": \"Doppler spectrogram\", \"units\": \"dB\"})\n",
404 | "real_spects = xr.Dataset({\"range_spect\": rng_spect, \"doppler_spect\": dopp_spect}, attrs=rdm_real.attrs)\n",
405 | "real_spects = xr.apply_ufunc(radar.mag2db, real_spects, keep_attrs=True, kwargs={\"normalize\": True})\n",
406 | "\n",
407 | "rdm_real_db = xr.apply_ufunc(radar.mag2db, rdm_rabs, keep_attrs=True, kwargs={\"normalize\": True})\n",
408 | "rdm_synth_db = xr.apply_ufunc(radar.normalize_db, rdm_db.sel(time=time_slice), keep_attrs=True)\n"
409 | ]
410 | },
411 | {
412 | "cell_type": "code",
413 | "execution_count": null,
414 | "id": "540ab81f",
415 | "metadata": {
416 | "tags": []
417 | },
418 | "outputs": [],
419 | "source": [
420 | "%%capture\n",
421 | "anim_rs = animation_real_synthetic(rdm_real_db, rdm_synth_db, skeleton_slice,\n",
422 | " sensor_loc=rec_config[\"position\"], vmin=-40, notebook=True)\n",
423 | "anim_spects = animation_spec(rdm_real_db, real_spects, vmin=-40, notebook=True)"
424 | ]
425 | },
426 | {
427 | "cell_type": "code",
428 | "execution_count": null,
429 | "id": "031dfc22",
430 | "metadata": {},
431 | "outputs": [],
432 | "source": [
433 | "HTML(anim_rs.to_html5_video())"
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": null,
439 | "id": "f6623b79",
440 | "metadata": {},
441 | "outputs": [],
442 | "source": [
443 | "HTML(anim_spects.to_html5_video())"
444 | ]
445 | },
446 | {
447 | "cell_type": "markdown",
448 | "id": "9398ee61-47c8-4398-9ee7-9660e497e54a",
449 | "metadata": {
450 | "pycharm": {
451 | "name": "#%% md\n"
452 | },
453 | "tags": []
454 | },
455 | "source": [
456 | "## Image transformation"
457 | ]
458 | },
459 | {
460 | "cell_type": "markdown",
461 | "id": "69c57ed0-5214-4e4d-915d-6693f2bfa976",
462 | "metadata": {
463 | "pycharm": {
464 | "name": "#%% md\n"
465 | }
466 | },
467 | "source": [
468 | ""
469 | ]
470 | },
471 | {
472 | "cell_type": "markdown",
473 | "id": "f9e98f4f-e139-47f9-81cd-4b9996a4f35e",
474 | "metadata": {
475 | "pycharm": {
476 | "name": "#%% md\n"
477 | },
478 | "tags": []
479 | },
480 | "source": [
481 | "### Open recordings as a list of `Xarray.Datasets`"
482 | ]
483 | },
484 | {
485 | "cell_type": "markdown",
486 | "id": "a918b594-c353-48bb-824c-d24010a84a78",
487 | "metadata": {
488 | "pycharm": {
489 | "name": "#%% md\n"
490 | }
491 | },
492 | "source": [
493 | "The example uses lazy load, but in the GPU `load=True` boosts performance"
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": null,
499 | "id": "1b4be7b0-291d-4e12-b0ef-818b80a2ad73",
500 | "metadata": {},
501 | "outputs": [],
502 | "source": [
503 | "real_recs = open_recordings(itn_config, real_path, load=False)\n",
504 | "synth_recs = open_recordings(itn_config, synth_path, load=False)"
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "execution_count": null,
510 | "id": "20c2487e",
511 | "metadata": {
512 | "tags": []
513 | },
514 | "outputs": [],
515 | "source": [
516 | "print(f\"{len(real_recs)} recordings have been lazy loaded\")\n",
517 | "i_rec = [i for i, rec in enumerate(real_recs) if rec.date == itn_recording][0]\n",
518 | "print(f\"Recording {itn_recording} found at index {i_rec}\")"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "execution_count": null,
524 | "id": "13139269",
525 | "metadata": {},
526 | "outputs": [],
527 | "source": [
528 | "real_rec = real_recs[i_rec]\n",
529 | "synth_rec = synth_recs[i_rec]"
530 | ]
531 | },
532 | {
533 | "cell_type": "markdown",
534 | "id": "679b6f57-36f4-44a9-a224-0a08048663c3",
535 | "metadata": {},
536 | "source": [
537 | "Extract short spectrograms from the recordings"
538 | ]
539 | },
540 | {
541 | "cell_type": "code",
542 | "execution_count": null,
543 | "id": "15ebc9b2-8dd9-4df8-9a93-fae0425775ca",
544 | "metadata": {},
545 | "outputs": [],
546 | "source": [
547 | "tslice = slice(\"00:01:39\", None)\n",
548 | "time_length = 64\n",
549 | "\n",
550 | "real_spects = real_rec.drop_vars('label').sel(time=tslice).isel(time=slice(0,time_length))\n",
551 | "synth_spects = synth_rec.drop_vars('label').sel(time=tslice).isel(time=slice(0,time_length))\n",
552 | "\n",
553 | "real_spects = xr.apply_ufunc(radar.normalize_db, real_spects.load(), keep_attrs=True)\n",
554 | "synth_spects = xr.apply_ufunc(radar.normalize_db, synth_spects.load(), keep_attrs=True)"
555 | ]
556 | },
557 | {
558 | "cell_type": "markdown",
559 | "id": "89442fe7-ba23-4d46-90ac-75b3431056e4",
560 | "metadata": {
561 | "tags": []
562 | },
563 | "source": [
564 | "### Transform images"
565 | ]
566 | },
567 | {
568 | "cell_type": "markdown",
569 | "id": "732e14cf-44c5-46b3-ac74-54a327eecbd6",
570 | "metadata": {
571 | "tags": []
572 | },
573 | "source": [
574 | "Load Image Transformation Network"
575 | ]
576 | },
577 | {
578 | "cell_type": "code",
579 | "execution_count": null,
580 | "id": "2961332a-afef-410b-a0dd-df8d72415632",
581 | "metadata": {
582 | "tags": []
583 | },
584 | "outputs": [],
585 | "source": [
586 | "dtype = torch.FloatTensor\n",
587 | "\n",
588 | "transformer = MultiTransformNet(num_inputs=2, num_channels=1)\n",
589 | "transformer.load_state_dict(torch.load(itn_path))\n",
590 | "_ = transformer.eval()"
591 | ]
592 | },
593 | {
594 | "cell_type": "markdown",
595 | "id": "2d575fb9-3dc6-4c73-a6cf-7747e28cf742",
596 | "metadata": {},
597 | "source": [
598 | "Transform real data with the ITN"
599 | ]
600 | },
601 | {
602 | "cell_type": "code",
603 | "execution_count": null,
604 | "id": "7fe74785-25d5-4de3-a45f-b7a108812d42",
605 | "metadata": {},
606 | "outputs": [],
607 | "source": [
608 | "range_inp = torch.from_numpy(real_spects.range_spect.values[None, None, :, :])\n",
609 | "dopp_inp = torch.from_numpy(real_spects.doppler_spect.values[None, None, :, :])\n",
610 | "spec_input = [Variable(range_inp, requires_grad=False).type(dtype), Variable(dopp_inp, requires_grad=False).type(dtype)]\n",
611 | "\n",
612 | "range_hat, doppler_hat = transformer(spec_input)\n",
613 | "range_trans = range_hat.detach().numpy()\n",
614 | "doppler_trans = doppler_hat.detach().numpy()"
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "execution_count": null,
620 | "id": "e148ead0-3356-4d10-bd54-22dfde6e114f",
621 | "metadata": {},
622 | "outputs": [],
623 | "source": [
624 | "spectrograms = xr.merge([real_spects.rename_vars(range_spect=\"range_real\", doppler_spect=\"doppler_real\"),\n",
625 | " synth_spects.rename_vars(range_spect=\"range_synth\", doppler_spect=\"doppler_synth\")],\n",
626 | " combine_attrs=\"drop_conflicts\")\n",
627 | "spectrograms[\"range_trans\"] = (['time', 'range'], np.squeeze(range_trans), {\"units\": \"dB\"})\n",
628 | "spectrograms[\"doppler_trans\"] = (['time', 'doppler'], np.squeeze(doppler_trans), {\"units\": \"dB\"})"
629 | ]
630 | },
631 | {
632 | "cell_type": "markdown",
633 | "id": "ba9a2ac9-334d-43aa-b7f5-446b989a43ff",
634 | "metadata": {
635 | "jupyter": {
636 | "outputs_hidden": false
637 | },
638 | "pycharm": {
639 | "name": "#%%\n"
640 | },
641 | "tags": []
642 | },
643 | "source": [
644 | "### Plotting"
645 | ]
646 | },
647 | {
648 | "cell_type": "code",
649 | "execution_count": null,
650 | "id": "a4e3609c-fcb3-4bcd-b5de-8460c966aff8",
651 | "metadata": {
652 | "collapsed": false,
653 | "jupyter": {
654 | "outputs_hidden": false
655 | },
656 | "pycharm": {
657 | "name": "#%%\n"
658 | }
659 | },
660 | "outputs": [],
661 | "source": [
662 | "fig, axes = plt.subplots(2, 3, figsize=(12, 8))\n",
663 | "\n",
664 | "spec_plot(spectrograms[[\"range_real\", \"range_trans\", \"range_synth\", \"doppler_real\", \"doppler_trans\", \"doppler_synth\"]],\n",
665 | " axes=axes.flatten(), ax_xlabel=axes[-1], vmin=-40, vmax=0, add_colorbar=False)\n",
666 | "\n",
667 | "im = axes[-1][-1].get_images()[0]\n",
668 | "fig.align_ylabels(axes[:,0])\n",
669 | "_ = [ax.set_ylabel(None) for ax in axes[:,1:].flatten()]\n",
670 | "\n",
671 | "cbar = plt.colorbar(im, ax=axes, orientation=\"horizontal\")\n",
672 | "cbar.set_label(\"Amplitude [dB]\")\n",
673 | "\n",
674 | "for title, ax in zip((\"Real data $x$\", \"Transformed data $\\widehat{y}$\", \"Synthetic data $y$\"), axes[0]):\n",
675 | " ax.set_title(title)"
676 | ]
677 | }
678 | ],
679 | "metadata": {
680 | "kernelspec": {
681 | "display_name": "Python 3 (ipykernel)",
682 | "language": "python",
683 | "name": "python3"
684 | },
685 | "language_info": {
686 | "codemirror_mode": {
687 | "name": "ipython",
688 | "version": 3
689 | },
690 | "file_extension": ".py",
691 | "mimetype": "text/x-python",
692 | "name": "python",
693 | "nbconvert_exporter": "python",
694 | "pygments_lexer": "ipython3",
695 | "version": "3.8.12"
696 | },
697 | "pycharm": {
698 | "stem_cell": {
699 | "cell_type": "raw",
700 | "source": [],
701 | "metadata": {
702 | "collapsed": false
703 | }
704 | }
705 | },
706 | "toc-autonumbering": false,
707 | "toc-showcode": false,
708 | "toc-showmarkdowntxt": true
709 | },
710 | "nbformat": 4,
711 | "nbformat_minor": 5
712 | }
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | from ifxaion.daq import Daq
2 | from ifxaion.radar.utils import processing_functions as pf
3 |
4 | from utils import radar
5 |
6 | import json
7 | import numpy as np
8 | import pandas as pd
9 | import xarray as xr
10 | from dask import is_dask_collection
11 |
12 | from pathlib import Path
13 | from datetime import datetime
14 |
15 | from utils.synthesize import synthetic_radar
16 | from utils.skeletons import load as skload
17 |
18 | _configs_path = Path.cwd() / 'configurations' / 'radar_configs.csv'
19 | _interference_path = None
20 |
21 | _base_path = Path("/mnt/infineon-radar")
22 | _raw_path = _base_path / "daq_x-har"
23 | _preprocessed_path = _base_path / "preprocessed" / "daq_x-har"
24 |
25 | _margin_opts = ("none", "coherent", "incoherent")
26 | _margin_default = _margin_opts[-1]
27 |
28 | _radar_transformations = {
29 | 'complex': {
30 | "func": lambda d: xr.apply_ufunc(radar.complex2vector, d,
31 | output_core_dims=[["complex"]], keep_attrs=True),
32 | "units": "Complex amplitude"},
33 | 'magnitude': {
34 | "func": lambda d: xr.apply_ufunc(radar.absolute, d,
35 | keep_attrs=True, kwargs={"normalize": True}),
36 | "units": "Magnitude"},
37 | 'db': {
38 | "func": lambda d: xr.apply_ufunc(radar.mag2db, d,
39 | keep_attrs=True, kwargs={"normalize": True}),
40 | "units": "dB"}}
41 | _radar_val_keys = tuple(_radar_transformations.keys())
42 | _radar_val_default = _radar_val_keys[-1]
43 |
44 | configs = pd.read_csv(_configs_path, index_col=0, sep=",").T
45 | config_names = configs.index
46 |
47 |
48 | def dict_included(subset_dict, other_dict):
49 | try:
50 | return all(other_dict[k] == v for k, v in subset_dict.items())
51 | except KeyError:
52 | return False
53 |
54 |
55 | def identify_config(radar_config):
56 | found_config = configs[configs.apply(dict_included, axis=1, args=[radar_config])]
57 | try:
58 | config_name = found_config.index[0]
59 | except IndexError:
60 | config_name = "UNKNOWN"
61 | return config_name
62 |
63 |
64 | class StrEncoder(json.JSONEncoder):
65 | _str_cls = (datetime, Path)
66 |
67 | def default(self, obj):
68 | if any(isinstance(obj, c) for c in self._str_cls):
69 | return str(obj) # Let the base class default method raise the TypeError
70 | return json.JSONEncoder.default(self, obj)
71 |
72 |
73 | def preprocess(raw_path=None, preprocessed_path=None, synthetic=False,
74 | range_length=None, doppler_length=None,
75 | range_scope=None, doppler_scope=None, resample_ms=None,
76 | value=_radar_val_default, marginalize=_margin_default,
77 | interference_path=_interference_path):
78 | """Preprocess and save radar data for later training
79 |
80 | :param raw_path: path to dataset
81 | :param preprocessed_path: path to output dataset
82 | :param synthetic: if True, generate radar signals from skeleton data and use it instead of real data
83 | :param range_length: Length of range FFT
84 | :param doppler_length: Length of Doppler FFT
85 | :param range_scope: If given (in meters), range_length will be referred only until this point
86 | :param doppler_scope: If given (in m/s), doppler_length will be referred only until this point
87 | :param resample_ms: New frame period to resample in milliseconds
88 | :param marginalize: If a non-empty string, save range and doppler spectrograms (in)coherently,
89 | otherwise range doppler maps
90 | :param value: Save data as either complex amplitude, magnitude or decibels
91 | :param interference_path: Path to a JSON file indicating interference chunks to discard
92 | :return: metadata dictionary
93 | """
94 |
95 | created_at = datetime.now()
96 |
97 | if raw_path is None:
98 | raw_path = _raw_path
99 | else:
100 | raw_path = Path(raw_path)
101 |
102 | metadata = locals()
103 | del metadata["preprocessed_path"]
104 |
105 | if marginalize not in _margin_opts:
106 | raise ValueError(f"Unrecognized marginalize option '{marginalize}', choose from {_margin_opts}")
107 |
108 | try:
109 | rdm_transform = _radar_transformations[value]
110 | except KeyError as ke:
111 | raise ValueError(f"Unrecognized value '{value}'. Value must belong to {_radar_val_keys}") from ke
112 |
113 | if preprocessed_path is None:
114 | preprocessed_path = _preprocessed_path
115 | else:
116 | preprocessed_path = Path(preprocessed_path)
117 |
118 | metadata_path = preprocessed_path / "metadata.json"
119 |
120 | try:
121 | with open(metadata_path, "w") as wf:
122 | json.dump(metadata, wf, cls=StrEncoder, indent=4)
123 | except FileNotFoundError:
124 | preprocessed_path.mkdir(parents=True)
125 | with open(metadata_path, "w") as wf:
126 | json.dump(metadata, wf, cls=StrEncoder, indent=4)
127 |
128 | if interference_path is None:
129 | interferences = None
130 | else:
131 | with open(interference_path, 'r') as rf:
132 | interferences = json.load(rf)
133 |
134 | label_suffix = 10
135 | unknown_config = "UNKNOWN"
136 |
137 | for activity_path in raw_path.iterdir():
138 | activity_dir = str(activity_path.relative_to(raw_path))
139 | activity = activity_dir[:-label_suffix]
140 | print(f"Converting {activity}")
141 | for rec_path in activity_path.iterdir():
142 | rec_name = str(rec_path.relative_to(activity_path).with_suffix(''))
143 |
144 | if synthetic:
145 | skeletons = skload(rec_path, verbose=True)
146 | else:
147 | skeletons = None
148 |
149 | daq = Daq(rec_dir=rec_path)
150 | env = daq.env
151 | radars = daq.radar
152 | n = len(radars)
153 | print(f"{n} radar files were found in {rec_path}")
154 | print("Environment:", env, sep="\n")
155 | del env["software"]
156 | del env["room_size"]
157 | env["synthetic"] = synthetic
158 |
159 | rec_signals = []
160 |
161 | for recording in radars:
162 | rec_config = recording.cfg
163 | config_name = identify_config(rec_config)
164 |
165 | rdm_dir = preprocessed_path / config_name / activity
166 | rdm_path = rdm_dir / rec_name
167 |
168 | rec_config['RadarName'] = rec_config.pop("Name")
169 | rec_config['cfg'] = config_name
170 |
171 | if not rdm_dir.is_dir():
172 | try:
173 | rdm_dir.mkdir(parents=False)
174 | except FileNotFoundError:
175 | rdm_dir.mkdir(parents=True)
176 | if config_name != unknown_config:
177 | with open(rdm_dir.parent / "config.json", 'w') as fp:
178 | json.dump(rec_config, fp, indent=4)
179 |
180 | rec_config["activity"] = activity
181 |
182 | print(f"Reading data with configuration {config_name}")
183 |
184 | interf_slices = []
185 | if interferences is not None:
186 | interf_slices = interferences[activity_dir][rec_name][config_name]
187 |
188 | named_slices = []
189 | if len(interf_slices) == 0:
190 | named_slices.append((rdm_path, slice(None, None)))
191 | else:
192 | print("Removing interferences")
193 | data_slices = complementary_slices(interf_slices)
194 | path_base = str(rdm_path)
195 | for i, data_slice in enumerate(data_slices):
196 | rdm_file = Path(f"{path_base}_{i+1}")
197 | named_slices.append((rdm_file, data_slice))
198 |
199 | for rdm_file, time_slice in named_slices:
200 | rec_data = recording.data[time_slice]
201 |
202 | timestamps = rec_data.index
203 | frame_interval_ms = np.mean((timestamps[1:] - timestamps[:-1]).total_seconds()) * 1e3
204 | duration_sec = (timestamps[-1] - timestamps[0]).total_seconds()
205 | print(f'Mean frame interval:\t{frame_interval_ms} ms')
206 | print(f'Total duration:\t{duration_sec} seconds')
207 |
208 | if synthetic:
209 | n_samples = rec_config['SamplesPerChirp']
210 | m_chirps = rec_config['ChirpsPerFrame']
211 |
212 | sk_slice = skeletons[time_slice]
213 | syntheticData = synthetic_radar(sk_slice, rec_config, timestamps.total_seconds())
214 |
215 | assert syntheticData.shape[-2] == m_chirps, "Incorrect #chirps of synthetic data"
216 | assert syntheticData.shape[-1] == n_samples, "Incorrect #samples per chirp of synthetic data"
217 |
218 | sMin = syntheticData.min()
219 | sMax = syntheticData.max()
220 | dynamic_range = sMax - sMin
221 | if dynamic_range > 0:
222 | sNorm = (syntheticData - sMin) / dynamic_range
223 | else:
224 | print(f"\n*** WARNING ***\nSynthetic {rdm_file}-{config_name} "
225 | f"has null dynamic range\n***************\n")
226 | sNorm = syntheticData
227 | rec_data = pd.DataFrame({"Timestamps": timestamps,
228 | "NormData": [sn for sn in sNorm]}).set_index("Timestamps")
229 | rec_signals.append((rdm_file, rec_config, rec_data))
230 |
231 | for rdm_file, rec_config, rData in rec_signals:
232 |
233 | n_samples = rec_config['SamplesPerChirp']
234 | m_chirps = rec_config['ChirpsPerFrame']
235 | frame_period_ms = rec_config['FramePeriod'] // 1000
236 | bw_MHz = rec_config['UpperFrequency'] - rec_config['LowerFrequency']
237 | prt_ns = rec_config['ChirpToChirpTime']
238 |
239 | config_name = rec_config['cfg']
240 | rdm_name = f"{rdm_file}-config_{config_name}"
241 |
242 | dr, r_max = radar.range_axis(bw_MHz * 10 ** 6, n_samples)
243 | dv, v_max = radar.doppler_axis(prt_ns * 10 ** (-9), m_chirps)
244 |
245 | if range_scope is None:
246 | r_length = range_length
247 | else:
248 | if r_max < range_scope:
249 | raise ValueError(f"Configuration {config_name} has a max. range of {r_max} m, "
250 | f"below the desired scope of {range_scope} m")
251 | r_length = round(range_length * r_max / range_scope)
252 | if doppler_scope is None:
253 | d_length = doppler_length
254 | else:
255 | if v_max < doppler_scope:
256 | raise ValueError(f"Configuration {config_name} has a max. velocity of {v_max} m/s, "
257 | f"below the desired scope of {doppler_scope} m/s")
258 | d_length = round(doppler_length * v_max / doppler_scope)
259 |
260 | if range_scope or doppler_scope:
261 | print(f"FFT parameters:\n\tRange scope: {range_scope} m,\tDoppler scope: {doppler_scope} m/s")
262 | print(f"\tRange length: {range_length} under scope, {r_length} in total")
263 | print(f"\tDoppler length: {doppler_length} under scope, {d_length} in total")
264 |
265 | rdm_da = raw2rdm(rData, rec_config, env, r_length=r_length, d_length=d_length, name=rdm_name)
266 |
267 | if marginalize == _margin_opts[0]: # none
268 | ds = xr.Dataset({"rdm": rdm_da}, attrs=rdm_da.attrs)
269 | ds.rdm.attrs["long_name"] = "Range Doppler map"
270 | else:
271 | if marginalize == _margin_opts[2]: # incoherent
272 | print("Incoherent marginalization")
273 | rdm_da = np.abs(rdm_da)
274 | else: # coherent
275 | print("Coherent marginalization")
276 | rspect = rdm_da.sum(dim="doppler")
277 | dspect = rdm_da.sum(dim="range")
278 | ds = xr.Dataset({"range_spect": rspect, "doppler_spect": dspect}, attrs=rdm_da.attrs)
279 | ds.range_spect.attrs["long_name"] = "Range spectrogram"
280 | ds.doppler_spect.attrs["long_name"] = "Doppler spectrogram"
281 |
282 | if resample_ms is not None and resample_ms != frame_period_ms:
283 | old_frames = ds.sizes["time"]
284 | print(f"Interpolating frame period from {frame_period_ms}ms to {resample_ms}ms")
285 | ds = ds.resample(time=f"{resample_ms}ms").interpolate("cubic")
286 | new_frames = ds.sizes["time"]
287 | print(f"{old_frames} frames resampled into {new_frames} frames")
288 | ds = ds.assign_attrs(frame_period_ms=resample_ms)
289 | else:
290 | ds = ds.assign_attrs(frame_period_ms=frame_period_ms)
291 |
292 | del ds.attrs["units"]
293 | for k, da in ds.data_vars.items():
294 | ds[k] = rdm_transform["func"](da)
295 | ds[k].attrs["units"] = rdm_transform["units"]
296 |
297 | if "complex" in ds.coords:
298 | ds["complex"] = ["real", "imag"]
299 |
300 | rdm_path = rdm_file.with_suffix(".nc")
301 | ds.to_netcdf(rdm_path)
302 | print(f"Data saved under {rdm_path}\n")
303 | return metadata
304 |
305 |
306 | def raw2rdm(raw_data, rec_config, env, antennas=0, r_length=None, d_length=None, name="rdm"):
307 | """
308 | Convert raw_data to a sequence of Range Doppler Maps (RDM) and return it as an xarray.DataArray
309 | :param raw_data: Radar raw data as returned by daq
310 | :param rec_config: RadarConfig of the recording as returned by daq
311 | :param env: Environment data as returned by daq
312 | :param antennas: Antenna indices to be preprocessed (Only a single index is supported)
313 | :param r_length: Effective length of the FFT over the range axis
314 | :param d_length: Effective length of the FFT over the doppler axis
315 | :param name: Name of the resulting xarray.DataArray
316 | :return: xarray.DataArray with a sequence of RDMs and embedded metadata.
317 | The dimensions of each RDM are (doppler_length, range_length) if used, otherwise (ChirpsPerFrame, SamplesPerChirp/2)
318 | """
319 | rdm_data = pf.preprocess_radar_data(raw_data,
320 | range_length=r_length,
321 | doppler_length=d_length,
322 | antennas=antennas).squeeze()
323 | doppler_bins, range_bins = rdm_data.shape[-2:]
324 |
325 | print(f"Data shape: {rdm_data.shape}")
326 |
327 | timestamps = raw_data.index
328 |
329 | n_samples = rec_config['SamplesPerChirp']
330 | m_chirps = rec_config['ChirpsPerFrame']
331 |
332 | bw_MHz = rec_config['UpperFrequency'] - rec_config['LowerFrequency']
333 | prt_ns = rec_config['ChirpToChirpTime']
334 |
335 | dr, r_max = radar.range_axis(bw_MHz * 10 ** 6, n_samples)
336 | dv, v_max = radar.doppler_axis(prt_ns * 10 ** (-9), m_chirps)
337 |
338 | range_coords = np.linspace(0, r_max, range_bins)
339 | doppler_coords = np.linspace(-v_max, v_max, doppler_bins)
340 |
341 | rdm_da = xr.DataArray(rdm_data, dims=("time", "doppler", "range"),
342 | name=name, attrs={"units": "Complex amplitude"},
343 | coords={"time": timestamps.to_numpy(),
344 | "doppler": doppler_coords,
345 | "range": range_coords})
346 |
347 | rdm_da.range.attrs["units"] = 'm'
348 | rdm_da.doppler.attrs["units"] = "m/s"
349 |
350 | rdm_da = rdm_da.assign_attrs(rec_config)
351 | rdm_da = rdm_da.assign_attrs(env)
352 |
353 | # Delete problematic attributes
354 | del rdm_da.attrs["position"]
355 | del rdm_da.attrs["orientation"]
356 | del rdm_da.attrs["transformation"]
357 |
358 | return rdm_da
359 |
360 |
361 | def open_recordings(radar_configs, preprocessed_path=None, range_length=None, doppler_length=None,
362 | categorical=True, load=False):
363 | """
364 | Open preprocessed recordings for all activities of a certain (or several) radar configuration using xarray
365 | :param radar_configs: One or more radar configurations, e.g. "E" or ["A", "B", "C"]
366 | :param preprocessed_path: Path to the preprocessed data
367 | :param range_length: If provided, data will be cropped on the range axis from 0 to range_length
368 | :param doppler_length: If provided, data will be cropped on the doppler axis from 0 to doppler_length
369 | :param categorical: If True, the activities are categorized as int labels
370 | :param load: If True, Datasets are loaded, otherwise use xarray's lazy load
371 | :return: List of xarray.datasets with the recordings
372 | """
373 | ds_kwargs = {} if load else dict(chunks="auto", cache=False)
374 |
375 | if preprocessed_path is None:
376 | preprocessed_path = _preprocessed_path
377 | else:
378 | preprocessed_path = Path(preprocessed_path)
379 |
380 | def concat_recs(cfg):
381 | cfg_path = preprocessed_path / cfg
382 | activities = [a for a in cfg_path.iterdir() if a.is_dir()]
383 | cfg_recs = []
384 | for act_path in activities:
385 | for rec_file in act_path.iterdir():
386 | with xr.open_dataset(rec_file, **ds_kwargs) as rec:
387 | if range_length is not None:
388 | rec = rec.isel(range=slice(range_length))
389 | if doppler_length is not None:
390 | rec = rec.isel(doppler=slice(doppler_length))
391 | if load:
392 | rec = rec.persist()
393 | cfg_recs.append(rec)
394 | if categorical:
395 | cfg_recs = categorize(cfg_recs, "activity", map_name="activities")
396 | return cfg_recs
397 |
398 | if type(radar_configs) == str:
399 | recordings = concat_recs(radar_configs)
400 | else:
401 | recordings = {c: concat_recs(c) for c in radar_configs}
402 |
403 | return recordings
404 |
405 |
406 | def categorize(recordings, attr, map_name="categories", var_name="label"):
407 | """ Add a global variable from a list of datasets as a categoric variable
408 |
409 | :param recordings: List of Datasets holding processed recordings
410 | :param attr: str, global attribute to categorize as variable
411 | :param map_name: str, name of the global attribute holding the category map
412 | :param var_name: str, name of the index variable
413 | :return: List of the same Datasets enriched with the categorized attribute
414 | """
415 | indices, category_map = pd.factorize([r.attrs[attr] for r in recordings], sort=True)
416 | return [r.assign({var_name: i}).assign_attrs({map_name: tuple(category_map)})
417 | for r, i in zip(recordings, indices)]
418 |
419 |
420 | def recs2dataset(recordings, sliced_indices, ignore_dims=None):
421 | """ Concatenate recording slices into a Dataset
422 |
423 | :param recordings: List of Datasets holding processed recordings
424 | :param sliced_indices: Double nested list of Range objects to slice recordings
425 | :param ignore_dims: list with the names of dimensions to be reset to plain indices
426 | :return: Concatenated Dataset with all sliced recordings
427 | """
428 | if ignore_dims is not None:
429 | ignore_dims = {dim: f"{dim}_" for dim in ignore_dims}
430 | rec_slices = []
431 | for sl in sliced_indices:
432 | rec = slice_recording(recordings, sl)
433 | rec = reset_time(rec)
434 | if ignore_dims is not None:
435 | rec = rec.rename_vars(ignore_dims).reset_coords(ignore_dims.values(), drop=True)
436 | rec_slices.append(rec)
437 | ds = concat_slices(rec_slices, combine_attrs="drop_conflicts")
438 | ds.label.attrs["long_name"] = "Activity label"
439 | return ds
440 |
441 |
442 | def slice_recording(recordings, indexed_slice):
443 | i, r_slice = indexed_slice
444 | try:
445 | sliced_rec = recordings[i].isel(time=r_slice)
446 | except IndexError as ie:
447 | t_len = recordings[i].sizes["time"]
448 | shift = t_len - r_slice.stop
449 | print(f"Caught index error: {ie}, shifting slice by {shift}")
450 | sliced_rec = recordings[i].isel(time=range(r_slice.start + shift, r_slice.stop + shift))
451 | return sliced_rec
452 |
453 |
454 | def concat_slices(slices, dim="batch_index", combine_attrs="drop"):
455 | return xr.concat(slices, dim=dim, combine_attrs=combine_attrs)
456 |
457 |
458 | def reset_time(dataset, add_offset=True):
459 | if add_offset:
460 | dataset = dataset.assign_attrs(date=pd.to_datetime(dataset.date) + dataset.time[0].values)
461 | return dataset.assign_coords(time=lambda d: np.arange(d.sizes["time"]) * np.timedelta64(d.frame_period_ms, 'ms'))
462 |
463 |
464 | def apply_processing(dataset, funcs, dask="allowed"):
465 | """Apply a sequence of processing functions to an xarray in a lazy fashion
466 |
467 | Args:
468 | dataset: xarray Object
469 | funcs: sequence of ufuncs
470 | dask: Dask option to pass to apply_ufunc
471 |
472 | Returns: The processed xarray object
473 |
474 | """
475 | for fn, args, kwargs in funcs:
476 | dataset = xr.apply_ufunc(fn, dataset, *args, keep_attrs=True, kwargs=kwargs, dask=dask)
477 | return dataset
478 |
479 |
480 | def proc_register(processing_functions, func, *args, **kwargs):
481 | processing_functions.append((func, args, kwargs))
482 |
483 |
484 | def get_size(dataset, loaded_only=True, human=False):
485 | def get_bytes(ds, human_b=False):
486 | byte_size = 0
487 | for var in ds.variables.values():
488 | if not loaded_only or not is_dask_collection(var):
489 | byte_size += var.nbytes
490 | if human_b:
491 | byte_size = human_bytes(byte_size)
492 | return byte_size
493 |
494 | if any(isinstance(dataset, c) for c in (list, tuple)):
495 | summed_bytes = sum(get_bytes(d) for d in dataset)
496 | if human:
497 | summed_bytes = human_bytes(summed_bytes)
498 | return summed_bytes
499 | else:
500 | return get_bytes(dataset, human_b=human)
501 |
502 |
503 | def human_bytes(byte_size, digits=2, units=('bytes', 'kB', 'MB', 'GB', 'TB')):
504 | if byte_size <= 0:
505 | return False
506 | exponent = int(np.floor(np.log2(byte_size) / 10))
507 | mantissa = byte_size / (1 << (exponent * 10))
508 | hbytes = f"{mantissa:.{digits}f}{units[exponent]}"
509 |
510 | return hbytes
511 |
512 |
513 | def get_scope(ds, unit):
514 | with xr.set_options(keep_attrs=True):
515 | return ds.coords[unit][-1] - ds.coords[unit][0]
516 |
517 |
518 | def complementary_slices(slices):
519 | comp_slices = []
520 | next_start = None
521 | for sl in slices:
522 | if isinstance(sl, slice):
523 | sl = {'start': sl.start, 'stop': sl.stop}
524 | if next_start != sl["start"]:
525 | comp_slices.append(slice(next_start, sl["start"]))
526 | next_start = sl["stop"]
527 | if next_start is not None:
528 | comp_slices.append(slice(next_start, None))
529 | return comp_slices
530 |
531 |
532 | if __name__ == '__main__':
533 |
534 | import argparse
535 |
536 | parser = argparse.ArgumentParser(description='FMCW radar preprocessing')
537 | parser.add_argument("--raw", type=str, default=str(_raw_path), help="Path to the raw radar data")
538 | parser.add_argument("--output", type=str, default=str(_preprocessed_path),
539 | help="Path to the output the preprocessed data")
540 | parser.add_argument("--interference", type=str, default=str(_interference_path),
541 | help="Path to a JSON file indicating interference chunks to discard")
542 | parser.add_argument("--synthetic", action='store_true', help="Use to create radar signals from skeleton data")
543 | parser.add_argument("--range-length", type=int, default=None,
544 | help="Length of range FFT, default to half the number of samples")
545 | parser.add_argument("--doppler-length", type=int, default=None,
546 | help="Length of Doppler FFT, default to number of chirps")
547 | parser.add_argument("--range-scope", type=float, default=None,
548 | help="If given (in meters), range-length will be referred only until this point")
549 | parser.add_argument("--doppler-scope", type=float, default=None,
550 | help="If given (in m/s), doppler-length will be referred only until this point")
551 | parser.add_argument("--resample", type=float, default=None, help="New frame period to resample in milliseconds")
552 | parser.add_argument("--value", type=str, choices=_radar_val_keys, default=_radar_val_default,
553 | help="Save data as either complex amplitude, magnitude or decibels")
554 | parser.add_argument("--marginalize", type=str, choices=_margin_opts, default=_margin_default,
555 | help ="'none' to save range Doppler maps, "
556 | "otherwise (in)coherent range and Doppler spectrograms")
557 |
558 | p_args = parser.parse_args()
559 | preprocess_kwargs = vars(p_args)
560 |
561 | preprocess_kwargs["raw_path"] = preprocess_kwargs.pop("raw")
562 | preprocess_kwargs["preprocessed_path"] = preprocess_kwargs.pop("output")
563 | preprocess_kwargs["interference_path"] = preprocess_kwargs.pop("interference")
564 | preprocess_kwargs["resample_ms"] = preprocess_kwargs.pop("resample")
565 |
566 | preprocess(**preprocess_kwargs)
567 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import xarray as xr
3 | import torch
4 | import os
5 | import argparse
6 | import time
7 | import json
8 |
9 | import matplotlib.pyplot as plt
10 |
11 | from torch.autograd import Variable
12 | from torch.optim import Adam
13 | from torch.utils.data import DataLoader
14 | import torch.nn as nn
15 |
16 | from networks.img_transf import ImageTransformNet, MultiTransformNet
17 | from networks.perceptual import Vgg16, RDNet, RDPerceptual, RACPIT
18 |
19 | from utils.slicer import train_test_slice
20 | from utils.provider import RadarDataset
21 | from utils.preprocess import open_recordings
22 | from utils.visualization import spec_plot
23 |
24 | # Global Variables
25 | BATCH_TRANSFER = 4
26 | BATCH_CLASSIFY = 32
27 | LEARNING_RATE = 1e-3
28 | EPOCHS = 100
29 | CONTENT_WEIGHT = 1e0
30 | TV_WEIGHT = 1e-7
31 |
32 | REAL_PATH = "/mnt/infineon-radar/preprocessed/real"
33 | SYNTH_PATH = "/mnt/infineon-radar/preprocessed/synthetic"
34 |
35 | # Radar Processing variables
36 | range_length = 128
37 | doppler_length = 128
38 | time_length = 64
39 | hop_length = 8
40 | ignore_dims = False
41 |
42 |
43 | def train_transfer(args):
44 | # GPU enabling
45 | if args.gpu is None:
46 | use_cuda = False
47 | dtype = torch.FloatTensor
48 | label_type = torch.LongTensor
49 | print("No GPU training")
50 | else:
51 | use_cuda = True
52 | dtype = torch.cuda.FloatTensor
53 | label_type = torch.cuda.LongTensor
54 | torch.cuda.set_device(args.gpu)
55 | print("Current device: %d" % torch.cuda.current_device())
56 |
57 | epochs = args.epochs
58 |
59 | # visualization of training controlled by flag
60 | visualize = 0 if args.visualize is None else args.visualize
61 |
62 | # define network
63 | if args.range:
64 | image_transformer = MultiTransformNet(num_inputs=2, num_channels=1).type(dtype)
65 | else:
66 | image_transformer = ImageTransformNet(num_channels=1).type(dtype)
67 | optimizer = Adam(image_transformer.parameters(), LEARNING_RATE)
68 | criterion = nn.CrossEntropyLoss()
69 |
70 | loss_mse = torch.nn.MSELoss()
71 |
72 | # get training dataset
73 | config = args.config
74 | input_path = args.input
75 | output_path = args.output
76 | print(f"Training with configurations {', '.join(config)}.")
77 | print(f"Using data from {input_path} as the input and data from {output_path} as an output.")
78 |
79 | train_load = 0.8
80 | num_workers = 4
81 |
82 | if args.recordings is not None:
83 | split = 'no-cut'
84 | train_load = 0.5
85 | elif args.segments is None:
86 | split = 'single'
87 | else:
88 | with open(args.segments, "r") as f:
89 | split = json.load(f)
90 |
91 | recordings_input = open_recordings(config, input_path,
92 | load=True, range_length=range_length, doppler_length=doppler_length)
93 | recordings_output = open_recordings(config, output_path,
94 | load=True, range_length=range_length, doppler_length=doppler_length)
95 |
96 | # Merge recordings from all configs
97 | recordings_input = [r for recs in recordings_input.values() for r in recs]
98 | recordings_output = [r for recs in recordings_output.values() for r in recs]
99 | if args.range:
100 | recordings_input = [r.rename_vars(range_spect="range_real") for r in recordings_input]
101 | recordings_output = [r.rename_vars(range_spect="range_synth") for r in recordings_output]
102 | else: # drop range spectrograms
103 | recordings_input = [r.drop_vars("range_spect") for r in recordings_input]
104 | recordings_output = [r.drop_vars("range_spect") for r in recordings_output]
105 |
106 | recordings = [xr.merge([rec_real.rename_vars(doppler_spect="doppler_real"),
107 | rec_synth.rename_vars(doppler_spect="doppler_synth")], combine_attrs="drop_conflicts")
108 | for rec_real, rec_synth in zip(recordings_input, recordings_output)]
109 |
110 | if args.classes is not None:
111 | classes = args.classes
112 | print(f"Selecting classes {classes}")
113 | new_recs = []
114 | for r in recordings:
115 | if r.label in classes:
116 | new_recs.append(r)
117 | print(f"{len(new_recs)} out of {len(recordings)} selected")
118 | recordings = new_recs
119 |
120 | slice_kwargs = dict(spec_length=time_length, stride=hop_length, train_load=train_load, copy_split=0)
121 | loader_kwargs = dict(batch_size=BATCH_TRANSFER, shuffle=True, num_workers=num_workers, pin_memory=True)
122 |
123 | print("Preloading datasets...")
124 | slice_output = slice_datasets(recordings, split=split, **slice_kwargs)
125 | if args.recordings is None:
126 | if args.segments is None:
127 | [train_dataset, test_dataset], tgt_segments = slice_output
128 | else:
129 | train_dataset, test_dataset = slice_output
130 | elif args.recordings == 'first':
131 | [train_dataset, test_dataset], tgt_segments = slice_output
132 | elif args.recordings == 'last':
133 | [test_dataset, train_dataset], tgt_segments = slice_output
134 | else:
135 | raise ValueError(f"Unrecognized recordings option {args.recordings}")
136 |
137 | train_loader = DataLoader(train_dataset, **loader_kwargs)
138 |
139 | test_indices = np.random.choice(len(test_dataset), size=visualize, replace=False)
140 |
141 | class_num = train_dataset.class_num
142 | input_shapes = train_dataset.feature_shapes
143 | input_shapes = input_shapes[:len(input_shapes)//2]
144 | print(f"Number of classes: {class_num}")
145 | print(f"Feature shapes: {input_shapes}\n")
146 |
147 | # load perceptual loss network
148 | if args.model is None:
149 | perceptual_net = Vgg16().type(dtype) # Only works with single input i.e. without range
150 | else:
151 | perceptual_net = RDPerceptual(args.model, input_shapes=input_shapes, class_num=class_num).type(dtype)
152 |
153 | log_id = args.log
154 |
155 | # calculate gram matrices for style feature layer maps we care about
156 | # style_features = vgg(style)
157 | # style_gram = [utils.gram(fmap) for fmap in style_features]
158 |
159 | loss_logs = []
160 |
161 | for e in range(epochs):
162 |
163 | # track values for...
164 | img_count = 0
165 | aggregate_content_loss = 0.0
166 | aggregate_classify_loss = 0.0
167 | aggregate_tv_loss = 0.0
168 | batch_num = 0
169 |
170 | # train network
171 | image_transformer.train()
172 | for batch_num, (feature_batch, label) in enumerate(train_loader):
173 | img_batch_read = len(label)
174 | img_count += img_batch_read
175 |
176 | if args.range:
177 | [real_range, real_doppler, synth_range, synth_doppler] = feature_batch
178 | x = [Variable(real_feat).type(dtype) for real_feat in (real_range, real_doppler)]
179 | y_c = [Variable(synth_feat).type(dtype) for synth_feat in (synth_range, synth_doppler)]
180 | pass
181 | else:
182 | [real_batch, synth_batch] = feature_batch
183 | x = Variable(real_batch).type(dtype)
184 | y_c = Variable(synth_batch).type(dtype)
185 | label_true = Variable(label).type(label_type)
186 |
187 | # zero out gradients
188 | optimizer.zero_grad()
189 |
190 | # input batch to transformer network
191 | y_hat = image_transformer(x)
192 |
193 | # get vgg features
194 | y_c_features = perceptual_net(y_c)
195 | y_hat_features = perceptual_net(y_hat)
196 |
197 | # calculate classification loss w.r.t. input
198 | label_pred = y_hat_features[0]
199 | classify_loss = CONTENT_WEIGHT*criterion(label_pred, label_true)
200 | aggregate_classify_loss += classify_loss.item()
201 |
202 | # calculate content loss (h_relu_2_2)
203 | recon = y_c_features[1]
204 | recon_hat = y_hat_features[1]
205 | content_loss = CONTENT_WEIGHT*loss_mse(recon_hat, recon)
206 | aggregate_content_loss += content_loss.item()
207 |
208 | # calculate total variation regularization (anisotropic version)
209 | # https://www.wikiwand.com/en/Total_variation_denoising
210 | if args.range:
211 | diff_i = 0.0
212 | diff_j = 0.0
213 | for y_h in y_hat:
214 | diff_i += torch.sum(torch.abs(y_h[:, :, :, 1:] - y_h[:, :, :, :-1]))
215 | diff_j += torch.sum(torch.abs(y_h[:, :, 1:, :] - y_h[:, :, :-1, :]))
216 | else:
217 | diff_i = torch.sum(torch.abs(y_hat[:, :, :, 1:] - y_hat[:, :, :, :-1]))
218 | diff_j = torch.sum(torch.abs(y_hat[:, :, 1:, :] - y_hat[:, :, :-1, :]))
219 | tv_loss = args.tv_weight*(diff_i + diff_j)
220 | aggregate_tv_loss += tv_loss.item()
221 |
222 | # total loss
223 | total_loss = content_loss + tv_loss
224 |
225 | # backprop
226 | total_loss.backward()
227 | optimizer.step()
228 |
229 | # print out status message
230 | if (batch_num + 1) % 100 == 0:
231 | status = f"{time.ctime()} Epoch {e + 1}: " \
232 | f"[{img_count}/{len(train_dataset)}] Batch:[{batch_num+1}] " \
233 | f"agg_content: {aggregate_content_loss/(batch_num+1.0):.6f} " \
234 | f"agg_class: {aggregate_classify_loss / (batch_num + 1.0):.6f} " \
235 | f"agg_tv: {aggregate_tv_loss/(batch_num+1.0):.6f} " \
236 | f"content: {content_loss:.6f} class: {classify_loss:.6f} tv: {tv_loss:.6f} "
237 | print(status)
238 |
239 | if ((batch_num + 1) % 1000 == 0) and visualize is not None:
240 | image_transformer.eval()
241 |
242 | if not os.path.exists("visualization"):
243 | os.makedirs("visualization")
244 | if not os.path.exists("visualization/%s" %log_id):
245 | os.makedirs("visualization/%s" %log_id)
246 |
247 | for img_index in test_indices:
248 | test_ds = test_dataset.dataset[int(img_index)]
249 | doppler_test = torch.from_numpy(test_ds.doppler_real.values[None, None, :, :])
250 |
251 | plt_path = f"visualization/{log_id}/" \
252 | f"{test_ds.activity}_{test_ds.date.replace(':','-')}_e{e+1}_b{batch_num+1}.png"
253 |
254 | x_test = Variable(doppler_test, requires_grad=False).type(dtype)
255 | titles = ("Real data", "Synthetic data", "Generated data")
256 | if args.range:
257 | range_test = torch.from_numpy(test_ds.range_real.values[None, None, :, :])
258 | x_test = [Variable(range_test, requires_grad=False).type(dtype), x_test]
259 | range_hat, doppler_hat = image_transformer(x_test)
260 | range_hat = range_hat.cpu().detach().numpy()
261 | doppler_hat = doppler_hat.cpu().detach().numpy()
262 | test_ds["range_gen"] = (['time', 'range'], np.squeeze(range_hat), {"units": "dB"})
263 | test_ds["doppler_gen"] = (['time', 'doppler'], np.squeeze(doppler_hat), {"units": "dB"})
264 | range_ds = test_ds[["range_real", "range_synth", "range_gen"]]
265 | doppler_ds = test_ds[["doppler_real", "doppler_synth", "doppler_gen"]]
266 | fig, axes = plt.subplots(3, 2, figsize=(11, 6))
267 | spec_plot(range_ds, axes=[ax[0] for ax in axes], vmin=-40, vmax=0, add_colorbar=False)
268 | spec_plot(doppler_ds, axes=[ax[1] for ax in axes], vmin=-40, vmax=0, add_colorbar=False)
269 | for ax_pair, title in zip(axes, titles):
270 | if title != titles[-1]:
271 | for ax in ax_pair:
272 | ax.axes.get_xaxis().set_visible(False)
273 | ax_pair[0].set_title(title)
274 | cbar = fig.colorbar(axes[0][0].get_images()[0], ax=axes, orientation='vertical')
275 | cbar.set_label('Amplitude [dB]')
276 | else:
277 | y_hat_test = image_transformer(x_test).cpu().detach().numpy()
278 | test_ds["doppler_gen"] = (['time', 'doppler'], np.squeeze(y_hat_test), {"units": "dB"})
279 | spec_plot(test_ds, vmin=-40, vmax=0, cbar_global="Amplitude [dB]")
280 |
281 | axes = plt.gcf().axes
282 | for ax, title in zip(axes, titles):
283 | if title != titles[-1]:
284 | ax.axes.get_xaxis().set_visible(False)
285 | ax.set_title(title)
286 |
287 | plt.savefig(plt_path)
288 | plt.close()
289 |
290 | print("images saved")
291 | image_transformer.train()
292 |
293 | loss_logs.append({'content': aggregate_content_loss/(batch_num+1.0),
294 | 'class': aggregate_classify_loss/(batch_num+1.0),
295 | 'tv': aggregate_tv_loss/(batch_num+1.0)})
296 |
297 | # save model
298 | image_transformer.eval()
299 |
300 | if use_cuda:
301 | image_transformer.cpu()
302 |
303 | with open(f"log/{args.log}_loss.json", "w") as wf:
304 | json.dump(loss_logs, wf, indent=4)
305 |
306 | if args.plot:
307 | content_loss = [log['content'] for log in loss_logs]
308 | class_loss = [log['class'] for log in loss_logs]
309 | fig, [ax_content, ax_class] = plt.subplots(2, 1)
310 |
311 | ax_content.plot(content_loss)
312 | ax_content.set_xlabel("Epochs")
313 | ax_content.set_ylabel("Content loss")
314 |
315 | ax_class.plot(class_loss)
316 | ax_class.set_xlabel("Epochs")
317 | ax_class.set_ylabel("Classification loss")
318 |
319 | plt.savefig(f"log/{args.log}.png")
320 |
321 | filename = "models/" + args.log + ".model"
322 | log_dir = os.path.dirname(filename)
323 | if not os.path.exists(log_dir):
324 | os.makedirs(log_dir)
325 | torch.save(image_transformer.state_dict(), filename)
326 |
327 | if use_cuda:
328 | image_transformer.cuda()
329 |
330 |
331 | def train_classify(args):
332 | # GPU enabling
333 | if args.gpu is None:
334 | use_cuda = False
335 | in_type = torch.FloatTensor
336 | out_type = torch.LongTensor
337 | print("No GPU training")
338 | else:
339 | use_cuda = True
340 | in_type = torch.cuda.FloatTensor
341 | out_type = torch.cuda.LongTensor
342 | torch.cuda.set_device(args.gpu)
343 | print("Current device: %d" % torch.cuda.current_device())
344 |
345 | # get training dataset
346 | config = args.config
347 | data_path = args.dataset
348 | print(f"Training with configurations {', '.join(config)} from the dataset under {data_path}")
349 |
350 | epochs = args.epochs
351 |
352 | num_workers = 4
353 |
354 | if args.no_split:
355 | train_load = 'fake-load'
356 | split = None
357 | elif args.recordings is None:
358 | split = 'single'
359 | train_load = 0.8
360 | else:
361 | split = 'no-cut'
362 | train_load = 0.5
363 |
364 | recordings = open_recordings(config, data_path, load=True, range_length=range_length, doppler_length=doppler_length)
365 | # Merge recordings from all configs
366 | train_recordings = [r for recs in recordings.values() for r in recs]
367 | if not args.range: # drop range spectrograms
368 | train_recordings = [r.drop_vars("range_spect") for r in train_recordings]
369 |
370 | slice_kwargs = dict(spec_length=time_length, stride=hop_length, train_load=train_load, copy_split=0)
371 | loader_kwargs = dict(batch_size=BATCH_CLASSIFY, shuffle=True, num_workers=num_workers, pin_memory=True)
372 |
373 | print("Preloading datasets...")
374 | if args.no_split:
375 | train_dataset = slice_datasets(train_recordings, split=split, **slice_kwargs)
376 | test_dataset = train_dataset
377 | tgt_segments = {}
378 | elif args.recordings is None or args.recordings == 'first':
379 | [train_dataset, test_dataset], tgt_segments = slice_datasets(train_recordings, split=split, **slice_kwargs)
380 | elif args.recordings == 'last':
381 | [test_dataset, train_dataset], tgt_segments = slice_datasets(train_recordings, split=split, **slice_kwargs)
382 | else:
383 | raise ValueError(f"Unexpected recordings argument {args.recordings}")
384 | with open(f"log/{args.log}_segments.json", "w") as wf:
385 | json.dump(tgt_segments, wf, indent=4)
386 |
387 | train_loader = DataLoader(train_dataset, **loader_kwargs)
388 | test_loader = DataLoader(test_dataset, **loader_kwargs)
389 |
390 | class_num = train_dataset.class_num
391 | input_shapes = train_dataset.feature_shapes
392 | print(f"Number of classes: {class_num}")
393 | print(f"Feature shapes: {input_shapes}\n")
394 |
395 | # define network
396 | c_net = RDNet(input_shapes=input_shapes, class_num=class_num).type(in_type)
397 |
398 | optimizer = Adam(c_net.parameters(), LEARNING_RATE)
399 | criterion = nn.CrossEntropyLoss()
400 |
401 | loss_logs = []
402 |
403 | early_stop_thresh = .96
404 |
405 | for e in range(epochs):
406 |
407 | # track values for...
408 | img_count = 0
409 | train_loss = 0.0
410 | train_acc = 0.0
411 | test_acc = 0.0
412 | batch_num = 0
413 |
414 | # train network
415 | c_net.train()
416 | for batch_num, batch in enumerate(train_loader):
417 | feature_batch, label_batch = batch
418 | if args.range:
419 | x = [Variable(feat).type(in_type) for feat in feature_batch]
420 | else:
421 | x = Variable(feature_batch[0]).type(in_type)
422 |
423 | img_batch_read = len(label_batch)
424 | img_count += img_batch_read
425 |
426 | # zero out gradients
427 | optimizer.zero_grad()
428 |
429 | # input batch to classifier network
430 | y_true = Variable(label_batch).type(out_type)
431 | y_pred = c_net(x)
432 |
433 | loss = criterion(y_pred, y_true)
434 | loss.backward()
435 | optimizer.step()
436 |
437 | # print statistics
438 | train_loss += loss.item()
439 | train_acc += evaluate(c_net, [batch]).item() * img_batch_read
440 | if (batch_num + 1) % 100 == 0:
441 | test_acc = evaluate(c_net, test_loader).item()
442 | status = f"{time.ctime()} Epoch {e + 1}: " \
443 | f"[{img_count}/{len(train_dataset)}] Batch:[{batch_num + 1}] " \
444 | f"train_loss: {train_loss / (batch_num + 1.0):.6f} " \
445 | f"train_acc: {train_acc / img_count:.6f} test_acc: {test_acc:.6f}"
446 | print(status)
447 |
448 | loss_logs.append({'train_loss': train_loss / (batch_num + 1.0),
449 | 'train_acc': train_acc / img_count, 'test_acc': test_acc})
450 |
451 | if train_acc / img_count > early_stop_thresh:
452 | print("***Early stopping training***")
453 | break
454 |
455 | # save model
456 | c_net.eval()
457 |
458 | if use_cuda:
459 | c_net.cpu()
460 |
461 | with open(f"log/{args.log}_loss.json", "w") as wf:
462 | json.dump(loss_logs, wf, indent=4)
463 |
464 | if args.plot:
465 | loss = []
466 | train_acc = []
467 | test_acc = []
468 | for log in loss_logs:
469 | loss.append(log['train_loss'])
470 | train_acc.append(log['train_acc'])
471 | test_acc.append(log['test_acc'])
472 | fig, [ax_loss, ax_acc] = plt.subplots(2, 1)
473 | ax_loss.plot(loss)
474 | ax_loss.set_xlabel("Epochs")
475 | ax_loss.set_ylabel("Train loss")
476 |
477 | ax_acc.plot(train_acc)
478 | ax_acc.plot(test_acc)
479 | ax_acc.set_xlabel("Epochs")
480 | ax_acc.set_ylabel("Accuracy")
481 | ax_acc.legend(["train", "test"])
482 | plt.savefig(f"log/{args.log}.png")
483 |
484 | filename = "models/" + args.log + ".model"
485 | log_dir = os.path.dirname(filename)
486 | if not os.path.exists(log_dir):
487 | os.makedirs(log_dir)
488 | torch.save(c_net.state_dict(), filename)
489 |
490 | if use_cuda:
491 | c_net.cuda()
492 |
493 | return loss_logs[-1]["train_acc"]
494 |
495 |
496 | def test(args):
497 | # GPU enabling
498 | if args.gpu is None:
499 | dtype = torch.FloatTensor
500 | print("No GPU in use")
501 | else:
502 | dtype = torch.cuda.FloatTensor
503 | torch.cuda.set_device(args.gpu)
504 | print("Current device: %d" % torch.cuda.current_device())
505 |
506 | # get training dataset
507 | config = args.config
508 | data_path = args.dataset
509 |
510 | train_load = 0.5
511 | num_workers = 4
512 |
513 | recordings = open_recordings(config, data_path, load=True, range_length=range_length, doppler_length=doppler_length)
514 | # Merge recordings from all configs
515 | recordings = [r for recs in recordings.values() for r in recs]
516 | if not args.range: # drop range spectrograms
517 | recordings = [r.drop_vars("range_spect") for r in recordings]
518 |
519 | slice_kwargs = dict(spec_length=time_length, stride=hop_length, train_load=train_load)
520 | loader_kwargs = dict(batch_size=BATCH_TRANSFER, shuffle=False, num_workers=num_workers, pin_memory=True)
521 |
522 | print("Preloading datasets...")
523 | if args.recordings is None:
524 | if args.segments is None:
525 | test_dataset = slice_datasets(recordings, split=None, **slice_kwargs)
526 | else:
527 | with open(args.segments, "r") as f:
528 | segments = json.load(f)
529 | _, test_dataset = slice_datasets(recordings, split=segments, **slice_kwargs)
530 | elif args.recordings == 'first':
531 | [test_dataset, _], tgt_segments = slice_datasets(recordings, split='no-cut', **slice_kwargs)
532 | elif args.recordings == 'last':
533 | [_, test_dataset], tgt_segments = slice_datasets(recordings, split='no-cut', **slice_kwargs)
534 | else:
535 | raise ValueError(f"Unrecognized recordings option {args.recordings}")
536 |
537 | test_loader = DataLoader(test_dataset, **loader_kwargs)
538 |
539 | class_num = test_dataset.class_num
540 | input_shapes = test_dataset.feature_shapes
541 | print(f"Number of classes: {class_num}")
542 | print(f"Feature shapes: {input_shapes}\n")
543 |
544 | # load network including transformer and classifier
545 | if args.transformer is None:
546 | trans_c_net = RDNet(input_shapes=input_shapes, class_num=class_num).type(dtype)
547 | trans_c_net.load_state_dict(torch.load(args.classifier))
548 | else:
549 | trans_c_net = RACPIT(trans_path=args.transformer, model_path=args.classifier,
550 | input_shapes=input_shapes, class_num=class_num).type(dtype)
551 | trans_c_net.eval()
552 |
553 | accuracy, predict_info = evaluate(trans_c_net, test_loader, predict_info=True)
554 | accuracy = accuracy.item()
555 | print(f"{accuracy:.6f} accuracy on configurations {', '.join(config)} from {data_path}")
556 |
557 | true_labels = predict_info["real"]
558 | predictions = predict_info["predict"]
559 | correct = predict_info["correct"].cpu().detach().numpy().astype(np.bool)
560 | misclassified = np.nonzero(np.logical_not(correct))[0]
561 |
562 | confusion_matrix(predictions, true_labels, nb_classes=class_num)
563 |
564 | visualize = args.visualize
565 |
566 | if visualize <= 0:
567 | return accuracy
568 |
569 | if not os.path.exists("visualization/%s" % args.log):
570 | os.makedirs("visualization/%s" % args.log)
571 |
572 | true_labels = true_labels.cpu().detach().numpy().astype(np.int)
573 | predictions = predictions.cpu().detach().numpy().astype(np.int)
574 |
575 | test_indices = np.random.choice(misclassified, size=visualize, replace=False)
576 | activities = test_dataset.attrs['activities']
577 |
578 | for img_index in test_indices:
579 | test_ds = test_dataset.dataset[int(img_index)]
580 | doppler_test = torch.from_numpy(test_ds.doppler_spect.values[None, None, :, :])
581 |
582 | true_activity = activities[true_labels[img_index]]
583 | pred_activity = activities[predictions[img_index]]
584 |
585 | assert true_activity != pred_activity, "Prediction is correct"
586 | assert true_activity == test_ds.activity, "The true activity does not coincide with the embedded activity"
587 |
588 | plt_path = f"visualization/{args.log}/" \
589 | f"{true_activity}_{test_ds.date.replace(':', '-')}_{pred_activity}.png"
590 |
591 | x_test = Variable(doppler_test, requires_grad=False).type(dtype)
592 | titles = ("Real data", f"Generated data, classified as {pred_activity}")
593 | if args.range:
594 | range_test = torch.from_numpy(test_ds.range_spect.values[None, None, :, :])
595 | x_test = [Variable(range_test, requires_grad=False).type(dtype), x_test]
596 | range_hat, doppler_hat = trans_c_net.transformer(x_test)
597 | range_hat = range_hat.cpu().detach().numpy()
598 | doppler_hat = doppler_hat.cpu().detach().numpy()
599 | test_ds["range_gen"] = (['time', 'range'], np.squeeze(range_hat), {"units": "dB"})
600 | test_ds["doppler_gen"] = (['time', 'doppler'], np.squeeze(doppler_hat), {"units": "dB"})
601 | range_ds = test_ds[["range_spect", "range_gen"]]
602 | doppler_ds = test_ds[["doppler_spect", "doppler_gen"]]
603 | fig, axes = plt.subplots(2, 2, figsize=(11, 6))
604 | spec_plot(range_ds, axes=[ax[0] for ax in axes], vmin=-40, vmax=0, add_colorbar=False)
605 | spec_plot(doppler_ds, axes=[ax[1] for ax in axes], vmin=-40, vmax=0, add_colorbar=False)
606 | for ax_pair, title in zip(axes, titles):
607 | if title != titles[-1]:
608 | for ax in ax_pair:
609 | ax.axes.get_xaxis().set_visible(False)
610 | ax_pair[0].set_title(title)
611 | cbar = fig.colorbar(axes[0][0].get_images()[0], ax=axes, orientation='vertical')
612 | cbar.set_label('Amplitude [dB]')
613 | else:
614 | y_hat_test = trans_c_net.transformer(x_test).cpu().detach().numpy()
615 | test_ds["doppler_gen"] = (['time', 'doppler'], np.squeeze(y_hat_test), {"units": "dB"})
616 | spec_plot(test_ds, vmin=-40, vmax=0, cbar_global="Amplitude [dB]")
617 |
618 | axes = plt.gcf().axes
619 | for ax, title in zip(axes, titles):
620 | if title != titles[-1]:
621 | ax.axes.get_xaxis().set_visible(False)
622 | ax.set_title(title)
623 |
624 | plt.savefig(plt_path)
625 | plt.close()
626 |
627 | return accuracy
628 |
629 |
630 | def slice_datasets(recordings, spec_length, stride, train_load=0.8, split=None, copy_split=0):
631 | if copy_split > 1:
632 | effective_len = len(recordings) // copy_split
633 | else:
634 | effective_len = len(recordings)
635 | if split is None:
636 | slices = train_test_slice(recordings, spec_length, stride, train_load, split=split)
637 | rd_dataset = RadarDataset(recordings, slices=slices, ignore_dims=ignore_dims)
638 | return rd_dataset
639 | elif isinstance(split, dict):
640 | slices = train_test_slice(recordings[:effective_len], spec_length, stride, train_load, verbose=False,
641 | split=split, return_segments=False)
642 | if copy_split > 1:
643 | slices = [copy_slices(sl, effective_len, copy_split) for sl in slices]
644 | rd_datasets = [RadarDataset(recordings, slices=s, ignore_dims=ignore_dims) for s in slices]
645 | # assert set(index for index, sl in slices[0]) == set(range(len(recordings))), \
646 | # "Slices do not include all recordings"
647 | return rd_datasets
648 | else:
649 | slices = train_test_slice(recordings[:effective_len], spec_length, stride, train_load, verbose=False,
650 | split=split, return_segments=True)
651 | segments = slices.pop(-1)
652 | if copy_split > 1:
653 | slices = [copy_slices(sl, effective_len, copy_split) for sl in slices]
654 | rd_datasets = [RadarDataset(recordings, slices=s, ignore_dims=ignore_dims) for s in slices]
655 | # assert set(index for index, sl in slices[0]) == set(range(len(recordings))), \
656 | return rd_datasets, segments
657 |
658 |
659 | def copy_slices(slices, num_recs, repeat):
660 | new_slices = []
661 | for n in range(repeat):
662 | new_slices += [(index + num_recs * n, sl) for index, sl in slices]
663 | return new_slices
664 |
665 |
666 | # ============== eval
667 | def evaluate(model_instance, input_loader, gpu=True, predict_info=False):
668 | ori_train_state = model_instance.training
669 | model_instance.eval()
670 | first_test = True
671 | all_probs, all_labels = None, None
672 |
673 | for data in input_loader:
674 | inputs = data[0]
675 | labels = data[1]
676 | if gpu:
677 | inputs = [inp.cuda() for inp in inputs]
678 | labels = labels.cuda()
679 |
680 | probabilities = model_instance.predict(inputs)
681 |
682 | probabilities = probabilities.data.float()
683 | labels = labels.data.float()
684 | if first_test:
685 | all_probs = probabilities
686 | all_labels = labels
687 | first_test = False
688 | else:
689 | all_probs = torch.cat((all_probs, probabilities), 0)
690 | all_labels = torch.cat((all_labels, labels), 0)
691 |
692 | _, predict = torch.max(all_probs, 1)
693 | predict = torch.squeeze(predict).float()
694 | correct = predict == all_labels
695 | accuracy = torch.sum(correct) / float(all_labels.size()[0])
696 |
697 | model_instance.train(ori_train_state)
698 |
699 | if predict_info:
700 | predictions = {"predict": predict, "real": all_labels, "correct": correct}
701 | return accuracy, predictions
702 | else:
703 | return accuracy
704 |
705 |
706 | def confusion_matrix(predicted, true_label, nb_classes):
707 | cm = torch.zeros(nb_classes, nb_classes)
708 | with torch.no_grad():
709 | for t, p in zip(true_label.view(-1), predicted.view(-1)):
710 | cm[t.long(), p.long()] += 1
711 | print(cm)
712 |
713 |
714 | def save_params(parameters):
715 | log_file = f"log/{parameters['log']}_params.json"
716 | log_dir = os.path.dirname(log_file)
717 | if not os.path.exists(log_dir):
718 | os.makedirs(log_dir)
719 | with open(log_file, "w") as wf:
720 | json.dump(parameters, wf, indent=4)
721 |
722 |
723 | def main():
724 | parser = argparse.ArgumentParser(description='style transfer in pytorch')
725 | parser.add_argument("--log", type=str, default=None, help="ID to mark output files and logs. Default to timestamp")
726 | subparsers = parser.add_subparsers(title="subcommands", dest="subcommand")
727 |
728 | train_parser = subparsers.add_parser("train-transfer", help="train a model to do style transfer")
729 | train_parser.add_argument("--plot", action='store_true', help="Plot and save a training report")
730 | train_parser.add_argument("--gpu", type=int, default=None, help="ID of GPU to be used")
731 | train_parser.add_argument("--epochs", type=int, default=EPOCHS, help="Number of epochs for training")
732 | train_parser.add_argument("--tv-weight", type=float, default=TV_WEIGHT, help="Weight of TV regularization")
733 | train_parser.add_argument("--model", type=str, default=None, help="Path to a saved model to use "
734 | "for perceptual loss. VGG16 as default.")
735 | train_parser.add_argument("--recordings", type=str, default=None, help="Select to use 'first' or 'last' recordings")
736 | train_parser.add_argument("--segments", type=str, default=None, help="path to a segment file")
737 | train_parser.add_argument("--visualize", type=int, default=None, help="Set to 1 if you want to visualize training")
738 | train_parser.add_argument("--input", type=str, default=REAL_PATH, help="Path to input training dataset")
739 | train_parser.add_argument("--output", type=str, default=SYNTH_PATH, help="Path to output training dataset")
740 | train_parser.add_argument("--config", type=str, nargs='*', default=["F"], help="Radar configurations to train with")
741 | train_parser.add_argument("--range", action='store_true', help="Use range information alongside doppler")
742 | train_parser.add_argument("--classes", type=int, nargs='*', default=None,
743 | help="Classes to train with, default to all")
744 |
745 | classify_parser = subparsers.add_parser("train-classify", help="train a model to classify human activity")
746 | classify_parser.add_argument("--plot", action='store_true', help="Plot and save a training report")
747 | classify_parser.add_argument("--gpu", type=int, default=None, help="ID of GPU to be used")
748 | classify_parser.add_argument("--epochs", type=int, default=EPOCHS, help="Number of epochs for training")
749 | classify_parser.add_argument("--min-acc", type=float, default=0.0, help="Retrain until min. accuracy is reached")
750 | classify_parser.add_argument("--recordings", type=str, default=None,
751 | help="Select to use 'first' or 'last' recordings")
752 | classify_parser.add_argument("--dataset", type=str, default=SYNTH_PATH, help="Path to training dataset")
753 | classify_parser.add_argument("--config", type=str, nargs='*', default=["E", "F"],
754 | help="Radar configurations to train with")
755 | classify_parser.add_argument("--range", action='store_true', help="Use range information alongside doppler")
756 | classify_parser.add_argument("--no-split", action='store_true', help="Do not split data into test/train")
757 |
758 | test_parser = subparsers.add_parser("test", help="test a model to apply human activity classification")
759 | test_parser.add_argument("--gpu", type=int, default=None, help="ID of GPU to be used")
760 | test_parser.add_argument("--visualize", type=int, default=0,
761 | help="Number of misclassified spectrograms to show")
762 | test_parser.add_argument("--transformer", type=str, default=None, help="Path to a saved model to use "
763 | "for image transform")
764 | test_parser.add_argument("--classifier", type=str, required=True, help="Path to a saved model to use "
765 | "for classification.")
766 | test_parser.add_argument("--recordings", type=str, default=None, help="Select to use 'first' or 'last' recordings")
767 | test_parser.add_argument("--segments", type=str, default=None, help="path to a segment file")
768 | test_parser.add_argument("--dataset", type=str, default=REAL_PATH, help="Path to the dataset "
769 | "to feed the transformer")
770 | test_parser.add_argument("--config", type=str, nargs='*', default=["E"], help="Radar configurations to test with")
771 | test_parser.add_argument("--range", action='store_true', help="Use range information alongside doppler")
772 |
773 | args = parser.parse_args()
774 |
775 | params = {"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S")}
776 | print(f"Process started at {params['timestamp']}")
777 | if args.log is None:
778 | args.log = params["timestamp"].replace(':', '')
779 | params.update(vars(args))
780 | save_params(params)
781 |
782 | # command
783 | if args.subcommand == "train-transfer":
784 | print("Training image transfer!")
785 | train_transfer(args)
786 | elif args.subcommand == "train-classify":
787 | print("Training classifier!")
788 | train_classify(args)
789 | acc = train_classify(args)
790 | while acc < args.min_acc:
791 | acc = train_classify(args)
792 | elif args.subcommand == "test":
793 | print("Testing!")
794 | acc = test(args)
795 | params["accuracy"] = acc
796 | save_params(params)
797 | else:
798 | print("invalid command")
799 |
800 |
801 | if __name__ == '__main__':
802 | main()
803 |
--------------------------------------------------------------------------------