├── LICENSE ├── README.md ├── conda_env ├── env_complete.yml ├── env_explicit.txt └── env_short.yml ├── configs ├── data │ ├── ant_fspl_floor.yaml │ ├── ant_fspl_floor_top.yaml │ ├── ant_fspl_slices_4m.yaml │ ├── ant_gain_floor.yaml │ ├── ant_slices_1m.yaml │ ├── ant_slices_4m.yaml │ ├── base.yaml │ ├── base_cyl.yaml │ ├── base_cyl_los_binary.yaml │ ├── base_cyl_los_min.yaml │ ├── base_eucl.yaml │ ├── base_eucl_los_binary.yaml │ ├── base_eucl_los_min.yaml │ ├── base_ga.yaml │ ├── base_ga_los_binary.yaml │ ├── base_ga_los_min.yaml │ ├── base_los_binary.yaml │ ├── base_los_min.yaml │ ├── base_slices.yaml │ ├── base_sph.yaml │ ├── base_sph_los_binary.yaml │ ├── base_sph_los_min.yaml │ ├── img.yaml │ ├── img_ga_azi_dist.yaml │ ├── img_ndsm.yaml │ ├── img_ndsm_ga_azi_dist.yaml │ └── img_rgb.yaml └── training │ └── lower_lr.yaml ├── lib ├── PMNet │ └── PMNet.py ├── RadioUNet │ └── RadioUNet.py ├── __init__.py ├── dcn.py ├── pl_callbacks.py ├── pl_datamodule.py ├── pl_lightningmodule.py ├── torch_datasets.py ├── torch_layers.py └── utils_coords.py ├── main_cli.py └── sample.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Fabian Jaensch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Radio Map Estimation - An Open Dataset with Directive Transmitter Antennas and Initial Experiments 2 | 3 | This is the official implementation of our experiments on learning radio map estimation with CNNs, described in ["Radio Map Estimation - An Open Dataset with Directive Transmitter Antennas and Initial Experiments"](https://arxiv.org/abs/2402.00878). 4 | 5 | ![alt text](sample.png "Sample") 6 | 7 | ## Requirements 8 | 9 | The dataset can be downloaded from [zenodo](https://zenodo.org/uploads/10210089) and is expected to be unpacked to the directory *./dataset*. 10 | 11 | To install the required packages via conda run: 12 | 13 | ``` 14 | conda env create -f conda_env/env_complete.yml 15 | ``` 16 | 17 | The environment has been used on Linux computers with CUDA 11.8 and A100 GPUs. On different OS/hardware, you may need to use the less restrictive file [conda_env/env_short.yml](conda_env/env_short.yml) or adjust some packages. 18 | 19 | ## Basic Usage 20 | 21 | To replicate the experiments from the paper, run this command: 22 | 23 | ``` 24 | python main_cli.py fit --model= --config= 25 | ``` 26 | 27 | Here, `````` can be any of _LitRadioUNet, LitPMNet_ or _LitUNetDCN_ and the configs for the dataset class corresponding to our experiments can be found in the directory [configs/data](configs/data). The training procedure will save the results including a model checkpoint, log file, config and Tensorboard log in a subdirectory of [./logs](./logs). 28 | 29 | Instead of training from scratch, you can [download](https://zenodo.org/uploads/10210089) the checkpoints and configs for some of the trained models. 30 | 31 | Trained models can be evaluated on the test set by running 32 | 33 | ``` 34 | python main_cli.py test --config= --ckpt_path= --trainer.logger.sub_dir=test 35 | ``` 36 | 37 | and inference on the test set is possible with: 38 | 39 | ``` 40 | python main_cli.py predict --config= --ckpt_path= --trainer.logger.sub_dir=predict 41 | ``` 42 | 43 | ## More options 44 | 45 | Arguments for the dataset class (inputs for the model) and hyperparameters of the models and the training procedure can be set with flags. To get an overview of all possible commands, run: 46 | 47 | ``` 48 | python main_cli.py fit --help 49 | python main_cli.py fit --model.help 50 | python main_cli.py fit --data.help LitRM_directional 51 | ``` 52 | 53 | More information can be found in the documentation of [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) and in particular the [CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html#lightning-cli). 54 | 55 | -------------------------------------------------------------------------------- /conda_env/env_complete.yml: -------------------------------------------------------------------------------- 1 | name: rmbd_env 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - _libgcc_mutex=0.1 9 | - _openmp_mutex=4.5 10 | - absl-py=1.4.0 11 | - aiohttp=3.8.4 12 | - aiosignal=1.3.1 13 | - alsa-lib=1.2.8 14 | - aom=3.5.0 15 | - asttokens=2.2.1 16 | - async-timeout=4.0.2 17 | - attr=2.5.1 18 | - attrs=23.1.0 19 | - backcall=0.2.0 20 | - backports=1.0 21 | - backports.functools_lru_cache=1.6.4 22 | - blas=1.0 23 | - blinker=1.6.2 24 | - blosc=1.21.4 25 | - brotli=1.0.9 26 | - brotli-bin=1.0.9 27 | - brotlipy=0.7.0 28 | - brunsli=0.1 29 | - bzip2=1.0.8 30 | - c-ares=1.19.1 31 | - c-blosc2=2.9.2 32 | - ca-certificates=2023.7.22 33 | - cachetools=5.3.0 34 | - cairo=1.16.0 35 | - certifi=2023.7.22 36 | - cffi=1.15.1 37 | - cfitsio=4.2.0 38 | - charls=2.4.2 39 | - charset-normalizer=3.1.0 40 | - click=8.1.3 41 | - cloudpickle=2.2.1 42 | - colorama=0.4.6 43 | - comm=0.1.3 44 | - contextlib2=21.6.0 45 | - contourpy=1.1.0 46 | - cryptography=41.0.1 47 | - cuda-cudart=11.8.89 48 | - cuda-cupti=11.8.87 49 | - cuda-libraries=11.8.0 50 | - cuda-nvrtc=11.8.89 51 | - cuda-nvtx=11.8.86 52 | - cuda-runtime=11.8.0 53 | - cycler=0.11.0 54 | - cytoolz=0.12.0 55 | - dask-core=2023.6.0 56 | - dav1d=1.2.1 57 | - dbus=1.13.6 58 | - debugpy=1.6.7 59 | - decorator=5.1.1 60 | - einops=0.6.1 61 | - executing=1.2.0 62 | - expat=2.5.0 63 | - ffmpeg=4.3 64 | - fftw=3.3.10 65 | - filelock=3.12.2 66 | - font-ttf-dejavu-sans-mono=2.37 67 | - font-ttf-inconsolata=3.000 68 | - font-ttf-source-code-pro=2.038 69 | - font-ttf-ubuntu=0.83 70 | - fontconfig=2.14.2 71 | - fonts-conda-ecosystem=1 72 | - fonts-conda-forge=1 73 | - fonttools=4.40.0 74 | - freetype=2.12.1 75 | - frozenlist=1.3.3 76 | - fsspec=2023.6.0 77 | - gettext=0.21.1 78 | - giflib=5.2.1 79 | - glib=2.76.3 80 | - glib-tools=2.76.3 81 | - gmp=6.2.1 82 | - gmpy2=2.1.2 83 | - gnutls=3.6.13 84 | - google-auth=2.20.0 85 | - google-auth-oauthlib=1.0.0 86 | - graphite2=1.3.13 87 | - grpcio=1.55.1 88 | - gst-plugins-base=1.22.0 89 | - gstreamer=1.22.0 90 | - gstreamer-orc=0.4.34 91 | - harfbuzz=6.0.0 92 | - icu=70.1 93 | - idna=3.4 94 | - imagecodecs=2023.1.23 95 | - imageio=2.31.1 96 | - importlib-metadata=6.7.0 97 | - importlib_metadata=6.7.0 98 | - intel-openmp=2022.1.0 99 | - ipykernel=6.23.1 100 | - ipython=8.14.0 101 | - jack=1.9.22 102 | - jedi=0.18.2 103 | - jinja2=3.1.2 104 | - jpeg=9e 105 | - jupyter_client=8.2.0 106 | - jupyter_core=5.3.1 107 | - jxrlib=1.1 108 | - keyutils=1.6.1 109 | - kiwisolver=1.4.4 110 | - krb5=1.20.1 111 | - lame=3.100 112 | - lazy_loader=0.2 113 | - lcms2=2.15 114 | - ld_impl_linux-64=2.40 115 | - lerc=4.0.0 116 | - libabseil=20230125.2 117 | - libaec=1.0.6 118 | - libavif=0.11.1 119 | - libblas=3.9.0 120 | - libbrotlicommon=1.0.9 121 | - libbrotlidec=1.0.9 122 | - libbrotlienc=1.0.9 123 | - libcap=2.67 124 | - libcblas=3.9.0 125 | - libclang=15.0.7 126 | - libclang13=15.0.7 127 | - libcublas=11.11.3.6 128 | - libcufft=10.9.0.58 129 | - libcufile=1.6.1.9 130 | - libcups=2.3.3 131 | - libcurand=10.3.2.106 132 | - libcurl=8.1.2 133 | - libcusolver=11.4.1.48 134 | - libcusparse=11.7.5.86 135 | - libdb=6.2.32 136 | - libdeflate=1.17 137 | - libedit=3.1.20191231 138 | - libev=4.33 139 | - libevent=2.1.10 140 | - libexpat=2.5.0 141 | - libffi=3.4.2 142 | - libflac=1.4.2 143 | - libgcc-ng=13.1.0 144 | - libgcrypt=1.10.1 145 | - libgfortran-ng=13.1.0 146 | - libgfortran5=13.1.0 147 | - libglib=2.76.3 148 | - libgomp=13.1.0 149 | - libgpg-error=1.47 150 | - libgrpc=1.55.1 151 | - libiconv=1.17 152 | - liblapack=3.9.0 153 | - libllvm15=15.0.7 154 | - libnghttp2=1.52.0 155 | - libnpp=11.8.0.86 156 | - libnsl=2.0.0 157 | - libnvjpeg=11.9.0.86 158 | - libogg=1.3.4 159 | - libopus=1.3.1 160 | - libpng=1.6.39 161 | - libpq=15.3 162 | - libprotobuf=4.23.2 163 | - libsndfile=1.2.0 164 | - libsodium=1.0.18 165 | - libsqlite=3.42.0 166 | - libssh2=1.11.0 167 | - libstdcxx-ng=13.1.0 168 | - libsystemd0=253 169 | - libtiff=4.5.0 170 | - libtool=2.4.7 171 | - libudev1=253 172 | - libuuid=2.38.1 173 | - libvorbis=1.3.7 174 | - libwebp-base=1.3.0 175 | - libxcb=1.13 176 | - libxkbcommon=1.5.0 177 | - libxml2=2.10.3 178 | - libzlib=1.2.13 179 | - libzopfli=1.0.3 180 | - lightning-utilities=0.8.0 181 | - locket=1.0.0 182 | - lz4-c=1.9.4 183 | - markdown=3.4.3 184 | - markupsafe=2.1.3 185 | - matplotlib=3.7.1 186 | - matplotlib-base=3.7.1 187 | - matplotlib-inline=0.1.6 188 | - mkl=2022.1.0 189 | - ml-collections=0.1.1 190 | - mpc=1.3.1 191 | - mpfr=4.2.0 192 | - mpg123=1.31.3 193 | - mpmath=1.3.0 194 | - multidict=6.0.4 195 | - munkres=1.1.4 196 | - mysql-common=8.0.33 197 | - mysql-libs=8.0.33 198 | - ncurses=6.4 199 | - nest-asyncio=1.5.6 200 | - nettle=3.6 201 | - networkx=3.1 202 | - nspr=4.35 203 | - nss=3.89 204 | - numpy=1.25.0 205 | - oauthlib=3.2.2 206 | - openh264=2.1.1 207 | - openjpeg=2.5.0 208 | - openssl=3.1.4 209 | - packaging=23.1 210 | - pandas=2.1.3 211 | - parso=0.8.3 212 | - partd=1.4.0 213 | - pcre2=10.40 214 | - pexpect=4.8.0 215 | - pickleshare=0.7.5 216 | - pillow=9.4.0 217 | - pip=23.1.2 218 | - pixman=0.40.0 219 | - platformdirs=3.6.0 220 | - ply=3.11 221 | - pooch=1.7.0 222 | - prompt-toolkit=3.0.38 223 | - prompt_toolkit=3.0.38 224 | - protobuf=4.23.2 225 | - psutil=5.9.5 226 | - pthread-stubs=0.4 227 | - ptyprocess=0.7.0 228 | - pulseaudio=16.1 229 | - pulseaudio-client=16.1 230 | - pulseaudio-daemon=16.1 231 | - pure_eval=0.2.2 232 | - pyasn1=0.4.8 233 | - pyasn1-modules=0.2.7 234 | - pycparser=2.21 235 | - pygments=2.15.1 236 | - pyjwt=2.7.0 237 | - pyopenssl=23.2.0 238 | - pyparsing=3.1.0 239 | - pyqt=5.15.7 240 | - pyqt5-sip=12.11.0 241 | - pysocks=1.7.1 242 | - python=3.10.11 243 | - python-dateutil=2.8.2 244 | - python-tzdata=2023.3 245 | - python_abi=3.10 246 | - pytorch=2.0.1 247 | - pytorch-cuda=11.8 248 | - pytorch-lightning=2.0.3 249 | - pytorch-mutex=1.0 250 | - pytz=2023.3.post1 251 | - pyu2f=0.1.5 252 | - pywavelets=1.4.1 253 | - pyyaml=6.0 254 | - pyzmq=25.1.0 255 | - qt-main=5.15.8 256 | - re2=2023.03.02 257 | - readline=8.2 258 | - requests=2.31.0 259 | - requests-oauthlib=1.3.1 260 | - rsa=4.9 261 | - scikit-image=0.20.0 262 | - scipy=1.10.1 263 | - setuptools=67.7.2 264 | - sip=6.7.9 265 | - six=1.16.0 266 | - snappy=1.1.10 267 | - stack_data=0.6.2 268 | - sympy=1.12 269 | - tensorboard=2.13.0 270 | - tensorboard-data-server=0.7.0 271 | - tifffile=2023.4.12 272 | - tk=8.6.12 273 | - toml=0.10.2 274 | - tomli=2.0.1 275 | - toolz=0.12.0 276 | - torchinfo=1.8.0 277 | - torchmetrics=0.11.4 278 | - torchtriton=2.0.0 279 | - torchvision=0.15.2 280 | - tornado=6.3.2 281 | - tqdm=4.65.0 282 | - traitlets=5.9.0 283 | - typing-extensions=4.6.3 284 | - typing_extensions=4.6.3 285 | - tzdata=2023c 286 | - unicodedata2=15.0.0 287 | - urllib3=1.26.15 288 | - wcwidth=0.2.6 289 | - werkzeug=2.3.6 290 | - wheel=0.40.0 291 | - xcb-util=0.4.0 292 | - xcb-util-image=0.4.0 293 | - xcb-util-keysyms=0.4.0 294 | - xcb-util-renderutil=0.3.9 295 | - xcb-util-wm=0.4.1 296 | - xkeyboard-config=2.38 297 | - xorg-kbproto=1.0.7 298 | - xorg-libice=1.1.1 299 | - xorg-libsm=1.2.4 300 | - xorg-libx11=1.8.4 301 | - xorg-libxau=1.0.11 302 | - xorg-libxdmcp=1.1.3 303 | - xorg-libxext=1.3.4 304 | - xorg-libxrender=0.9.10 305 | - xorg-renderproto=0.11.1 306 | - xorg-xextproto=7.3.0 307 | - xorg-xproto=7.0.31 308 | - xz=5.2.6 309 | - yaml=0.2.5 310 | - yarl=1.9.2 311 | - zeromq=4.3.4 312 | - zfp=1.0.0 313 | - zipp=3.15.0 314 | - zlib=1.2.13 315 | - zlib-ng=2.0.7 316 | - zstd=1.5.2 317 | - pip: 318 | - docstring-parser==0.15 319 | - importlib-resources==5.12.0 320 | - jsonargparse==4.21.2 321 | - nvidia-htop==1.0.5 322 | - termcolor==2.3.0 323 | - typeshed-client==2.3.0 324 | -------------------------------------------------------------------------------- /conda_env/env_explicit.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 6 | https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.conda 7 | https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.7.22-hbcca054_0.conda 8 | https://conda.anaconda.org/nvidia/linux-64/cuda-cudart-11.8.89-0.tar.bz2 9 | https://conda.anaconda.org/nvidia/linux-64/cuda-cupti-11.8.87-0.tar.bz2 10 | https://conda.anaconda.org/nvidia/linux-64/cuda-nvrtc-11.8.89-0.tar.bz2 11 | https://conda.anaconda.org/nvidia/linux-64/cuda-nvtx-11.8.86-0.tar.bz2 12 | https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2 13 | https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2 14 | https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2 15 | https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-hab24e00_0.tar.bz2 16 | https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda 17 | https://conda.anaconda.org/nvidia/linux-64/libcublas-11.11.3.6-0.tar.bz2 18 | https://conda.anaconda.org/nvidia/linux-64/libcufft-10.9.0.58-0.tar.bz2 19 | https://conda.anaconda.org/nvidia/linux-64/libcufile-1.6.1.9-0.tar.bz2 20 | https://conda.anaconda.org/nvidia/linux-64/libcurand-10.3.2.106-0.tar.bz2 21 | https://conda.anaconda.org/nvidia/linux-64/libcusolver-11.4.1.48-0.tar.bz2 22 | https://conda.anaconda.org/nvidia/linux-64/libcusparse-11.7.5.86-0.tar.bz2 23 | https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.1.0-h15d22d2_0.conda 24 | https://conda.anaconda.org/nvidia/linux-64/libnpp-11.8.0.86-0.tar.bz2 25 | https://conda.anaconda.org/nvidia/linux-64/libnvjpeg-11.9.0.86-0.tar.bz2 26 | https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.1.0-hfd8a6a1_0.conda 27 | https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.10-3_cp310.conda 28 | https://conda.anaconda.org/pytorch/noarch/pytorch-mutex-1.0-cuda.tar.bz2 29 | https://conda.anaconda.org/conda-forge/noarch/tzdata-2023c-h71feb2d_0.conda 30 | https://conda.anaconda.org/nvidia/linux-64/cuda-libraries-11.8.0-0.tar.bz2 31 | https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2 32 | https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.1.0-h69a702a_0.conda 33 | https://conda.anaconda.org/conda-forge/linux-64/libgomp-13.1.0-he5830b7_0.conda 34 | https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 35 | https://conda.anaconda.org/nvidia/linux-64/cuda-runtime-11.8.0-0.tar.bz2 36 | https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2 37 | https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.1.0-he5830b7_0.conda 38 | https://conda.anaconda.org/pytorch/linux-64/pytorch-cuda-11.8-h7e8668a_5.tar.bz2 39 | https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.8-h166bdaf_0.tar.bz2 40 | https://conda.anaconda.org/conda-forge/linux-64/aom-3.5.0-h27087fc_0.tar.bz2 41 | https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2 42 | https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-h7f98852_4.tar.bz2 43 | https://conda.anaconda.org/conda-forge/linux-64/c-ares-1.19.1-hd590300_0.conda 44 | https://conda.anaconda.org/conda-forge/linux-64/charls-2.4.2-h59595ed_0.conda 45 | https://conda.anaconda.org/conda-forge/linux-64/dav1d-1.2.1-hd590300_0.conda 46 | https://conda.anaconda.org/conda-forge/linux-64/fftw-3.3.10-nompi_hc118613_108.conda 47 | https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2 48 | https://conda.anaconda.org/conda-forge/linux-64/giflib-5.2.1-h0b41bf4_3.conda 49 | https://conda.anaconda.org/conda-forge/linux-64/gmp-6.2.1-h58526e2_0.tar.bz2 50 | https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.13-h58526e2_1001.tar.bz2 51 | https://conda.anaconda.org/conda-forge/linux-64/gstreamer-orc-0.4.34-hd590300_0.conda 52 | https://conda.anaconda.org/conda-forge/linux-64/icu-70.1-h27087fc_0.tar.bz2 53 | https://conda.anaconda.org/conda-forge/linux-64/jpeg-9e-h0b41bf4_3.conda 54 | https://conda.anaconda.org/conda-forge/linux-64/jxrlib-1.1-h7f98852_2.tar.bz2 55 | https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.1-h166bdaf_0.tar.bz2 56 | https://conda.anaconda.org/conda-forge/linux-64/lame-3.100-h166bdaf_1003.tar.bz2 57 | https://conda.anaconda.org/conda-forge/linux-64/lerc-4.0.0-h27087fc_0.tar.bz2 58 | https://conda.anaconda.org/conda-forge/linux-64/libabseil-20230125.2-cxx17_h59595ed_2.conda 59 | https://conda.anaconda.org/conda-forge/linux-64/libaec-1.0.6-hcb278e6_1.conda 60 | https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.0.9-h166bdaf_8.tar.bz2 61 | https://conda.anaconda.org/conda-forge/linux-64/libdb-6.2.32-h9c3ff4c_0.tar.bz2 62 | https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.17-h0b41bf4_0.conda 63 | https://conda.anaconda.org/conda-forge/linux-64/libev-4.33-h516909a_1.tar.bz2 64 | https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda 65 | https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2 66 | https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.17-h166bdaf_0.tar.bz2 67 | https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.0-h7f98852_0.tar.bz2 68 | https://conda.anaconda.org/conda-forge/linux-64/libogg-1.3.4-h7f98852_1.tar.bz2 69 | https://conda.anaconda.org/conda-forge/linux-64/libopus-1.3.1-h7f98852_1.tar.bz2 70 | https://conda.anaconda.org/conda-forge/linux-64/libsodium-1.0.18-h36c2ea0_1.tar.bz2 71 | https://conda.anaconda.org/conda-forge/linux-64/libtool-2.4.7-h27087fc_0.conda 72 | https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda 73 | https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.0-h0b41bf4_0.conda 74 | https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda 75 | https://conda.anaconda.org/conda-forge/linux-64/libzopfli-1.0.3-h9c3ff4c_0.tar.bz2 76 | https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda 77 | https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.31.3-hcb278e6_0.conda 78 | https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-hcb278e6_0.conda 79 | https://conda.anaconda.org/conda-forge/linux-64/nettle-3.6-he412f7d_0.tar.bz2 80 | https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda 81 | https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.4-hd590300_0.conda 82 | https://conda.anaconda.org/conda-forge/linux-64/pixman-0.40.0-h36c2ea0_0.tar.bz2 83 | https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2 84 | https://conda.anaconda.org/conda-forge/linux-64/re2-2023.03.02-h8c504da_0.conda 85 | https://conda.anaconda.org/conda-forge/linux-64/snappy-1.1.10-h9fff704_0.conda 86 | https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.38-h0b41bf4_0.conda 87 | https://conda.anaconda.org/conda-forge/linux-64/xorg-kbproto-1.0.7-h7f98852_1002.tar.bz2 88 | https://conda.anaconda.org/conda-forge/linux-64/xorg-libice-1.1.1-hd590300_0.conda 89 | https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.11-hd590300_0.conda 90 | https://conda.anaconda.org/conda-forge/linux-64/xorg-libxdmcp-1.1.3-h7f98852_0.tar.bz2 91 | https://conda.anaconda.org/conda-forge/linux-64/xorg-renderproto-0.11.1-h7f98852_1002.tar.bz2 92 | https://conda.anaconda.org/conda-forge/linux-64/xorg-xextproto-7.3.0-h0b41bf4_1003.conda 93 | https://conda.anaconda.org/conda-forge/linux-64/xorg-xproto-7.0.31-h7f98852_1007.tar.bz2 94 | https://conda.anaconda.org/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2 95 | https://conda.anaconda.org/conda-forge/linux-64/yaml-0.2.5-h7f98852_2.tar.bz2 96 | https://conda.anaconda.org/conda-forge/linux-64/zfp-1.0.0-h27087fc_3.tar.bz2 97 | https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.0.7-h0b41bf4_0.conda 98 | https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-hcb278e6_1.conda 99 | https://conda.anaconda.org/conda-forge/linux-64/gnutls-3.6.13-h85f3911_1.tar.bz2 100 | https://conda.anaconda.org/conda-forge/linux-64/jack-1.9.22-h11f4161_0.conda 101 | https://conda.anaconda.org/conda-forge/linux-64/libavif-0.11.1-h8182462_2.conda 102 | https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.0.9-h166bdaf_8.tar.bz2 103 | https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.0.9-h166bdaf_8.tar.bz2 104 | https://conda.anaconda.org/conda-forge/linux-64/libcap-2.67-he9d0100_0.conda 105 | https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20191231-he28a2e2_2.tar.bz2 106 | https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.10-h28343ad_4.tar.bz2 107 | https://conda.anaconda.org/conda-forge/linux-64/libflac-1.4.2-h27087fc_0.tar.bz2 108 | https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.47-h71f35ed_0.conda 109 | https://conda.anaconda.org/conda-forge/linux-64/libnghttp2-1.52.0-h61bc06f_0.conda 110 | https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.39-h753d276_0.conda 111 | https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-4.23.2-hd1fb520_5.conda 112 | https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.42.0-h2797004_0.conda 113 | https://conda.anaconda.org/conda-forge/linux-64/libssh2-1.11.0-h0841786_0.conda 114 | https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2 115 | https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.13-h7f98852_1004.tar.bz2 116 | https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.10.3-hca2bb57_4.conda 117 | https://conda.anaconda.org/conda-forge/linux-64/mpfr-4.2.0-hb012696_0.conda 118 | https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_0.conda 119 | https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.40-hc3806b6_0.tar.bz2 120 | https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda 121 | https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.12-h27826a3_0.tar.bz2 122 | https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.4-h7391055_0.conda 123 | https://conda.anaconda.org/conda-forge/linux-64/zeromq-4.3.4-h9c3ff4c_1.tar.bz2 124 | https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda 125 | https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.2-h3eb15da_6.conda 126 | https://conda.anaconda.org/conda-forge/linux-64/blosc-1.21.4-h0f2a231_0.conda 127 | https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.0.9-h166bdaf_8.tar.bz2 128 | https://conda.anaconda.org/conda-forge/linux-64/c-blosc2-2.9.2-hb4ffafa_0.conda 129 | https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-hca18f0e_1.conda 130 | https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2022.1.0-h9e868ea_3769.conda 131 | https://conda.anaconda.org/conda-forge/linux-64/krb5-1.20.1-h81ceb04_0.conda 132 | https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.1-h166bdaf_0.tar.bz2 133 | https://conda.anaconda.org/conda-forge/linux-64/libglib-2.76.3-hebfc3b9_0.conda 134 | https://conda.anaconda.org/conda-forge/linux-64/libgrpc-1.55.1-h59456c1_1.conda 135 | https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-hadd5161_1.conda 136 | https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.0-hb75c966_0.conda 137 | https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.5.0-h6adf6a1_2.conda 138 | https://conda.anaconda.org/conda-forge/linux-64/libudev1-253-h0b41bf4_1.conda 139 | https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.5.0-h79f4944_1.conda 140 | https://conda.anaconda.org/conda-forge/linux-64/mpc-1.3.1-hfe3b2da_0.conda 141 | https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_0.conda 142 | https://conda.anaconda.org/conda-forge/linux-64/nss-3.89-he45b914_0.conda 143 | https://conda.anaconda.org/conda-forge/linux-64/openh264-2.1.1-h780b84a_0.tar.bz2 144 | https://conda.anaconda.org/conda-forge/linux-64/python-3.10.11-he550d4f_0_cpython.conda 145 | https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-h516909a_0.tar.bz2 146 | https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.0-h516909a_0.tar.bz2 147 | https://conda.anaconda.org/conda-forge/linux-64/xcb-util-renderutil-0.3.9-h166bdaf_0.tar.bz2 148 | https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.1-h516909a_0.tar.bz2 149 | https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.4-h0b41bf4_0.conda 150 | https://conda.anaconda.org/conda-forge/noarch/absl-py-1.4.0-pyhd8ed1ab_0.conda 151 | https://conda.anaconda.org/conda-forge/noarch/attrs-23.1.0-pyh71513ae_1.conda 152 | https://conda.anaconda.org/conda-forge/noarch/backcall-0.2.0-pyh9f0ad1d_0.tar.bz2 153 | https://conda.anaconda.org/conda-forge/noarch/backports-1.0-pyhd8ed1ab_3.conda 154 | https://conda.anaconda.org/conda-forge/noarch/blinker-1.6.2-pyhd8ed1ab_0.conda 155 | https://conda.anaconda.org/conda-forge/linux-64/brotli-1.0.9-h166bdaf_8.tar.bz2 156 | https://conda.anaconda.org/conda-forge/noarch/cachetools-5.3.0-pyhd8ed1ab_0.conda 157 | https://conda.anaconda.org/conda-forge/noarch/certifi-2023.7.22-pyhd8ed1ab_0.conda 158 | https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.1.0-pyhd8ed1ab_0.conda 159 | https://conda.anaconda.org/conda-forge/noarch/click-8.1.3-unix_pyhd8ed1ab_2.tar.bz2 160 | https://conda.anaconda.org/conda-forge/noarch/cloudpickle-2.2.1-pyhd8ed1ab_0.conda 161 | https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2 162 | https://conda.anaconda.org/conda-forge/noarch/contextlib2-21.6.0-pyhd8ed1ab_0.tar.bz2 163 | https://conda.anaconda.org/conda-forge/noarch/cycler-0.11.0-pyhd8ed1ab_0.tar.bz2 164 | https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2 165 | https://conda.anaconda.org/conda-forge/linux-64/debugpy-1.6.7-py310heca2aa9_0.conda 166 | https://conda.anaconda.org/conda-forge/noarch/decorator-5.1.1-pyhd8ed1ab_0.tar.bz2 167 | https://conda.anaconda.org/conda-forge/noarch/einops-0.6.1-pyhd8ed1ab_0.conda 168 | https://conda.anaconda.org/conda-forge/noarch/executing-1.2.0-pyhd8ed1ab_0.tar.bz2 169 | https://conda.anaconda.org/pytorch/linux-64/ffmpeg-4.3-hf484d3e_0.tar.bz2 170 | https://conda.anaconda.org/conda-forge/noarch/filelock-3.12.2-pyhd8ed1ab_0.conda 171 | https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda 172 | https://conda.anaconda.org/conda-forge/linux-64/frozenlist-1.3.3-py310h5764c6d_0.tar.bz2 173 | https://conda.anaconda.org/conda-forge/noarch/fsspec-2023.6.0-pyh1a96a4e_0.conda 174 | https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.76.3-hfc55251_0.conda 175 | https://conda.anaconda.org/conda-forge/linux-64/gmpy2-2.1.2-py310h3ec546c_1.tar.bz2 176 | https://conda.anaconda.org/conda-forge/linux-64/grpcio-1.55.1-py310h1b8f574_1.conda 177 | https://conda.anaconda.org/conda-forge/noarch/idna-3.4-pyhd8ed1ab_0.tar.bz2 178 | https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.4-py310hbf28c38_1.tar.bz2 179 | https://conda.anaconda.org/conda-forge/noarch/lazy_loader-0.2-pyhd8ed1ab_0.conda 180 | https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.15-hfd0df8a_0.conda 181 | https://conda.anaconda.org/conda-forge/linux-64/libclang13-15.0.7-default_h9986a30_2.conda 182 | https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h36d4200_3.conda 183 | https://conda.anaconda.org/conda-forge/linux-64/libcurl-8.1.2-h409715c_0.conda 184 | https://conda.anaconda.org/conda-forge/linux-64/libpq-15.3-hbcd7760_1.conda 185 | https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-253-h8c4010b_1.conda 186 | https://conda.anaconda.org/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 187 | https://conda.anaconda.org/conda-forge/linux-64/markupsafe-2.1.3-py310h2372a71_0.conda 188 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-2022.1.0-hc2b9512_224.conda 189 | https://conda.anaconda.org/conda-forge/noarch/mpmath-1.3.0-pyhd8ed1ab_0.conda 190 | https://conda.anaconda.org/conda-forge/linux-64/multidict-6.0.4-py310h1fa729e_0.conda 191 | https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2 192 | https://conda.anaconda.org/conda-forge/noarch/nest-asyncio-1.5.6-pyhd8ed1ab_0.tar.bz2 193 | https://conda.anaconda.org/conda-forge/noarch/networkx-3.1-pyhd8ed1ab_0.conda 194 | https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.0-hfec8fc6_2.conda 195 | https://conda.anaconda.org/conda-forge/noarch/packaging-23.1-pyhd8ed1ab_0.conda 196 | https://conda.anaconda.org/conda-forge/noarch/parso-0.8.3-pyhd8ed1ab_0.tar.bz2 197 | https://conda.anaconda.org/conda-forge/noarch/pickleshare-0.7.5-py_1003.tar.bz2 198 | https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2 199 | https://conda.anaconda.org/conda-forge/linux-64/psutil-5.9.5-py310h1fa729e_0.conda 200 | https://conda.anaconda.org/conda-forge/noarch/ptyprocess-0.7.0-pyhd3deb0d_0.tar.bz2 201 | https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.2-pyhd8ed1ab_0.tar.bz2 202 | https://repo.anaconda.com/pkgs/main/noarch/pyasn1-0.4.8-py_0.conda 203 | https://conda.anaconda.org/conda-forge/noarch/pycparser-2.21-pyhd8ed1ab_0.tar.bz2 204 | https://conda.anaconda.org/conda-forge/noarch/pygments-2.15.1-pyhd8ed1ab_0.conda 205 | https://conda.anaconda.org/conda-forge/noarch/pyjwt-2.7.0-pyhd8ed1ab_0.conda 206 | https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.0-pyhd8ed1ab_0.conda 207 | https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2 208 | https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2023.3-pyhd8ed1ab_0.conda 209 | https://conda.anaconda.org/conda-forge/noarch/pytz-2023.3.post1-pyhd8ed1ab_0.conda 210 | https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0-py310h5764c6d_5.tar.bz2 211 | https://conda.anaconda.org/conda-forge/linux-64/pyzmq-25.1.0-py310h5bbb5d0_0.conda 212 | https://conda.anaconda.org/conda-forge/noarch/setuptools-67.7.2-pyhd8ed1ab_0.conda 213 | https://conda.anaconda.org/conda-forge/noarch/six-1.16.0-pyh6c4a22f_0.tar.bz2 214 | https://conda.anaconda.org/conda-forge/linux-64/tensorboard-data-server-0.7.0-py310h34c0648_0.conda 215 | https://conda.anaconda.org/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 216 | https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2 217 | https://conda.anaconda.org/conda-forge/noarch/toolz-0.12.0-pyhd8ed1ab_0.tar.bz2 218 | https://conda.anaconda.org/conda-forge/linux-64/tornado-6.3.2-py310h2372a71_0.conda 219 | https://conda.anaconda.org/conda-forge/noarch/traitlets-5.9.0-pyhd8ed1ab_0.conda 220 | https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.6.3-pyha770c72_0.conda 221 | https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-15.0.0-py310h5764c6d_0.tar.bz2 222 | https://conda.anaconda.org/conda-forge/noarch/wheel-0.40.0-pyhd8ed1ab_0.conda 223 | https://conda.anaconda.org/conda-forge/linux-64/xcb-util-image-0.4.0-h166bdaf_0.tar.bz2 224 | https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.4-h0b41bf4_2.conda 225 | https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.10-h7f98852_1003.tar.bz2 226 | https://conda.anaconda.org/conda-forge/noarch/zipp-3.15.0-pyhd8ed1ab_0.conda 227 | https://conda.anaconda.org/conda-forge/noarch/aiosignal-1.3.1-pyhd8ed1ab_0.tar.bz2 228 | https://conda.anaconda.org/conda-forge/noarch/asttokens-2.2.1-pyhd8ed1ab_0.conda 229 | https://conda.anaconda.org/conda-forge/noarch/backports.functools_lru_cache-1.6.4-pyhd8ed1ab_0.tar.bz2 230 | https://conda.anaconda.org/conda-forge/linux-64/brunsli-0.1-h9c3ff4c_0.tar.bz2 231 | https://conda.anaconda.org/conda-forge/linux-64/cairo-1.16.0-ha61ee94_1014.tar.bz2 232 | https://conda.anaconda.org/conda-forge/linux-64/cffi-1.15.1-py310h255011f_3.conda 233 | https://conda.anaconda.org/conda-forge/linux-64/cfitsio-4.2.0-hd9d235c_0.conda 234 | https://conda.anaconda.org/conda-forge/noarch/comm-0.1.3-pyhd8ed1ab_0.conda 235 | https://conda.anaconda.org/conda-forge/linux-64/cytoolz-0.12.0-py310h5764c6d_1.tar.bz2 236 | https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.40.0-py310h2372a71_0.conda 237 | https://conda.anaconda.org/conda-forge/linux-64/glib-2.76.3-hfc55251_0.conda 238 | https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-6.7.0-pyha770c72_0.conda 239 | https://conda.anaconda.org/conda-forge/noarch/jedi-0.18.2-pyhd8ed1ab_0.conda 240 | https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.2-pyhd8ed1ab_1.tar.bz2 241 | https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-16_linux64_mkl.tar.bz2 242 | https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_h7634d5b_2.conda 243 | https://conda.anaconda.org/conda-forge/noarch/lightning-utilities-0.8.0-pyhd8ed1ab_0.conda 244 | https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.1.6-pyhd8ed1ab_0.tar.bz2 245 | https://conda.anaconda.org/conda-forge/noarch/ml-collections-0.1.1-pyhd8ed1ab_0.tar.bz2 246 | https://conda.anaconda.org/conda-forge/noarch/partd-1.4.0-pyhd8ed1ab_0.conda 247 | https://conda.anaconda.org/conda-forge/noarch/pexpect-4.8.0-pyh1a96a4e_2.tar.bz2 248 | https://conda.anaconda.org/conda-forge/linux-64/pillow-9.4.0-py310h023d228_1.conda 249 | https://conda.anaconda.org/conda-forge/noarch/pip-23.1.2-pyhd8ed1ab_0.conda 250 | https://conda.anaconda.org/conda-forge/linux-64/protobuf-4.23.2-py310hb875b13_1.conda 251 | https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-h5195f5e_3.conda 252 | https://repo.anaconda.com/pkgs/main/noarch/pyasn1-modules-0.2.7-py_0.conda 253 | https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2 254 | https://conda.anaconda.org/conda-forge/noarch/pyu2f-0.1.5-pyhd8ed1ab_0.tar.bz2 255 | https://conda.anaconda.org/conda-forge/noarch/rsa-4.9-pyhd8ed1ab_0.tar.bz2 256 | https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.9-py310hc6cd4ac_0.conda 257 | https://conda.anaconda.org/conda-forge/noarch/sympy-1.12-pypyh9d50eac_103.conda 258 | https://conda.anaconda.org/conda-forge/noarch/tqdm-4.65.0-pyhd8ed1ab_1.conda 259 | https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.6.3-hd8ed1ab_0.conda 260 | https://conda.anaconda.org/conda-forge/noarch/werkzeug-2.3.6-pyhd8ed1ab_0.conda 261 | https://conda.anaconda.org/conda-forge/linux-64/yarl-1.9.2-py310h2372a71_0.conda 262 | https://conda.anaconda.org/conda-forge/noarch/async-timeout-4.0.2-pyhd8ed1ab_0.tar.bz2 263 | https://conda.anaconda.org/conda-forge/linux-64/brotlipy-0.7.0-py310h5764c6d_1005.tar.bz2 264 | https://conda.anaconda.org/conda-forge/linux-64/cryptography-41.0.1-py310h75e40e8_0.conda 265 | https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.0-h25f0c4b_2.conda 266 | https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-6.0.0-h8e241bc_0.conda 267 | https://conda.anaconda.org/conda-forge/noarch/importlib_metadata-6.7.0-hd8ed1ab_0.conda 268 | https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-16_linux64_mkl.tar.bz2 269 | https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-16_linux64_mkl.tar.bz2 270 | https://conda.anaconda.org/conda-forge/noarch/markdown-3.4.3-pyhd8ed1ab_0.conda 271 | https://conda.anaconda.org/conda-forge/noarch/platformdirs-3.6.0-pyhd8ed1ab_0.conda 272 | https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-daemon-16.1-ha8d29e2_3.conda 273 | https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.11.0-py310heca2aa9_3.conda 274 | https://conda.anaconda.org/conda-forge/noarch/stack_data-0.6.2-pyhd8ed1ab_0.conda 275 | https://conda.anaconda.org/conda-forge/noarch/wcwidth-0.2.6-pyhd8ed1ab_0.conda 276 | https://conda.anaconda.org/conda-forge/linux-64/aiohttp-3.8.4-py310h2372a71_1.conda 277 | https://conda.anaconda.org/conda-forge/noarch/dask-core-2023.6.0-pyhd8ed1ab_0.conda 278 | https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.0-h4243ec0_2.conda 279 | https://conda.anaconda.org/conda-forge/linux-64/jupyter_core-5.3.1-py310hff52083_0.conda 280 | https://conda.anaconda.org/conda-forge/linux-64/numpy-1.25.0-py310ha4c1d20_0.conda 281 | https://conda.anaconda.org/conda-forge/noarch/oauthlib-3.2.2-pyhd8ed1ab_0.tar.bz2 282 | https://conda.anaconda.org/conda-forge/noarch/prompt-toolkit-3.0.38-pyha770c72_0.conda 283 | https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-16.1-hcb278e6_3.conda 284 | https://conda.anaconda.org/conda-forge/noarch/pyopenssl-23.2.0-pyhd8ed1ab_1.conda 285 | https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.1.0-py310hd41b1e2_0.conda 286 | https://conda.anaconda.org/conda-forge/linux-64/imagecodecs-2023.1.23-py310ha3ed6a1_0.conda 287 | https://conda.anaconda.org/conda-forge/noarch/imageio-2.31.1-pyh24c5eb1_0.conda 288 | https://conda.anaconda.org/conda-forge/noarch/jupyter_client-8.2.0-pyhd8ed1ab_0.conda 289 | https://conda.anaconda.org/conda-forge/linux-64/pandas-2.1.3-py310hcc13569_0.conda 290 | https://conda.anaconda.org/conda-forge/noarch/prompt_toolkit-3.0.38-hd8ed1ab_0.conda 291 | https://conda.anaconda.org/conda-forge/linux-64/pywavelets-1.4.1-py310h0a54255_0.conda 292 | https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h5d23da1_6.conda 293 | https://conda.anaconda.org/conda-forge/noarch/urllib3-1.26.15-pyhd8ed1ab_0.conda 294 | https://conda.anaconda.org/conda-forge/noarch/ipython-8.14.0-pyh41d4057_0.conda 295 | https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.7.1-py310he60537e_0.conda 296 | https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.7-py310hab646b1_3.conda 297 | https://conda.anaconda.org/conda-forge/noarch/requests-2.31.0-pyhd8ed1ab_0.conda 298 | https://conda.anaconda.org/conda-forge/noarch/tifffile-2023.4.12-pyhd8ed1ab_0.conda 299 | https://conda.anaconda.org/conda-forge/noarch/google-auth-2.20.0-pyh1a96a4e_0.conda 300 | https://conda.anaconda.org/conda-forge/noarch/ipykernel-6.23.1-pyh210e3f2_0.conda 301 | https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.7.1-py310hff52083_0.conda 302 | https://conda.anaconda.org/conda-forge/noarch/pooch-1.7.0-pyha770c72_3.conda 303 | https://conda.anaconda.org/conda-forge/noarch/requests-oauthlib-1.3.1-pyhd8ed1ab_0.tar.bz2 304 | https://conda.anaconda.org/conda-forge/noarch/google-auth-oauthlib-1.0.0-pyhd8ed1ab_0.conda 305 | https://conda.anaconda.org/conda-forge/linux-64/scipy-1.10.1-py310ha4c1d20_3.conda 306 | https://conda.anaconda.org/conda-forge/linux-64/scikit-image-0.20.0-py310h9b08913_1.conda 307 | https://conda.anaconda.org/conda-forge/noarch/tensorboard-2.13.0-pyhd8ed1ab_0.conda 308 | https://conda.anaconda.org/conda-forge/noarch/torchinfo-1.8.0-pyhd8ed1ab_0.conda 309 | https://conda.anaconda.org/conda-forge/noarch/torchmetrics-0.11.4-pyhd8ed1ab_0.conda 310 | https://conda.anaconda.org/conda-forge/noarch/pytorch-lightning-2.0.3-pyhd8ed1ab_0.conda 311 | https://conda.anaconda.org/pytorch/linux-64/pytorch-2.0.1-py3.10_cuda11.8_cudnn8.7.0_0.tar.bz2 312 | https://conda.anaconda.org/pytorch/linux-64/torchtriton-2.0.0-py310.tar.bz2 313 | https://conda.anaconda.org/pytorch/linux-64/torchvision-0.15.2-py310_cu118.tar.bz2 314 | -------------------------------------------------------------------------------- /conda_env/env_short.yml: -------------------------------------------------------------------------------- 1 | name: rmbd_env 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.10 7 | - pytorch=2.0.1 8 | - pytorch-cuda=11.8 9 | - torchvision 10 | - einops 11 | - matplotlib 12 | - scikit-image 13 | - torchinfo 14 | - tensorboard 15 | - ml-collections 16 | - ipykernel 17 | - ca-certificates 18 | - openssl 19 | - pytorch-lightning=2.0.3 20 | - certifi 21 | - pandas 22 | - pip: 23 | - docstring-parser==0.15 24 | - importlib-resources==5.12.0 25 | - jsonargparse==4.21.2 26 | - nvidia-htop==1.0.5 27 | - termcolor==2.3.0 28 | - typeshed-client==2.3.0 29 | -------------------------------------------------------------------------------- /configs/data/ant_fspl_floor.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: false 19 | ant_gain_top: false 20 | ant_gain_floor_dist: true 21 | ant_gain_top_dist: true 22 | ant_gain_slices: false 23 | ant_gain_los_slices: false 24 | img: false 25 | z_step: 4 26 | z_max: 32 27 | thresh: 0.2 28 | augmentation: true 29 | batch_size: 32 30 | shuffle: true 31 | num_workers: 8 32 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/ant_fspl_floor_top.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/ant_fspl_slices_4m.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: false 19 | ant_gain_top: false 20 | ant_gain_slices: false 21 | ant_gain_slices_dist: true 22 | ant_gain_los_slices: false 23 | img: false 24 | z_step: 4 25 | z_max: 32 26 | thresh: 0.2 27 | augmentation: true 28 | batch_size: 32 29 | shuffle: true 30 | num_workers: 8 31 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/ant_gain_floor.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: false 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/ant_slices_1m.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: false 19 | ant_gain_top: false 20 | ant_gain_slices: true 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 1 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/ant_slices_4m.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: false 19 | ant_gain_top: false 20 | ant_gain_slices: true 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_cyl.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: true 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_cyl_los_binary.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: true 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: true 14 | los_top: true 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_cyl_los_min.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: true 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: true 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_eucl.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: true 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_eucl_los_binary.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: true 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: true 14 | los_top: true 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_eucl_los_min.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: true 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: true 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_ga.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: true 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_ga_los_binary.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: true 13 | los_floor: true 14 | los_top: true 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_ga_los_min.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: true 13 | los_floor: false 14 | los_top: false 15 | los_z_min: true 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_los_binary.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: true 14 | los_top: true 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_los_min.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: true 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_slices.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: false 19 | ant_gain_top: false 20 | ant_gain_slices: true 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_sph.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: true 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_sph_los_binary.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: true 12 | coords_ga: false 13 | los_floor: true 14 | los_top: true 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/base_sph_los_min.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: true 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: true 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: true 18 | ant_gain_floor: true 19 | ant_gain_top: true 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/img.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: false 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: false 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: true 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/img_ga_azi_dist.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | azimuth: true 8 | dist2d: true 9 | ndsms: false 10 | ndsm_all: false 11 | coords_euclidian: false 12 | coords_cylindrical: false 13 | coords_spherical: false 14 | coords_ga: true 15 | los_floor: false 16 | los_top: false 17 | los_z_min: false 18 | los_z_min_rel: false 19 | los_theta_max: false 20 | ant_gain_floor: true 21 | ant_gain_top: false 22 | ant_gain_slices: false 23 | ant_gain_los_slices: false 24 | img: true 25 | z_step: 4 26 | z_max: 32 27 | thresh: 0.2 28 | augmentation: true 29 | batch_size: 32 30 | shuffle: true 31 | num_workers: 8 32 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/img_ndsm.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: false 8 | ndsm_all: true 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: false 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: true 23 | z_step: 4 24 | z_max: 32 25 | thresh: 0.2 26 | augmentation: true 27 | batch_size: 32 28 | shuffle: true 29 | num_workers: 8 30 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/img_ndsm_ga_azi_dist.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | azimuth: true 8 | dist2d: true 9 | ndsms: false 10 | ndsm_all: true 11 | coords_euclidian: false 12 | coords_cylindrical: false 13 | coords_spherical: false 14 | coords_ga: true 15 | los_floor: false 16 | los_top: false 17 | los_z_min: false 18 | los_z_min_rel: false 19 | los_theta_max: false 20 | ant_gain_floor: true 21 | ant_gain_top: false 22 | ant_gain_slices: false 23 | ant_gain_los_slices: false 24 | img: true 25 | z_step: 4 26 | z_max: 32 27 | thresh: 0.2 28 | augmentation: true 29 | batch_size: 32 30 | shuffle: true 31 | num_workers: 8 32 | pin_memory: true -------------------------------------------------------------------------------- /configs/data/img_rgb.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: lib.pl_datamodule.LitRM_directional 3 | init_args: 4 | dataset_path: dataset 5 | id_file: splits_0.1_0.1_seeds_1_2_ant_all.json 6 | tx_one_hot: true 7 | ndsms: false 8 | ndsm_all: false 9 | coords_euclidian: false 10 | coords_cylindrical: false 11 | coords_spherical: false 12 | coords_ga: false 13 | los_floor: false 14 | los_top: false 15 | los_z_min: false 16 | los_z_min_rel: false 17 | los_theta_max: false 18 | ant_gain_floor: true 19 | ant_gain_top: false 20 | ant_gain_slices: false 21 | ant_gain_los_slices: false 22 | img: false 23 | img_rgb: true 24 | z_step: 4 25 | z_max: 32 26 | thresh: 0.2 27 | augmentation: true 28 | batch_size: 32 29 | shuffle: true 30 | num_workers: 8 31 | pin_memory: true -------------------------------------------------------------------------------- /configs/training/lower_lr.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | init_args: 3 | lr: 0.00005 4 | scheduler_params: 5 | patience: 8 6 | cb_early_stopping: 7 | patience: 20 -------------------------------------------------------------------------------- /lib/PMNet/PMNet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From https://github.com/abman23/PMNet/tree/PMNet , we have been using the version from 4.7.2023 (has been updated later after running our experiments, ) 3 | Lee et al - PMNet: Robust Pathloss Map Prediction via Supervised Learning, December 2023, Proceedings of IEEE Global Communicaions Conference (GLOBECOM) 4 | We have added a few options (e.g. varying number of in_ch) and fixed a msitake in _stem, otherwise the model is the same as proposed by the autors. 5 | 6 | License: 7 | MIT License 8 | 9 | Copyright (c) 2023 Juhyung Lee 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in all 19 | copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 27 | SOFTWARE. 28 | ''' 29 | 30 | 31 | from __future__ import absolute_import, print_function 32 | 33 | from collections import OrderedDict 34 | 35 | import torch 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | # from torchvision import models 39 | from ..dcn import DeformableConv2d 40 | 41 | _BATCH_NORM = nn.BatchNorm2d 42 | 43 | _BOTTLENECK_EXPANSION = 4 44 | 45 | # Conv, Batchnorm, Relu layers, basic building block. 46 | class _ConvBnReLU(nn.Sequential): 47 | 48 | BATCH_NORM = _BATCH_NORM 49 | 50 | def __init__( 51 | self, in_ch, out_ch, kernel_size, stride, padding, dilation, relu=True, bn_eps : float = 1e-05, dcn=False 52 | ): 53 | super(_ConvBnReLU, self).__init__() 54 | if dcn and dilation > 1: 55 | self.add_module( 56 | "conv", 57 | DeformableConv2d( 58 | in_ch, out_ch, kernel_size, stride, kernel_size//2, bias=False 59 | ), 60 | ) 61 | else: 62 | self.add_module( 63 | "conv", 64 | nn.Conv2d( 65 | in_ch, out_ch, kernel_size, stride, padding, dilation, bias=False 66 | ), 67 | ) 68 | self.add_module("bn", _BATCH_NORM(out_ch, eps=bn_eps, momentum=1 - 0.999)) 69 | 70 | if relu: 71 | self.add_module("relu", nn.ReLU()) 72 | 73 | # Bottleneck layer cinstructed from ConvBnRelu layer block, buiding block for Res layers 74 | class _Bottleneck(nn.Module): 75 | 76 | def __init__(self, in_ch, out_ch, stride, dilation, downsample, bn_eps : float = 1e-05, dcn=False): 77 | super(_Bottleneck, self).__init__() 78 | mid_ch = out_ch // _BOTTLENECK_EXPANSION 79 | self.reduce = _ConvBnReLU(in_ch, mid_ch, 1, stride, 0, 1, True, bn_eps=bn_eps) 80 | self.conv3x3 = _ConvBnReLU(mid_ch, mid_ch, 3, 1, dilation, dilation, True, bn_eps=bn_eps, dcn=dcn) 81 | self.increase = _ConvBnReLU(mid_ch, out_ch, 1, 1, 0, 1, False, bn_eps=bn_eps) 82 | self.shortcut = ( 83 | _ConvBnReLU(in_ch, out_ch, 1, stride, 0, 1, False, bn_eps=bn_eps) 84 | if downsample 85 | else nn.Identity() 86 | ) 87 | 88 | def forward(self, x): 89 | h = self.reduce(x) 90 | h = self.conv3x3(h) 91 | h = self.increase(h) 92 | h += self.shortcut(x) 93 | return F.relu(h) 94 | 95 | # Res Layer used to costruct the encoder 96 | class _ResLayer(nn.Sequential): 97 | 98 | def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None, bn_eps=1e-5, dcn=False): 99 | super(_ResLayer, self).__init__() 100 | 101 | if multi_grids is None: 102 | multi_grids = [1 for _ in range(n_layers)] 103 | else: 104 | assert n_layers == len(multi_grids) 105 | 106 | # Downsampling is only in the first block 107 | for i in range(n_layers): 108 | self.add_module( 109 | "block{}".format(i + 1), 110 | _Bottleneck( 111 | in_ch=(in_ch if i == 0 else out_ch), 112 | out_ch=out_ch, 113 | stride=(stride if i == 0 else 1), 114 | dilation=dilation * multi_grids[i], 115 | downsample=(True if i == 0 else False), 116 | bn_eps=bn_eps, 117 | dcn=dcn 118 | ), 119 | ) 120 | 121 | # Stem layer is the initial interfacing layer 122 | class _Stem(nn.Sequential): 123 | """ 124 | The 1st conv layer. 125 | Note that the max pooling is different from both MSRA and FAIR ResNet. 126 | """ 127 | 128 | def __init__(self, out_ch, in_ch = 2, ceil_mode=True, bn_eps=1e-5): 129 | super(_Stem, self).__init__() 130 | self.add_module("conv1", _ConvBnReLU(in_ch, out_ch, 7, 2, 3, 1, bn_eps=bn_eps)) 131 | '''First argument for MaxPool2d in the original implementation is in_ch, which should be a mistake (we have up to 15 in channels in some configurations...). 132 | We set it to 2 instead, since with the dataset used by the authors of PMNet, this should be the usual value of in_ch.''' 133 | self.add_module("pool", nn.MaxPool2d(2, 2, 0, ceil_mode=ceil_mode)) 134 | 135 | class _ImagePool(nn.Module): 136 | def __init__(self, in_ch, out_ch, bn_eps=1e-5): 137 | super().__init__() 138 | self.pool = nn.AdaptiveAvgPool2d(1) 139 | self.conv = _ConvBnReLU(in_ch, out_ch, 1, 1, 0, 1, bn_eps=bn_eps) 140 | 141 | def forward(self, x): 142 | _, _, H, W = x.shape 143 | h = self.pool(x) 144 | h = self.conv(h) 145 | h = F.interpolate(h, size=(H, W), mode="bilinear", align_corners=False) 146 | return h 147 | 148 | # Atrous spatial pyramid pooling 149 | class _ASPP(nn.Module): 150 | 151 | def __init__(self, in_ch, out_ch, rates, bn_eps=1e-5, dcn=False): 152 | super(_ASPP, self).__init__() 153 | self.stages = nn.Module() 154 | self.stages.add_module("c0", _ConvBnReLU(in_ch, out_ch, 1, 1, 0, 1, bn_eps=bn_eps)) 155 | for i, rate in enumerate(rates): 156 | self.stages.add_module( 157 | "c{}".format(i + 1), 158 | _ConvBnReLU(in_ch, out_ch, 3, 1, padding=rate, dilation=rate, bn_eps=bn_eps, dcn=dcn), 159 | ) 160 | self.stages.add_module("imagepool", _ImagePool(in_ch, out_ch)) 161 | 162 | def forward(self, x): 163 | return torch.cat([stage(x) for stage in self.stages.children()], dim=1) 164 | 165 | # Decoder layer constricted using these 2 blocks 166 | def ConRu(in_channels, out_channels, kernel, padding): 167 | return nn.Sequential( 168 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding), 169 | nn.ReLU(inplace=True) 170 | ) 171 | 172 | def ConRuT(in_channels, out_channels, kernel, padding, output_padding): 173 | return nn.Sequential( 174 | nn.ConvTranspose2d(in_channels, out_channels, kernel, stride=2, padding=padding, output_padding=output_padding), 175 | nn.ReLU(inplace=True) 176 | ) 177 | 178 | class PMNet(nn.Module): 179 | 180 | def __init__( 181 | self, 182 | in_ch : int, 183 | n_blocks : list, 184 | atrous_rates : list, 185 | multi_grids : list, 186 | output_stride : int, 187 | ceil_mode : bool = True, 188 | output_padding=(0, 0), 189 | bn_eps : float = 1e-05, 190 | dcn = False 191 | ): 192 | 193 | super(PMNet, self).__init__() 194 | 195 | if output_stride == 8: 196 | s = [1, 2, 1, 1] 197 | d = [1, 1, 2, 4] 198 | elif output_stride == 16: 199 | s = [1, 2, 2, 1] 200 | d = [1, 1, 1, 2] 201 | else: 202 | raise ValueError(f'output_stride={output_stride}, but only 8, 16 allowed') 203 | 204 | # Encoder 205 | ch = [64 * 2 ** p for p in range(6)] 206 | self.layer1 = _Stem(ch[0], in_ch=in_ch, ceil_mode=ceil_mode, bn_eps=bn_eps) 207 | self.layer2 = _ResLayer(n_blocks[0], ch[0], ch[2], s[0], d[0], bn_eps=bn_eps) 208 | self.layer3 = _ResLayer(n_blocks[1], ch[2], ch[3], s[1], d[1], bn_eps=bn_eps) 209 | self.layer4 = _ResLayer(n_blocks[2], ch[3], ch[3], s[2], d[2], bn_eps=bn_eps) 210 | self.layer5 = _ResLayer(n_blocks[3], ch[3], ch[4], s[3], d[3], multi_grids, bn_eps=bn_eps, dcn=dcn) 211 | self.aspp = _ASPP(ch[4], 256, atrous_rates, bn_eps=bn_eps, dcn=dcn) 212 | concat_ch = 256 * (len(atrous_rates) + 2) 213 | self.add_module("fc1", _ConvBnReLU(concat_ch, 512, 1, 1, 0, 1, bn_eps=bn_eps)) 214 | self.reduce = _ConvBnReLU(256, 256, 1, 1, 0, 1, bn_eps=bn_eps) 215 | 216 | # Decoder 217 | self.conv_up5 = ConRu(512, 512, 3, 1) 218 | if output_stride==16: 219 | self.conv_up4 = ConRuT(512+512, 512, 3, 1, output_padding=output_padding[0]) 220 | elif output_stride==8: 221 | self.conv_up4 = ConRu(512+512, 512, 3, 1) 222 | self.conv_up3 = ConRuT(512+512, 256, 3, 1, output_padding=output_padding[1]) 223 | self.conv_up2 = ConRu(256+256, 256, 3, 1) 224 | self.conv_up1 = ConRu(256+256, 256, 3, 1) 225 | 226 | self.conv_up0 = ConRu(256+64, 128, 3, 1) 227 | self.conv_up00 = nn.Sequential( 228 | nn.Conv2d(128+in_ch, 64, kernel_size=3, padding=1), 229 | nn.BatchNorm2d(64, eps=bn_eps), 230 | nn.ReLU(), 231 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 232 | nn.BatchNorm2d(64, eps=bn_eps), 233 | nn.ReLU(), 234 | nn.Conv2d(64, 1, kernel_size=3, padding=1)) 235 | 236 | def forward(self, x): 237 | # Encoder 238 | x1 = self.layer1(x) 239 | x2 = self.layer2(x1) 240 | x3 = self.reduce(x2) 241 | x4 = self.layer3(x3) 242 | x5 = self.layer4(x4) 243 | x6 = self.layer5(x5) 244 | x7 = self.aspp(x6) 245 | x8 = self.fc1(x7) 246 | 247 | # Decoder 248 | xup5 = self.conv_up5(x8) 249 | xup5 = torch.cat([xup5, x5], dim=1) 250 | xup4 = self.conv_up4(xup5) 251 | xup4 = torch.cat([xup4, x4], dim=1) 252 | xup3 = self.conv_up3(xup4) 253 | xup3 = torch.cat([xup3, x3], dim=1) 254 | xup2 = self.conv_up2(xup3) 255 | xup2 = torch.cat([xup2, x2], dim=1) 256 | xup1 = self.conv_up1(xup2) 257 | xup1 = torch.cat([xup1, x1], dim=1) 258 | xup0 = self.conv_up0(xup1) 259 | 260 | xup0 = F.interpolate(xup0, size=x.shape[2:], mode="bilinear", align_corners=False) 261 | xup0 = torch.cat([xup0, x], dim=1) 262 | xup00 = self.conv_up00(xup0) 263 | 264 | return xup00 265 | 266 | if __name__=="__main__": 267 | m = PMNet(n_classes=1, 268 | n_blocks=[3, 3, 27, 3], 269 | atrous_rates=[6, 12, 18], 270 | multi_grids=[1, 2, 4], 271 | output_stride=16,) 272 | 273 | B = 4 274 | H = 256 275 | 276 | input = torch.randn(B, 2, H, H) 277 | output = m(input) 278 | print(output.shape) -------------------------------------------------------------------------------- /lib/RadioUNet/RadioUNet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adjusted version of RadioUNet - we mainly use the first net (below RadioWNet) and in order to use both, one has to use the logic in pl_lightningmodule. 3 | R. Levie, Ç. Yapar, G. Kutyniok and G. Caire, "RadioUNet: Fast Radio Map Estimation With Convolutional Neural Networks," in IEEE Transactions on Wireless Communications, vol. 20, no. 6, pp. 4001-4015, June 2021, doi: 10.1109/TWC.2021.3054977. 4 | https://github.com/RonLevie/RadioUNet 5 | 6 | MIT License 7 | 8 | Copyright (c) 2019 Ron Levie 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | ''' 28 | import torch 29 | import torch.nn as nn 30 | 31 | def convrelu(in_channels, out_channels, kernel, padding, pool): 32 | return nn.Sequential( 33 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding), 34 | #In conv, the dimension of the output, if the input is H,W, is 35 | # H+2*padding-kernel +1 36 | nn.ReLU(inplace=True), 37 | nn.MaxPool2d(pool, stride=pool, padding=0, dilation=1, return_indices=False, ceil_mode=False) 38 | #pooling takes Height H and width W to (H-pool)/pool+1 = H/pool, and floor. Same for W. 39 | #altogether, the output size is (H+2*padding-kernel +1)/pool. 40 | ) 41 | 42 | def convreluT(in_channels, out_channels, kernel, padding): 43 | return nn.Sequential( 44 | nn.ConvTranspose2d(in_channels, out_channels, kernel, stride=2, padding=padding), 45 | nn.ReLU(inplace=True) 46 | #input is H X W, output is (H-1)*2 - 2*padding + kernel 47 | ) 48 | 49 | 50 | ### name is a bit misleading, this is now just the first UNet, second one defined below 51 | class RadioWNet(nn.Module): 52 | 53 | def __init__(self,inputs=2): 54 | super().__init__() 55 | 56 | self.inputs=inputs 57 | 58 | if inputs<=3: 59 | self.layer00 = convrelu(inputs, 6, 3, 1,1) 60 | self.layer0 = convrelu(6, 40, 5, 2,2) 61 | else: 62 | self.layer00 = convrelu(inputs, 10, 3, 1,1) 63 | self.layer0 = convrelu(10, 40, 5, 2,2) 64 | 65 | self.layer1 = convrelu(40, 50, 5, 2,2) 66 | self.layer10 = convrelu(50, 60, 5, 2,1) 67 | self.layer2 = convrelu(60, 100, 5, 2,2) 68 | self.layer20 = convrelu(100, 100, 3, 1,1) 69 | self.layer3 = convrelu(100, 150, 5, 2,2) 70 | self.layer4 =convrelu(150, 300, 5, 2,2) 71 | self.layer5 =convrelu(300, 500, 5, 2,2) 72 | 73 | self.conv_up5 =convreluT(500, 300, 4, 1) 74 | self.conv_up4 = convreluT(300+300, 150, 4, 1) 75 | self.conv_up3 = convreluT(150 + 150, 100, 4, 1) 76 | self.conv_up20 = convrelu(100 + 100, 100, 3, 1, 1) 77 | self.conv_up2 = convreluT(100 + 100, 60, 6, 2) 78 | self.conv_up10 = convrelu(60 + 60, 50, 5, 2, 1) 79 | self.conv_up1 = convreluT(50 + 50, 40, 6, 2) 80 | self.conv_up0 = convreluT(40 + 40, 20, 6, 2) 81 | if inputs<=3: 82 | self.conv_up00 = convrelu(20+6+inputs, 20, 5, 2,1) 83 | 84 | else: 85 | self.conv_up00 = convrelu(20+10+inputs, 20, 5, 2,1) 86 | 87 | self.conv_up000 = convrelu(20+inputs, 1, 5, 2,1) 88 | 89 | 90 | 91 | def forward(self, input): 92 | 93 | input0=input 94 | 95 | layer00 = self.layer00(input0) 96 | layer0 = self.layer0(layer00) 97 | layer1 = self.layer1(layer0) 98 | layer10 = self.layer10(layer1) 99 | layer2 = self.layer2(layer10) 100 | layer20 = self.layer20(layer2) 101 | layer3 = self.layer3(layer20) 102 | layer4 = self.layer4(layer3) 103 | layer5 = self.layer5(layer4) 104 | 105 | layer4u = self.conv_up5(layer5) 106 | layer4u = torch.cat([layer4u, layer4], dim=1) 107 | layer3u = self.conv_up4(layer4u) 108 | layer3u = torch.cat([layer3u, layer3], dim=1) 109 | layer20u = self.conv_up3(layer3u) 110 | layer20u = torch.cat([layer20u, layer20], dim=1) 111 | layer2u = self.conv_up20(layer20u) 112 | layer2u = torch.cat([layer2u, layer2], dim=1) 113 | layer10u = self.conv_up2(layer2u) 114 | layer10u = torch.cat([layer10u, layer10], dim=1) 115 | layer1u = self.conv_up10(layer10u) 116 | layer1u = torch.cat([layer1u, layer1], dim=1) 117 | layer0u = self.conv_up1(layer1u) 118 | layer0u = torch.cat([layer0u, layer0], dim=1) 119 | layer00u = self.conv_up0(layer0u) 120 | layer00u = torch.cat([layer00u, layer00], dim=1) 121 | layer00u = torch.cat([layer00u,input0], dim=1) 122 | layer000u = self.conv_up00(layer00u) 123 | layer000u = torch.cat([layer000u,input0], dim=1) 124 | output1 = self.conv_up000(layer000u) 125 | 126 | return output1 127 | 128 | class RadioWNet2(nn.Module): 129 | 130 | def __init__(self,inputs=2): 131 | super().__init__() 132 | 133 | self.inputs=inputs 134 | 135 | self.Wlayer00 = convrelu(inputs, 20, 3, 1,1) 136 | self.Wlayer0 = convrelu(20, 30, 5, 2,2) 137 | self.Wlayer1 = convrelu(30, 40, 5, 2,2) 138 | self.Wlayer10 = convrelu(40, 50, 5, 2,1) 139 | self.Wlayer2 = convrelu(50, 60, 5, 2,2) 140 | self.Wlayer20 = convrelu(60, 70, 3, 1,1) 141 | self.Wlayer3 = convrelu(70, 90, 5, 2,2) 142 | self.Wlayer4 =convrelu(90, 110, 5, 2,2) 143 | self.Wlayer5 =convrelu(110, 150, 5, 2,2) 144 | 145 | self.Wconv_up5 =convreluT(150, 110, 4, 1) 146 | self.Wconv_up4 = convreluT(110+110, 90, 4, 1) 147 | self.Wconv_up3 = convreluT(90 + 90, 70, 4, 1) 148 | self.Wconv_up20 = convrelu(70 + 70, 60, 3, 1, 1) 149 | self.Wconv_up2 = convreluT(60 + 60, 50, 6, 2) 150 | self.Wconv_up10 = convrelu(50 + 50, 40, 5, 2, 1) 151 | self.Wconv_up1 = convreluT(40 + 40, 30, 6, 2) 152 | self.Wconv_up0 = convreluT(30 + 30, 20, 6, 2) 153 | self.Wconv_up00 = convrelu(20+20+inputs, 20, 5, 2,1) 154 | self.Wconv_up000 = convrelu(20+inputs, 1, 5, 2,1) 155 | 156 | def forward(self, input): 157 | 158 | Winput=input 159 | 160 | Wlayer00 = self.Wlayer00(Winput) 161 | Wlayer0 = self.Wlayer0(Wlayer00) 162 | Wlayer1 = self.Wlayer1(Wlayer0) 163 | Wlayer10 = self.Wlayer10(Wlayer1) 164 | Wlayer2 = self.Wlayer2(Wlayer10) 165 | Wlayer20 = self.Wlayer20(Wlayer2) 166 | Wlayer3 = self.Wlayer3(Wlayer20) 167 | Wlayer4 = self.Wlayer4(Wlayer3) 168 | Wlayer5 = self.Wlayer5(Wlayer4) 169 | 170 | Wlayer4u = self.Wconv_up5(Wlayer5) 171 | Wlayer4u = torch.cat([Wlayer4u, Wlayer4], dim=1) 172 | Wlayer3u = self.Wconv_up4(Wlayer4u) 173 | Wlayer3u = torch.cat([Wlayer3u, Wlayer3], dim=1) 174 | Wlayer20u = self.Wconv_up3(Wlayer3u) 175 | Wlayer20u = torch.cat([Wlayer20u, Wlayer20], dim=1) 176 | Wlayer2u = self.Wconv_up20(Wlayer20u) 177 | Wlayer2u = torch.cat([Wlayer2u, Wlayer2], dim=1) 178 | Wlayer10u = self.Wconv_up2(Wlayer2u) 179 | Wlayer10u = torch.cat([Wlayer10u, Wlayer10], dim=1) 180 | Wlayer1u = self.Wconv_up10(Wlayer10u) 181 | Wlayer1u = torch.cat([Wlayer1u, Wlayer1], dim=1) 182 | Wlayer0u = self.Wconv_up1(Wlayer1u) 183 | Wlayer0u = torch.cat([Wlayer0u, Wlayer0], dim=1) 184 | Wlayer00u = self.Wconv_up0(Wlayer0u) 185 | Wlayer00u = torch.cat([Wlayer00u, Wlayer00], dim=1) 186 | Wlayer00u = torch.cat([Wlayer00u,Winput], dim=1) 187 | Wlayer000u = self.Wconv_up00(Wlayer00u) 188 | Wlayer000u = torch.cat([Wlayer000u,Winput], dim=1) 189 | output2 = self.Wconv_up000(Wlayer000u) 190 | 191 | return output2 192 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabja19/RML/aaa88503fcdd0810ac4544b942fee2548d01f866/lib/__init__.py -------------------------------------------------------------------------------- /lib/dcn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is taken from https://github.com/developer0hye/PyTorch-Deformable-Convolution-v2 . 3 | 4 | License: 5 | MIT License 6 | 7 | Copyright (c) 2021 Yonghye Kwon 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | 27 | Citation: 28 | Kwon, Y. (2022). PyTorch-Deformable-Convolution-v2 (Version 1.0.0) [Computer software]. https://github.com/developer0hye/PyTorch-Deformable-Convolution-v2 29 | ''' 30 | 31 | 32 | import torch 33 | import torchvision.ops 34 | from torch import nn 35 | import numpy as np 36 | 37 | class DeformableConv2d(nn.Module): 38 | def __init__(self, 39 | in_channels, 40 | out_channels, 41 | kernel_size=3, 42 | stride=1, 43 | padding=1, 44 | bias=False): 45 | 46 | super(DeformableConv2d, self).__init__() 47 | 48 | assert type(kernel_size) == tuple or type(kernel_size) == int 49 | 50 | kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) 51 | self.stride = stride if type(stride) == tuple else (stride, stride) 52 | self.padding = padding 53 | 54 | self.offset_conv = nn.Conv2d(in_channels, 55 | 2 * kernel_size[0] * kernel_size[1], 56 | kernel_size=kernel_size, 57 | stride=stride, 58 | padding=self.padding, 59 | bias=True) 60 | 61 | nn.init.constant_(self.offset_conv.weight, 0.) 62 | nn.init.constant_(self.offset_conv.bias, 0.) 63 | 64 | self.modulator_conv = nn.Conv2d(in_channels, 65 | 1 * kernel_size[0] * kernel_size[1], 66 | kernel_size=kernel_size, 67 | stride=stride, 68 | padding=self.padding, 69 | bias=True) 70 | 71 | nn.init.constant_(self.modulator_conv.weight, 0.) 72 | nn.init.constant_(self.modulator_conv.bias, 0.) 73 | 74 | self.regular_conv = nn.Conv2d(in_channels=in_channels, 75 | out_channels=out_channels, 76 | kernel_size=kernel_size, 77 | stride=stride, 78 | padding=self.padding, 79 | bias=bias) 80 | 81 | def forward(self, x): 82 | #h, w = x.shape[2:] 83 | #max_offset = max(h, w)/4. 84 | 85 | offset = self.offset_conv(x)#.clamp(-max_offset, max_offset) 86 | modulator = 2. * torch.sigmoid(self.modulator_conv(x)) 87 | 88 | x = torchvision.ops.deform_conv2d(input=x, 89 | offset=offset, 90 | weight=self.regular_conv.weight, 91 | bias=self.regular_conv.bias, 92 | padding=self.padding, 93 | mask=modulator, 94 | stride=self.stride, 95 | ) 96 | return x 97 | 98 | -------------------------------------------------------------------------------- /lib/pl_callbacks.py: -------------------------------------------------------------------------------- 1 | '''Our custom callbacks.''' 2 | import torch 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.callbacks import Callback 5 | from torchinfo import summary 6 | from pathlib import Path 7 | 8 | class SummaryCustom(Callback): 9 | ''' 10 | This callback generates a model summary based on torchinfo, prints it to the shell and saves it to the log dir of the trainer if possible. 11 | The summary generated by torchinfo is much more detailed than the one generated by PL. 12 | Also add hparams of the dataset so we can get (almost) all information about the simulation from the log file. 13 | ''' 14 | def __init__(self, depth=5, batch_size=32): 15 | self.depth = depth 16 | self.batch_size = batch_size 17 | 18 | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 19 | try: 20 | input_size = (self.batch_size, pl_module.hparams['in_ch'], 256, 256) 21 | except ValueError as e: 22 | print(e) 23 | return 24 | device = torch.device('cuda') if torch.cuda.is_available() and isinstance(trainer.accelerator, pl.accelerators.cuda.CUDAAccelerator) else torch.device('cpu') 25 | 26 | sum = summary(pl_module, input_size=input_size, depth=6, device=device, verbose=0) 27 | strPrint = ( 28 | f'\nSaving to {trainer.log_dir}' 29 | f'\n{pl_module}\n' 30 | f'\n{pl_module.hparams_initial}\n' 31 | f'\n{sum}' 32 | f'\n(for input_size={input_size}, with AMP significantly less memory)' 33 | ) 34 | 35 | try: 36 | strPrint += f'\nData:\t{trainer.datamodule}\n\n{trainer.datamodule.hparams}' 37 | except Exception as e: 38 | print(e) 39 | 40 | print(strPrint) 41 | 42 | ### in fast_dev_run, the trainer has no log_dir 43 | if hasattr(trainer, 'log_dir') and trainer.log_dir is not None: 44 | with open(Path(trainer.log_dir) / "log_file.txt", "a", encoding='utf-8') as f: 45 | print(strPrint, file=f) 46 | 47 | def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 48 | print( 49 | f'\nSaving to {trainer.log_dir}\n' 50 | f'loaded {pl_module}' 51 | ) 52 | 53 | class AdjustThreads(Callback): 54 | '''Adjust threads used by PyTorch (for use on cluster)''' 55 | def __init__(self, t : None | int | str = 'auto', verb : bool = False) -> None: 56 | ''' 57 | Set number of threads to a fixed value if t is int or derive automatically from os.sched_affinity if 'auto' or skip if None 58 | ''' 59 | from os import sched_getaffinity 60 | if t is None or t=='None': 61 | from torch import get_num_threads, get_num_interop_threads 62 | if verb: 63 | print(f'\n\nnot adjusting threads (using {get_num_threads()}, {get_num_interop_threads} out of {len(sched_getaffinity(0))}\n\n') 64 | return 65 | from torch import set_num_threads, set_num_interop_threads 66 | if t == 'auto': 67 | t = len(sched_getaffinity(0)) 68 | else: 69 | t = t 70 | set_num_threads(t) 71 | set_num_interop_threads(t) 72 | if verb: 73 | print(f'\n\nadjusted threads to {t}\n\n') 74 | 75 | class Precision(Callback): 76 | '''EXPLAIN''' 77 | def __init__(self, matmul_precision : str = 'medium', conv_tf32 : bool = True) -> None: 78 | '''Allows to set matmul precision and TF32 for convs from the CLI, should not have an effect with AMP''' 79 | torch.set_float32_matmul_precision(precision=matmul_precision) 80 | torch.backends.cudnn.allow_tf32 = conv_tf32 81 | -------------------------------------------------------------------------------- /lib/pl_datamodule.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pathlib import Path 3 | import json 4 | import numpy as np 5 | import torch 6 | from matplotlib import pyplot as plt 7 | from . import torch_datasets 8 | 9 | 10 | class LitRM(pl.LightningDataModule): 11 | ''' 12 | Abstract DataModule, to be subclassed for specific datasets. This is to specify some common parameters and methods and avoid copying. 13 | ''' 14 | def __init__(self, 15 | dataset_class : torch_datasets.RM, 16 | dataset_path : str | Path, 17 | id_file : str, 18 | augmentation : bool = True, 19 | ###### params for dataloader 20 | batch_size : int = 32, 21 | shuffle : bool = True, 22 | num_workers : int = 8, 23 | pin_memory : bool = True, 24 | **dataset_params 25 | ): 26 | ''' 27 | ''' 28 | super().__init__() 29 | 30 | self.dataset_class = dataset_class 31 | 32 | self.loader_params = dict( 33 | batch_size=batch_size, 34 | shuffle=shuffle, 35 | num_workers=num_workers, 36 | pin_memory=pin_memory, 37 | ) 38 | 39 | self.dataset_params = { 40 | 'dataset_path' : Path(dataset_path), 41 | 'augmentation' : augmentation, 42 | **dataset_params 43 | } 44 | 45 | with open(Path(dataset_path) / id_file, 'r') as f: 46 | id_data = json.load(f) 47 | 48 | self.train_ids = id_data['train'] 49 | self.val_ids = id_data['val'] 50 | self.test_ids = id_data['test'] 51 | 52 | self.in_ch = dataset_class.get_in_ch(**self.dataset_params) 53 | self.PL_scale = dataset_class.get_PL_scale(**self.dataset_params) 54 | 55 | # to be overridden by subclasses after calling super init 56 | if not hasattr(self, 'name'): 57 | self.name = 'LitRM' 58 | 59 | print(f"prepared dataset {dataset_class}\t{len(self.train_ids)} train / {len(self.val_ids)} val / {len(self.test_ids)} test ids") 60 | 61 | def setup(self, stage: str=None) -> None: 62 | '''' 63 | DESCRIPTION 64 | ''' 65 | if stage=="fit" or stage is None: 66 | self.train_set = self.dataset_class(ids=self.train_ids, **self.dataset_params) 67 | self.val_set = self.dataset_class(ids=self.val_ids, **self.dataset_params) 68 | elif stage=="validate": 69 | self.val_set = self.dataset_class(ids=self.val_ids, **{**self.dataset_params, 'augmentation' : False}) 70 | elif stage=="test": 71 | self.test_set = self.dataset_class(ids=self.test_ids, **{**self.dataset_params, 'augmentation' : False, "get_ids" : True, "test" : True}) 72 | elif stage=="predict": 73 | self.predict_set = self.dataset_class(ids=self.test_ids,**{**self.dataset_params, 'augmentation' : False, "get_ids" : True}) 74 | else: 75 | raise NotImplementedError(f"stage={stage}") 76 | 77 | def train_dataloader(self): 78 | return torch.utils.data.DataLoader(self.train_set, **self.loader_params) 79 | 80 | def val_dataloader(self): 81 | ### shuffle the validation set so we see different outputs in TB every epoch 82 | return torch.utils.data.DataLoader(self.val_set, **{**self.loader_params, "shuffle" : True}) 83 | 84 | def test_dataloader(self, shuffle=False): 85 | return torch.utils.data.DataLoader(self.test_set, **{**self.loader_params, "shuffle" : shuffle}) 86 | 87 | def predict_dataloader(self): 88 | return torch.utils.data.DataLoader(self.predict_set, **{**self.loader_params, "shuffle" : False}) 89 | 90 | def __repr__(self): 91 | return self.name 92 | 93 | class LitRM_directional(LitRM): 94 | ''' 95 | Directional radio map dataset in lightning, on top of torch_dataset_directional.RMD and subclassing LitRM. 96 | All of the important logic is implemented in the torch class and LitRM, but this class allows us to easily track the used parameters, set defaults and have a common base with other datasets. 97 | ''' 98 | def __init__(self, 99 | dataset_path : str | Path = Path('./dataset'), 100 | id_file : str = 'splits_0.1_0.1_seeds_1_2_ant_all.json', 101 | ### inputs 102 | tx_one_hot : bool = True, 103 | ndsms : bool = True, 104 | ndsm_all : bool = False, 105 | dist2d : bool = False, 106 | dist2d_log : bool = False, 107 | coords_euclidian : bool = False, 108 | coords_cylindrical : bool = False, 109 | coords_spherical : bool = False, 110 | coords_ga : bool = False, 111 | los_floor : bool = False, 112 | los_top : bool = False, 113 | los_z_min : bool = False, 114 | los_z_min_rel : bool = False, 115 | los_theta_max : bool = False, 116 | ant_gain_floor : bool = True, 117 | ant_gain_top : bool = True, 118 | ant_gain_slices : bool = False, 119 | ant_gain_los_slices : bool = False, 120 | ant_gain_floor_dist : bool = False, 121 | ant_gain_top_dist : bool = False, 122 | ant_gain_slices_dist : bool = False, 123 | ant_gain_los_slices_dist : bool = False, 124 | img : bool = False, 125 | img_rgb : bool = False, 126 | elevation : bool = False, 127 | azimuth : bool = False, 128 | ### values for scaling and cutting height values/path loss 129 | z_step : int = 4, 130 | z_max : int = 32, 131 | thresh : float | None = 0.2, 132 | ### params for dataloader, augmentation 133 | **kwargs 134 | ): 135 | ''' 136 | Inputs that can be requested by passing True: 137 | tx_one_hot - tensor with Tx height as value in Tx position, 0 elsewhere 138 | ndsms - height maps buildings and vegetation 139 | ndsm_all - height map buildings and vegetation together (unclassified) 140 | dist2d - 2D distance of each pixel to Tx 141 | dist2d_log - log10(dist2d) 142 | coords_euclidian - euclidian coordinate system (Tx perspective) 143 | coords_cylindrical - cylindrical coordinate system (Tx perspective) 144 | coords_spherical - spherical coordinate system (Tx perspective) 145 | coords_ga - grid anchor (see https://ieeexplore.ieee.org/document/9753644) 146 | los_floor - binary LoS information for floor/ground in each pixel 147 | los_top - binary LoS information for building top in each pixel 148 | los_z_min - minimum z-value visible from Tx in each pixel 149 | los_z_min_rel - minimum z-value visible from Tx in each pixel, minus Tx height 150 | los_theta_max - maximum theta-value visible from Tx in each pixel (spherical coordinates) 151 | ant_gain_floor - antenna gain projected onto the ground 152 | ant_gain_top - antenna gain projected onto the building top 153 | ant_gain_slices - antenna gain projected onto planes parallel to the ground according to z_step, z_max 154 | ant_gain_los_slices - antenna gain projected onto planes parallel to the ground according to z_step, z_max, additionally 0 if no LoS 155 | ant_gain_X_dist - gain in dB - 2*log_10(dist2d) , corresponding to free space path loss 156 | img - aerial image (RGBI) 157 | img_rgb - aerial image (RGB) 158 | elevation - tilt/elevation angle (spherical coordinates) 159 | azimuth - azimuth angle (spherical/cylindrical coordinates) 160 | 161 | z_step, z_max - generate slices (gain) at heights 0, z_step, 2*z_step, ..., (z_step-1) * z_max 162 | thresh - cut off lower part of the dB range from radio map 163 | ''' 164 | self.save_hyperparameters() 165 | 166 | super().__init__( 167 | dataset_class=torch_datasets.RM_directional, 168 | dataset_path=dataset_path, 169 | id_file=id_file, 170 | tx_one_hot = tx_one_hot, 171 | ndsms = ndsms, 172 | ndsm_all = ndsm_all, 173 | dist2d = dist2d, 174 | dist2d_log = dist2d_log, 175 | coords_euclidian = coords_euclidian, 176 | coords_cylindrical = coords_cylindrical, 177 | coords_spherical = coords_spherical, 178 | coords_ga = coords_ga, 179 | los_floor = los_floor, 180 | los_top = los_top, 181 | los_z_min = los_z_min, 182 | los_z_min_rel = los_z_min_rel, 183 | los_theta_max = los_theta_max, 184 | ant_gain_floor = ant_gain_floor, 185 | ant_gain_top = ant_gain_top, 186 | ant_gain_slices = ant_gain_slices, 187 | ant_gain_los_slices = ant_gain_los_slices, 188 | ant_gain_floor_dist = ant_gain_floor_dist, 189 | ant_gain_top_dist = ant_gain_top_dist, 190 | ant_gain_slices_dist = ant_gain_slices_dist, 191 | ant_gain_los_slices_dist = ant_gain_los_slices_dist, 192 | img = img, 193 | img_rgb = img_rgb, 194 | azimuth=azimuth, 195 | elevation=elevation, 196 | z_step = z_step, 197 | z_max = z_max, 198 | thresh=thresh, 199 | **kwargs 200 | ) 201 | self.name = 'LitRM_directional' 202 | 203 | def __repr__(self) -> str: 204 | return self.name 205 | 206 | # for colorbars of the same size as the image 207 | def colorbar(mappable): 208 | from mpl_toolkits.axes_grid1 import make_axes_locatable 209 | import matplotlib.pyplot as plt 210 | last_axes = plt.gca() 211 | ax = mappable.axes 212 | fig = ax.figure 213 | divider = make_axes_locatable(ax) 214 | cax = divider.append_axes("right", size="5%", pad=0.05) 215 | cbar = fig.colorbar(mappable, cax=cax) 216 | plt.sca(last_axes) 217 | return cbar 218 | 219 | 220 | 221 | ''' 222 | Test 223 | ''' 224 | if __name__ == "__main__": 225 | print("testing LitRMD") 226 | litrm = LitRM_directional() 227 | litrm.prepare_data() 228 | print("prepared data") 229 | litrm.setup("fit") 230 | print("setup data stage fit") 231 | litrm.val_dataloader() 232 | t = litrm.train_dataloader() 233 | print("generated train, val loaders") 234 | litrm.setup("test") 235 | print("setup stage test") 236 | litrm.test_dataloader() 237 | print("generated test loader") 238 | 239 | inputs, target = next(iter(t)) 240 | print(f'first batch in train loader:\ninputs: {type(inputs), inputs.shape, inputs.dtype}\ntarget: {type(target), target.shape, target.dtype}\nsaving to {Path("./test_dataset.png")}') 241 | try: 242 | from .pl_lightningmodule import show_samples 243 | show_samples(inputs, target=target, path_save=Path("./test_dataset.png")) 244 | except ImportError as e: 245 | print(e, '.pl_lightningmodule not found?') -------------------------------------------------------------------------------- /lib/pl_lightningmodule.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | from typing import Any 3 | import pytorch_lightning as pl 4 | from pathlib import Path 5 | import torch 6 | from torch import nn 7 | import time 8 | from lib.torch_layers import Modular, MSE_img, L1_img 9 | from numpy import ceil 10 | from skimage import io 11 | import numpy as np 12 | 13 | class LitCNN(pl.LightningModule): 14 | ''' 15 | Class defining the training, testing etc logic around the models. 16 | Takes an arbitrary nn.Module defining the architecture, for specific models we can subclass LitCNN and pass the corresponding model. 17 | ''' 18 | def __init__(self, 19 | model : nn.Module, 20 | PL_scale : float = 100, 21 | lr : float = 1e-4, 22 | optimizer : str = "Adam", 23 | optimizer_params : dict = {}, 24 | scheduler : str | None = 'ReduceLROnPlateau', 25 | scheduler_params : dict[str, int | float | bool | str] = dict( 26 | threshold=1e-4, 27 | patience=4, 28 | verbose=True, 29 | threshold_mode='abs' 30 | ), 31 | scheduler_interval : str = "epoch", 32 | scheduler_frequency : int = 1, 33 | loss_fn : str = 'MSELoss', 34 | channels_last : bool = True, 35 | model_weights : str | None = None, 36 | previous_model : str | None = None, 37 | train_previous_model : bool = False, 38 | **kwargs_tracking 39 | ): 40 | ''' 41 | Inputs: 42 | model - the torch model to be used 43 | PL_scale - based on the dataset, only to convert grayscale loss/error to dB 44 | lr - base learning rate 45 | optimizer - optimizer name 46 | optimizer_params - parameters for initialization of optimizer 47 | scheduler - scheduler name 48 | scheduler_params - parameters for initialization of scheduler 49 | scheduler_interval, scheduler_frequency - define how PL handles the scheduler (change e.g. for annealing) 50 | loss_fn - function to calculating loss, subclass of nn.Module, given as string for CLI and logging to work properly 51 | channels_last - use channels_last tensor layout 52 | model_weights - path to checkpoint to load only weights but nothing else saved in the checkpoint (e.g. optimizer states) 53 | previous_model - log directory of an existing model that shall be used before the current one (curriculum training as in RadioUNet) 54 | train_previous_model- whether to train the previous_model as well or not (freeze it) 55 | **kwargs_tracking - not used, only to track other parameters (e.g. from the dataset) with tensorboard 56 | ''' 57 | super().__init__() 58 | ######### 59 | self.example_input_array = torch.zeros((32, 5, 256, 256)) 60 | ######### 61 | if previous_model is not None: 62 | print(f'LitCNN previous: {previous_model}') 63 | self.save_hyperparameters(ignore=['model']) 64 | self.model = model 65 | if model_weights is not None: 66 | weights = torch.load(Path(model_weights))['state_dict'] 67 | ### remove the "model." in the beginning 68 | self.model.load_state_dict({k[6:] : v for k, v in weights.items()}) 69 | self.model_weights = model_weights 70 | self.PL_scale = PL_scale 71 | self.lr = lr 72 | self.optimizer = optimizer 73 | self.optimizer_params = optimizer_params 74 | self.scheduler = scheduler 75 | self.scheduler_params = scheduler_params 76 | self.scheduler_interval = scheduler_interval 77 | self.scheduler_frequency = scheduler_frequency 78 | self.loss_fn = getattr(torch.nn, loss_fn)() 79 | self.channels_last = channels_last 80 | 81 | if previous_model is None: 82 | self.previous = previous_model 83 | else: 84 | import yaml 85 | sub_dir = list(Path(previous_model).iterdir()) 86 | assert len(sub_dir)==1, f'run_dir of {previous_model} could not be determined uniquely, found sub_dirs {sub_dir}' 87 | sub_dir = sub_dir[0] 88 | ckpt1 = list((sub_dir / 'checkpoints').glob('*loss*.ckpt')) 89 | assert len(ckpt1) == 1, f'checkpoint for model1 ({previous_model}) not (uniquely) determined, found {ckpt1}' 90 | ckpt1 = ckpt1[0] 91 | config = sub_dir / 'config.yaml' 92 | assert config.is_file(), f'config for model1 ({previous_model})not existing/not a file' 93 | with open(config, 'r') as f: 94 | params = yaml.safe_load(f) 95 | model1_class = LitModelDict[params['model']['class_path'].split('.')[-1]] 96 | self.previous = model1_class.load_from_checkpoint(ckpt1) 97 | print(f'loaded {model1_class} from {ckpt1} using {config}') 98 | if not train_previous_model: 99 | self.previous.freeze() 100 | 101 | ### print unused params to the log/terminal to allow recognizing errors, in particular arguments not understood due to typos, and to pass e.g. params from the dataset to Tensorboard 102 | if kwargs_tracking: 103 | print(f"\nunused params in LitCNN: {kwargs_tracking}\n") 104 | 105 | self.train_losses = [] 106 | self.val_losses = [] 107 | self.metrics = { 108 | ### may be used to generate more metrics during testing 109 | 'mse' : MSE_img(reduction='none'), 110 | # 'l1' : L1_img(reduction='none'), 111 | } 112 | 113 | self.test_losses = { 114 | k : [] for k in self.metrics.keys() 115 | } 116 | ### for testing, also store id and magnitude (squared) of each sample for later analysis, also number of positive pixels 117 | self.test_magnitudes = { 118 | k : [] for k in self.metrics.keys() 119 | } 120 | self.pos_pixels = [] 121 | self.test_ids = [] 122 | 123 | def __repr__(self): 124 | if hasattr(self, "name"): 125 | return self.name 126 | else: 127 | return "LitCNN" 128 | 129 | def forward(self, x : Any) -> Any: 130 | if self.previous is not None: 131 | x = torch.cat([x, self.previous(x)], dim=-3) 132 | return self.model(x) 133 | 134 | def training_step(self, batch, batch_idx) -> torch.Tensor: 135 | 136 | inputs, target = batch 137 | output = self(inputs) 138 | loss = self.loss_fn(output, target) 139 | 140 | self.log("loss_train_it", loss.detach().cpu(), batch_size=target.shape[0], on_step=True, prog_bar=True) 141 | self.log("loss_train_avg", loss.detach().cpu(), batch_size=target.shape[0], on_epoch=True, on_step=False) 142 | 143 | self.train_losses.append(loss.detach().cpu()) 144 | 145 | return loss 146 | 147 | def validation_step(self, batch, batch_idx) -> torch.Tensor: 148 | inputs, target = batch 149 | output = self(inputs) 150 | loss = self.loss_fn(output, target).detach().cpu() 151 | 152 | self.log("loss_val_avg", loss, batch_size=target.shape[0], on_epoch=True, prog_bar=True) 153 | self.val_losses.append(loss) 154 | 155 | if batch_idx==0: 156 | ### save examples to TB log 157 | fig_list = [] 158 | for m in range(min(inputs.shape[0], 20)): 159 | fig_list.append(show_samples(batch, m, output)) 160 | self.logger.experiment.add_figure(f"val_samples", fig_list, global_step=self.current_epoch) 161 | plt.close('all') 162 | 163 | return loss 164 | 165 | def test_step(self, batch, batch_idx) -> None: 166 | if len(batch)==3: 167 | inputs, target, map_id = batch 168 | masks = None 169 | elif len(batch)==4: 170 | inputs, target, masks, map_id = batch 171 | ### we don't know masks beforehand, add them to metrics in first batch 172 | if batch_idx==0: 173 | for m in masks.keys(): 174 | for me in self.metrics.keys(): 175 | self.test_losses[f'{me}_mask_{m}'] = [] 176 | self.test_magnitudes[f'{me}_mask_{m}'] = [] 177 | 178 | output = self(inputs) 179 | 180 | for k, v in self.metrics.items(): 181 | self.test_losses[k].append(v(output, target).cpu()) 182 | self.test_magnitudes[k].append(v(target, torch.zeros_like(target)).cpu()) 183 | if masks is not None: 184 | for km, m in masks.items(): 185 | self.test_losses[f'{k}_mask_{km}'].append(v(output * m,target * m).cpu()) 186 | self.test_magnitudes[f'{k}_mask_{km}'].append(v(target * m, torch.zeros_like(target)).cpu()) 187 | 188 | self.pos_pixels.append(torch.squeeze(torch.sum(target > 0, dim=(-1,-2)))) 189 | self.test_ids.append(['_'.join([map_id[i][j] for i in range(len(map_id))]) for j in range(len(map_id[0]))]) 190 | 191 | def predict_step(self, batch, batch_idx): 192 | inputs, target, map_id = batch 193 | output = self(inputs) 194 | ### save prediction and loss for each sample 195 | for b in range(inputs.shape[0]): 196 | map_id_str = '' 197 | for i in map_id: 198 | map_id_str += f'{i[b]}_' 199 | io.imsave(self.predict_dir / f'{map_id_str[:-1]}.png', np.clip((255 * torch.squeeze(output[b,:]).cpu().numpy()), 0, 255).astype(np.uint8), check_contrast=False) 200 | 201 | def on_train_epoch_start(self) -> None: 202 | self.train_losses = [] 203 | 204 | def on_validation_epoch_start(self) -> None: 205 | self.val_losses = [] 206 | 207 | def on_test_epoch_start(self) -> None: 208 | self.test_losses = { 209 | k : [] for k in self.test_losses.keys() 210 | } 211 | self.test_magnitudes = { 212 | k : [] for k in self.test_magnitudes.keys() 213 | } 214 | self.pos_pixels = [] 215 | self.test_ids = [] 216 | 217 | def on_train_epoch_end(self) -> None: 218 | '''Add average losses of this epoch to the log file and check whether val loss improved.''' 219 | ### in case we continue training from a checkpoint, this hook gets triggered before training actually starts again 220 | if len(self.train_losses)==0: 221 | return 222 | avg_loss_train = float(torch.mean(torch.stack(self.train_losses))) 223 | avg_loss_val = float(torch.mean(torch.stack(self.val_losses))) if len(self.val_losses) > 0 else 1e5 224 | 225 | 226 | ### we are using "loss_val_avg" to track the avg loss over all epochs 227 | ### hp_metric is instead used to track the best avg loss so far 228 | if avg_loss_val < self.best_val[0]: 229 | self.best_val = (avg_loss_val, self.current_epoch) 230 | self.log('hp_metric', self.best_val[0]) 231 | 232 | if self.trainer.log_dir is not None: 233 | with open(Path(self.trainer.log_dir) / "log_file.txt", "a", encoding='utf-8') as f: 234 | f.write(f'Epoch\t{self.current_epoch}/{self.trainer.max_epochs}\t{time.strftime("%d.%m.%y-%H:%M:%S")}\t\ttrain: {avg_loss_train:.7f}\t\tval: {avg_loss_val:.7f} (best: {self.best_val[0]:.7f} in ep. {self.best_val[1]})\t\tlr: {self.trainer.optimizers[0].param_groups[0]["lr"]:.7f}\n') 235 | 236 | def on_test_epoch_end(self) -> None: 237 | ''' 238 | Add average losses of this epoch to the log file and TB. 239 | ''' 240 | torch.set_printoptions(precision=6) 241 | save_dir = Path(self.trainer.log_dir) 242 | 243 | import pandas as pd 244 | columns = ['id'] 245 | for k in self.test_losses.keys(): 246 | columns.append(k) 247 | columns.append(f'{k}_magnitude') 248 | # columns.append(f'{k}_normalized') 249 | columns.append('pos_pixels') 250 | df = pd.DataFrame(columns=columns) 251 | 252 | ### loop over all batches, store results per sample in df, collect losses and normalized losses for averaging 253 | n_samples = 0 254 | losses_acc = { 255 | k : 0 for k in self.test_losses.keys() 256 | } 257 | losses_norm_acc = { 258 | k : 0 for k in self.test_losses.keys() 259 | } 260 | for i in range(len(self.test_ids)): 261 | ids = self.test_ids[i] 262 | pos_pixels = self.pos_pixels[i] 263 | losses = { 264 | k : self.test_losses[k][i] for k in self.test_losses.keys() 265 | } 266 | magnitudes = { 267 | k : self.test_magnitudes[k][i] for k in self.test_losses.keys() 268 | } 269 | for j in range(len(ids)): 270 | row = [ids[j]] 271 | for k in self.test_losses.keys(): 272 | 273 | row.append(losses[k][j].item()) 274 | row.append(magnitudes[k][j].item()) 275 | # row.append(losses[k][j].item() / (magnitudes[k][j].item() + 1e-4)) 276 | losses_acc[k] += losses[k][j].item() 277 | # losses_norm_acc[k] += losses[k][j].item() / (magnitudes[k][j].item() + 1e-4) 278 | row.append(pos_pixels[j].item()) 279 | df = pd.concat([df, pd.DataFrame([row], columns=columns)]) 280 | 281 | n_samples += len(ids) 282 | df.to_csv(save_dir / "errors_per_sample.csv", index=False) 283 | 284 | with open(save_dir / "log_file_test.txt", "a", encoding='utf-8') as f: 285 | f.write(f'Test \t{time.strftime("%d.%m.%y-%H:%M:%S")}\nAverage losses/metrics: ({n_samples} samples)\n\n') 286 | ### calculate averages, save to test_log 287 | for k, v in losses_acc.items(): 288 | message = f'{k}:\t{v / n_samples}\n' 289 | print(message) 290 | f.write(message) 291 | 292 | def configure_optimizers(self): 293 | if hasattr(torch.optim, self.optimizer): 294 | opt_class = getattr(torch.optim, self.optimizer) 295 | opt = opt_class(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr, **self.optimizer_params) 296 | if self.scheduler is None: 297 | return {'optimizer' : opt} 298 | elif hasattr(torch.optim.lr_scheduler, self.scheduler): 299 | sched_class = getattr(torch.optim.lr_scheduler, self.scheduler) 300 | sched = sched_class(opt, **self.scheduler_params) 301 | return { 302 | "optimizer" : opt, 303 | "lr_scheduler" : { 304 | "scheduler" : sched, 305 | "interval" : self.scheduler_interval, 306 | "frequency": self.scheduler_frequency, 307 | "monitor": "loss_val_avg", 308 | "strict": True, 309 | "name": "lr" 310 | } 311 | } 312 | else: 313 | raise ValueError(f"scheduler {self.scheduler} not found") 314 | else: 315 | raise ValueError(f"optimizer {self.optimizer} not found") 316 | 317 | def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: torch.optim.Optimizer) -> None: 318 | optimizer.zero_grad(set_to_none=True) 319 | 320 | def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> tuple: 321 | ### can we do a nicer check here? 322 | if self.channels_last: 323 | for inp in batch: 324 | if isinstance(inp, torch.Tensor) and inp.dim() > 3: 325 | inp = inp.to(memory_format=torch.channels_last) 326 | 327 | return batch 328 | 329 | def on_test_start(self) -> None: 330 | if self.channels_last: 331 | self = self.to(memory_format=torch.channels_last) 332 | 333 | def on_validation_start(self) -> None: 334 | if self.channels_last: 335 | self = self.to(memory_format=torch.channels_last) 336 | 337 | def on_predict_epoch_start(self) -> None: 338 | self.predict_dir = Path(self.trainer.logger.log_dir).parent / 'predictions' 339 | self.predict_dir.mkdir(exist_ok=True) 340 | self.start_time = time.time() 341 | 342 | def on_predict_epoch_end(self) -> None: 343 | time_diff = time.time() - self.start_time 344 | days = int(time_diff // (24 * 60 * 60)) 345 | left = time_diff % (24 * 60 * 60) 346 | hours = int(left // (60 * 60)) 347 | left = left % (60 * 60) 348 | mins = int(left // 60) 349 | secs = int(left % 60) 350 | message = f"prediction ended at {time.strftime('%y%m%d-%H:%M:%S')} after {days} days, {hours}:{mins}:{secs}" 351 | if self.trainer.log_dir is not None: ##could be None in fast_dev_run 352 | with open(Path(self.trainer.log_dir) / "log_file_predict.txt", "a", encoding='utf-8') as f: 353 | f.write(message) 354 | print(message) 355 | 356 | def on_fit_start(self): 357 | ### store best val loss and corresponding epoch 358 | self.best_val = (1e5, -1) 359 | 360 | ### store all the losses for a single train, validation epoch to calculate average afterwards 361 | ### reset them here in case we load an already trained model 362 | self.train_losses = [] 363 | self.val_losses = [] 364 | 365 | self.start_time = time.time() 366 | 367 | if self.channels_last: 368 | self = self.to(memory_format=torch.channels_last) 369 | 370 | message = f"\n\ntraining started at {time.strftime('%y%m%d-%H:%M:%S')}\n\n" 371 | if self.model_weights is not None: 372 | message = f'\nLoaded weights from {self.model_weights}\n' + message 373 | if self.trainer.log_dir is not None: ##could be None in fast_dev_run 374 | with open(Path(self.trainer.log_dir) / "log_file.txt", "a", encoding='utf-8') as f: 375 | f.write(message) 376 | print(message) 377 | print(f'Optimizers: {self.optimizers()}\tSchedulers: {self.lr_schedulers()}') 378 | 379 | def on_fit_end(self): 380 | time_diff = time.time() - self.start_time 381 | days = int(time_diff // (24 * 60 * 60)) 382 | left = time_diff % (24 * 60 * 60) 383 | hours = int(left // (60 * 60)) 384 | left = left % (60 * 60) 385 | mins = int(left // 60) 386 | secs = int(left % 60) 387 | message = f"training ended at {time.strftime('%y%m%d-%H:%M:%S')} after {days} days, {hours}:{mins}:{secs}\nbest loss:\t{self.best_val[0]} in epoch {self.best_val[1]}" 388 | if self.trainer.log_dir is not None: ##could be None in fast_dev_run 389 | with open(Path(self.trainer.log_dir) / "log_file.txt", "a", encoding='utf-8') as f: 390 | f.write(message) 391 | print(message) 392 | 393 | class LitUNet(LitCNN): 394 | 'Basic UNet with several options for adjustments, e.g. residual connections (ResNet-like), Dropout, different activation, down and up sampling options (see torch_layers)' 395 | def __init__(self, 396 | in_ch : int, 397 | out_ch : int = 1, 398 | inBlock : tuple | list = ('nConvBlocks', {}), 399 | encoderBlock : tuple | list = ('nConvBlocks', {}), 400 | skipBlock : tuple | list = ('Identity', {}), 401 | bottleneck : tuple | list | None = None, 402 | decoderBlock : tuple | list = ('nConvBlocks', {}), 403 | outBlock : tuple | list = ('nConvBlocks', {'res' : False}), 404 | num_layers : int = 2, 405 | channel : int | list = 32, 406 | depth : int | None = 5, 407 | activation : str = 'ReLU', 408 | kernel_size : int = 3, 409 | batchnorm : bool = True, 410 | bn_eps : float = 1e-05, 411 | dropout : tuple | float = 0, 412 | img_size : int = 256, 413 | res : bool = True, 414 | params_down : dict[str,int|str] = {}, 415 | params_up : dict[str,int|str] = {}, 416 | ### other params to be passed to LitCNN super class 417 | previous_model : str = None, 418 | **kwargs, 419 | ): 420 | 421 | self.name = 'LitUNet' 422 | self.save_hyperparameters() 423 | 424 | super().__init__( 425 | model = Modular( 426 | in_ch=in_ch + 1 * (previous_model is not None), 427 | out_ch=out_ch, 428 | inBlock=inBlock, 429 | encoderBlock=encoderBlock, 430 | skipBlock=skipBlock, 431 | bottleneck=bottleneck, 432 | decoderBlock=decoderBlock, 433 | outBlock=outBlock, 434 | num_layers=num_layers, 435 | channel=channel, 436 | depth=depth, 437 | activation=activation, 438 | kernel_size=kernel_size, 439 | batchnorm=batchnorm, 440 | bn_eps=bn_eps, 441 | dropout=dropout, 442 | img_size=img_size, 443 | params_down=params_down, 444 | params_up=params_up, 445 | res=res), 446 | previous_model=previous_model, 447 | **kwargs 448 | ) 449 | 450 | class LitUNet_ViT(LitUNet): 451 | '''UNet with ViT from TransUNet in the bottleneck, not well tested''' 452 | def __init__(self, 453 | in_ch : int = 3, 454 | hidden_size : int = 768, 455 | grid_size : int = 16, 456 | num_layers : int = 12, 457 | nhead : int = 12, 458 | dropout : float = 0.1, 459 | attn_bias : bool = False, 460 | dim_feedforward : int | None = None, 461 | previous_model : str = None, 462 | **kwargs 463 | ): 464 | super().__init__( 465 | bottleneck=('ViT', dict( 466 | hidden_size=hidden_size, 467 | grid_size=grid_size, 468 | num_layers=num_layers, 469 | nhead=nhead, 470 | dropout=dropout, 471 | attn_bias=attn_bias, 472 | dim_feedforward=dim_feedforward 473 | )), 474 | in_ch=in_ch + 1 * (previous_model is not None), 475 | previous_model=previous_model, 476 | **kwargs) 477 | self.name = 'LitUNet_ViT' 478 | 479 | class LitUNetDCN(LitCNN): 480 | '''UNet with deformable convolutions''' 481 | def __init__(self, 482 | in_ch, 483 | lr = 1e-4, 484 | params_down={}, 485 | params_up={}, 486 | dropout=0, 487 | channel=32, 488 | depth=5, 489 | img_size=256, 490 | batchnorm=True, 491 | res=True, 492 | batchnorm_dcn=True, 493 | deactivate_last_res=False, 494 | previous_model : str = None, 495 | skip_first=True, 496 | **kwargs): 497 | self.name = 'LitUNetDCN' 498 | outBlock = [('nConvBlocks', {})] if not deactivate_last_res else [('nConvBlocks', {'res' : False})] 499 | super().__init__( 500 | model = Modular( 501 | in_ch=in_ch + 1 * (previous_model is not None), 502 | out_ch=1, 503 | channel=channel, 504 | depth=depth, 505 | res=res, 506 | batchnorm=batchnorm, 507 | inBlock=[('nConvBlocks', {})], 508 | outBlock=outBlock, 509 | decoderBlock=[('nConvBlocks', {}), ('DCNBlock', dict(batchnorm=batchnorm_dcn))], 510 | bottleneck=[('nConvBlocks', {}), ('DCNBlock', dict(batchnorm=batchnorm_dcn))], 511 | encoderBlock=[('convBlock', {}), ('DCNBlock', dict(batchnorm=batchnorm_dcn))], 512 | skipBlock=('Identity', {}), 513 | params_down=params_down, 514 | params_up=params_up, 515 | dropout=dropout, 516 | img_size=img_size, 517 | skip_first=skip_first 518 | ), 519 | previous_model=previous_model, 520 | lr=lr, 521 | **kwargs) 522 | 523 | class LitUNet_DCN_old2(LitUNetDCN): 524 | '''Legacy alias for LitUNetDCN (for loading old checkpoints etc)''' 525 | def __init__(self, *args, **kwargs): 526 | super().__init__(*args, **kwargs) 527 | 528 | ''' 529 | models from other authors 530 | ''' 531 | class LitPMNet(LitCNN): 532 | def __init__(self, 533 | in_ch : int = 3, 534 | n_blocks : list = [3, 3, 27, 3], 535 | atrous_rates : list = [6, 12, 18], 536 | multi_grids : list | None = [1, 2, 4], 537 | output_stride : int = 8, 538 | ceil_mode : bool = True, 539 | output_padding : tuple[int, int] = (1, 1), 540 | bn_eps : float = 1e-05, 541 | dcn = False, 542 | previous_model : str = None, 543 | **kwargs 544 | ): 545 | from .PMNet.PMNet import PMNet 546 | ### workaround 547 | assert multi_grids is None or len(multi_grids)==n_blocks[-1] 548 | model = PMNet(in_ch=in_ch + 1 * (previous_model is not None), n_blocks=n_blocks, atrous_rates=atrous_rates, multi_grids=multi_grids, output_stride=output_stride, ceil_mode=ceil_mode, output_padding=output_padding, bn_eps=bn_eps, dcn=dcn) 549 | self.name = "LitPMNet" 550 | self.save_hyperparameters() 551 | super().__init__(model=model, previous_model=previous_model, **kwargs) 552 | 553 | class LitRadioUNet(LitCNN): 554 | def __init__(self, 555 | in_ch : int = 3, 556 | previous_model : str = None, 557 | **kwargs 558 | ): 559 | from .RadioUNet.RadioUNet import RadioWNet 560 | model = RadioWNet(inputs=in_ch + 1 * (previous_model is not None)) 561 | self.name = "LitRadioUNet" 562 | self.save_hyperparameters() 563 | super().__init__(model=model, previous_model=previous_model, **kwargs) 564 | 565 | class LitRadioUNet2(LitCNN): 566 | def __init__(self, 567 | in_ch : int = 3, 568 | previous_model : str = None, 569 | **kwargs 570 | ): 571 | from .RadioUNet.RadioUNet import RadioWNet2 572 | model = RadioWNet2(inputs=in_ch + 1 * (previous_model is not None)) 573 | self.name = "LitRadioUNet2" 574 | self.save_hyperparameters() 575 | super().__init__(model=model, previous_model=previous_model, **kwargs) 576 | 577 | ### register all our LitModels in this dict to link strings to class names 578 | ### this is needed if we want to train more than one model in curricuum (as in RadioUNet) 579 | LitModelDict = { 580 | 'LitUNet' : LitUNet, 581 | 'LitUNet_ViT' : LitUNet_ViT, 582 | 'LitPMNet' : LitPMNet, 583 | 'LitUNet_DCN_old2' : LitUNet_DCN_old2, 584 | 'LitRadioUNet' : LitRadioUNet, 585 | 'LitRadioUNet2' : LitRadioUNet2, 586 | } 587 | 588 | 589 | # for colorbars of the same size as the image 590 | def colorbar(mappable): 591 | from mpl_toolkits.axes_grid1 import make_axes_locatable 592 | import matplotlib.pyplot as plt 593 | last_axes = plt.gca() 594 | ax = mappable.axes 595 | fig = ax.figure 596 | divider = make_axes_locatable(ax) 597 | cax = divider.append_axes("right", size="5%", pad=0.05) 598 | cbar = fig.colorbar(mappable, cax=cax) 599 | plt.sca(last_axes) 600 | return cbar 601 | 602 | def show_samples( 603 | batch : tuple, 604 | batch_id : int = 0, 605 | output : None | torch.Tensor = None, 606 | filename : None | str | Path = None 607 | ): 608 | inputs, target = batch[0][batch_id,:].detach().cpu().to(torch.float32), torch.squeeze(batch[1][batch_id,:]).detach().cpu().to(torch.float32) 609 | n_inp = inputs.shape[0] 610 | cols = 5 611 | rows = int(ceil((n_inp + 1 + (output is not None)) / 5)) 612 | fig = plt.figure(figsize=(4*cols, 3*rows)) 613 | fig.add_subplot(rows, cols, 1) 614 | plt.imshow(target, vmin=0, vmax=1) 615 | plt.colorbar() 616 | plt.title('target') 617 | b = 2 618 | if output is not None: 619 | output = torch.squeeze(output[batch_id,:].detach().cpu().to(torch.float32)) 620 | fig.add_subplot(rows, cols, 2) 621 | plt.imshow(output) 622 | plt.colorbar() 623 | plt.title('output') 624 | b += 1 625 | for i in range(n_inp): 626 | fig.add_subplot(rows, cols, i + b) 627 | ### inputs are either in the range [0,1] or [-1,1] or [-1,0] 628 | if torch.all(inputs[i,:] >=0): 629 | vmin = 0 630 | else: 631 | vmin = -1 632 | if torch.all(inputs[i,:] <=0): 633 | vmax = 0 634 | else: 635 | vmax = 1 636 | plt.imshow(inputs[i,:], vmin=vmin, vmax=vmax) 637 | plt.colorbar() 638 | plt.title(f'input {i}') 639 | if filename is not None: 640 | plt.savefig(filename) 641 | return fig -------------------------------------------------------------------------------- /lib/torch_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | import numpy as np 4 | import json 5 | from skimage import io 6 | from operator import itemgetter 7 | from . import utils_coords 8 | from torchvision import disable_beta_transforms_warning 9 | disable_beta_transforms_warning() 10 | import torchvision.transforms.v2 as T 11 | 12 | class RM(torch.utils.data.Dataset): 13 | ''' 14 | Generic dataset class for Radio Maps. Defines some logic we use for all different datasets and some methods that have to be defined by subclasses. 15 | ''' 16 | def __init__(self, 17 | ids : list, 18 | dataset_path : str | Path, 19 | augmentation : bool, 20 | get_ids : bool = False 21 | ) -> None: 22 | ''' 23 | Inputs: 24 | ids - list of ids, ids will be passed to __getsample__ of subclass 25 | dataset_path - where to find the dataset directory 26 | augmentation - whether to use random flips and rotations 27 | get_ids - whether to get sample id with each sample (used in inference) 28 | ''' 29 | self.ids = ids 30 | self.dataset_path = Path(dataset_path) 31 | assert self.dataset_path.is_dir(), f"Please check the path to the dataset ({dataset_path}), not a directory." 32 | 33 | self.t1 = T.ToImageTensor() 34 | if augmentation: 35 | self.transforms = T.Compose([ 36 | ### this combination of transforms allows us to obtain exactly the 8 ways of flipping and rotating that keep the image shape 37 | # T.ToImageTensor(), 38 | T.ConvertImageDtype(torch.float32), 39 | T.RandomHorizontalFlip(p=0.5), 40 | T.RandomVerticalFlip(p=0.5), 41 | T.RandomApply([T.RandomRotation((90, 90))], p=0.5), 42 | ]) 43 | else: 44 | self.transforms = T.Compose([ 45 | # T.ToImageTensor(), 46 | T.ConvertImageDtype(torch.float32), 47 | ]) 48 | self.get_ids = get_ids 49 | 50 | def get_in_ch(*args, **kwargs) -> int: 51 | ''' 52 | Dummy method. 53 | ''' 54 | raise NotImplementedError('subclasses of RM have to implement get_in_ch') 55 | 56 | def get_PL_scale(*args, **kwargs): 57 | ''' 58 | Dummy method. 59 | ''' 60 | raise NotImplementedError('subclasses or RMhave to implement get_PL_scale') 61 | 62 | def __len__(self) -> int: 63 | ''' 64 | Total number of samples. 65 | ''' 66 | return len(self.ids) 67 | 68 | def __getsample__(self, idx : int) -> tuple | list: 69 | ''' 70 | Dummy method to be implemented by subclasses. This does the job usually __getitem__ does. 71 | ''' 72 | raise NotImplementedError('subclasses or RMhave to implement __getsample__') 73 | 74 | def __getitem__(self, idx : int) -> tuple | list: 75 | ''' 76 | Calls method __getsample__ that all subclasses have to implement. Applies augmentation and, if requested, adds ids for inference. 77 | ''' 78 | curr_id = self.ids[idx] 79 | sample = list(self.__getsample__(curr_id)) 80 | for i in range(len(sample)): 81 | ### ToImageTensor doesn't work on dicts 82 | if isinstance(sample[i], dict): 83 | for k, v in sample[i].items(): 84 | sample[i][k] = self.t1(v) 85 | sample[i] = self.t1(sample[i]) 86 | sample = self.transforms(sample) 87 | 88 | if self.get_ids: 89 | return *sample, curr_id 90 | else: 91 | return sample 92 | 93 | class RM_directional(RM): 94 | ''' 95 | PyTorch dataset containing radio maps, corresponding gis layers (nDSMs of buildings and vegetation) and transmitter positions and parameters (gain, direction). 96 | Additional information based on this data like coordinate systems, line-of-sight-maps can also be generated. 97 | ''' 98 | def __init__(self, 99 | dataset_path : str | Path, 100 | ### inputs 101 | tx_one_hot : bool, 102 | ndsms : bool, 103 | ndsm_all : bool, 104 | dist2d : bool, 105 | dist2d_log : bool, 106 | coords_euclidian : bool, 107 | coords_cylindrical : bool, 108 | coords_spherical : bool, 109 | coords_ga : bool, 110 | los_floor : bool, 111 | los_top : bool, 112 | los_z_min : bool, 113 | los_z_min_rel : bool, 114 | los_theta_max : bool, 115 | ant_gain_floor : bool, 116 | ant_gain_top : bool, 117 | ant_gain_slices : bool, 118 | ant_gain_los_slices : bool, 119 | ant_gain_floor_dist : bool, 120 | ant_gain_top_dist : bool, 121 | ant_gain_slices_dist : bool, 122 | ant_gain_los_slices_dist : bool, 123 | img : bool, 124 | img_rgb : bool, 125 | elevation : bool, 126 | azimuth : bool, 127 | ### values for scaling and cutting height values 128 | z_step : int, 129 | z_max : int, 130 | thresh : float, 131 | ### activate to pass masks when testing (allows to calculate loss in LoS/non-LoS areas) 132 | test : bool = False, 133 | ### kwargs for super class 134 | **kwargs 135 | ) -> None: 136 | ''' Arguments: 137 | dataset_path - string or path; Path to dataset main folder. 138 | 139 | inputs - Several inputs can be turned on by setting cooresponding bool to True, all of these are one or several 256 x 256 tensors according to the map 140 | tx_one_hot - 1 in tx position, 0 else 141 | ndsms - ndsms of buildings and vegetation separately 142 | ndsm_all - one ndsm not separated by classes (mainly used to see how ndsm + image without class labels performs, as this information might be easier to obtain) 143 | dist2d - distance in x-y plance from Tx to each point 144 | dist2d_log - log_10(dist2d) , corresponding to the effect of the distance in free space path loss in dB (linear transform should be learned by the network) 145 | coordinates_X - coordinates of the floor, building top and vegetation top for a coordinates system with center in the Tx position 146 | e.g. coordinates_cylindrical produces five channels: phi, r, and for each of floor/building/vegetation the z_value 147 | GA is grid anchor as proposed in RadioTrans paper (https://ieeexplore.ieee.org/document/9753644) 148 | los_flor/top - 1/0 for los/non-los, both top of buildings and floor 149 | los_z_min - minimal z-value for visibility (from the floor, as in nDSMs, GA) 150 | los_z_min_rel - minimal z-value for visibility, relative to tx height (euclidian/cylindrical coordinates) 151 | los_theta_max - maximal theta value for visibility, relative to tx position (spherical coordinates) 152 | ant_gain_floor/top - antenna gain in dB projected onto the floor of the map/ building top 153 | ant_gain_slices - antenna gain projected onto horizontal planes, according to z_step, z_max (see below) 154 | ant_gain_los_slices - like ant_gain_slices, but additionally contains LoS information, gain in voxels with no LoS is set to 0 155 | ant_gain_X_dist - gain in dB - 2*log_10(dist2d) , corresponding to free space path loss 156 | img - aerial image (RGBI) 157 | img_rgb - aerial image, only RGB channels 158 | elevation - tilt/elevation angle (spherical coordinates) 159 | azimuth - azimuth angle (spherical/cylindrical coordinates) 160 | z_step - int; This is only used for ant_gain_slices and determines to how many meters in z-direction each slice corresponds. The maximum height considered in the 161 | dataset is 32m, so with e.g. z_step=4 we produce 8 slices for the heights 0-4, 4-8,...,28-32 162 | z_max - int; Maximum height value relevant for this dataset. This is used to cut off gis layers above this value and rescale height values linearly from [0, z_max] to [0, 1]. 163 | thresh - float; To cut off lower part of the radio map 164 | test - bool; pass True to receive masks for LoS, areas directly illuminated by Tx for additional metrics 165 | The output of the get_item method is a tuple of two tensors, first is all inputs stacked along the channel dimension, second one is the target radio map. 166 | All inputs and the target are normalized and in the case of heights cut to values in the interval [-1,1]. 167 | ''' 168 | 169 | self.antenna_gains = {} 170 | for ant_file in (dataset_path / 'antenna_patterns').glob('pattern_*.npy'): 171 | try: 172 | ant_id = int(ant_file.stem.split('_')[-1]) 173 | except ValueError as e: 174 | # print(e) 175 | continue 176 | self.antenna_gains[ant_id] = np.load(ant_file) 177 | 178 | self.tx_one_hot = tx_one_hot 179 | self.ndsms = ndsms 180 | self.ndsm_all = ndsm_all 181 | self.dist2d = dist2d 182 | self.dist2d_log = dist2d_log 183 | self.coords_euclidian = coords_euclidian 184 | self.coords_cylindrical = coords_cylindrical 185 | self.coords_spherical = coords_spherical 186 | self.coords_ga = coords_ga 187 | self.los_floor = los_floor 188 | self.los_top = los_top 189 | self.los_z_min = los_z_min 190 | self.los_z_min_rel = los_z_min_rel 191 | self.los_theta_max = los_theta_max 192 | self.ant_gain_floor = ant_gain_floor 193 | self.ant_gain_top = ant_gain_top 194 | self.ant_gain_slices = ant_gain_slices 195 | self.ant_gain_los_slices = ant_gain_los_slices 196 | self.ant_gain_floor_dist = ant_gain_floor_dist 197 | self.ant_gain_top_dist = ant_gain_top_dist 198 | self.ant_gain_slices_dist = ant_gain_slices_dist 199 | self.ant_gain_los_slices_dist = ant_gain_los_slices_dist 200 | self.img = img 201 | self.img_rgb = img_rgb 202 | self.azimuth = azimuth 203 | self.elevation = elevation 204 | 205 | self.z_step = z_step 206 | self.z_max = z_max 207 | self.thresh = thresh ## explain. may be 0 208 | 209 | self.test = test 210 | 211 | super().__init__(dataset_path=dataset_path, **kwargs) 212 | 213 | def get_in_ch( 214 | tx_one_hot=False, 215 | ndsms=False, 216 | ndsm_all=False, 217 | dist2d=False, 218 | dist2d_log=False, 219 | coords_euclidian=False, 220 | coords_cylindrical=False, 221 | coords_spherical=False, 222 | coords_ga=False, 223 | los_floor=False, 224 | los_top=False, 225 | los_z_min=False, 226 | los_z_min_rel=False, 227 | los_theta_max=False, 228 | ant_gain_floor=False, 229 | ant_gain_top=False, 230 | ant_gain_slices=False, 231 | ant_gain_los_slices=False, 232 | ant_gain_floor_dist=False, 233 | ant_gain_top_dist=False, 234 | ant_gain_slices_dist=False, 235 | ant_gain_los_slices_dist=False, 236 | img=False, 237 | img_rgb=False, 238 | z_max=32, 239 | z_step=4, 240 | ####### 241 | azimuth=False, 242 | elevation=False, 243 | ####### 244 | ### catch all arguments not needed for calculating channel 245 | *args, **kwargs) -> int: 246 | ''' 247 | Calculates the number of in channels for the network for the given data configuration. 248 | The way linking of arguments in the CLI works, we have to make this avaailable in the __init__ of the DataModule, therefore we cannot work with the self attributes here. 249 | *args, **kwargs don't do anything, these are just defined so we can throw in the whole configuration of the dataset at once. 250 | 251 | Output : int; number of channels of input tensor 252 | ''' 253 | return tx_one_hot + dist2d + dist2d_log + los_floor + los_top + los_z_min + los_z_min_rel + los_theta_max + ant_gain_floor + ant_gain_top + ant_gain_floor_dist + ant_gain_top_dist + azimuth \ 254 | + 2 * ndsms + ndsm_all \ 255 | + 3 * (elevation + img_rgb) \ 256 | + 4 * img \ 257 | + 5 * (coords_euclidian + coords_cylindrical + coords_ga) \ 258 | + 7 * coords_spherical \ 259 | + int(z_max // z_step) * (ant_gain_slices + ant_gain_los_slices + ant_gain_slices_dist + ant_gain_los_slices_dist) 260 | 261 | def get_PL_scale(dataset_path : str | Path = 'dataset', **kwargs): 262 | ''' 263 | Reads PL scale from file containing information about PL threshold and max PL. 264 | ''' 265 | print(dataset_path) 266 | with open(Path(dataset_path) / 'max_power.json', 'r') as f: 267 | power_data = json.load(f) 268 | return float(power_data['pg_max']) - float(power_data['pg_trnc']) 269 | 270 | def __getsample__(self, curr_id : tuple): 271 | ''' 272 | The function loads all data for the current idx and generates requested inputs. 273 | Output: 274 | (inputs, target) - tuple of two tensors; first is all inputs stacked along the channel dimension, second one is the target radio map 275 | if test is passed in init, additionally the tuple contains a dict of names and masks 276 | ''' 277 | if self.test: 278 | los_floor = None 279 | ant_gain_floor = None 280 | 281 | with open(self.dataset_path / "tx_antennas" / "{}_{}_{}_txparams.json".format(*curr_id[:3]), 'r') as f: 282 | tx_antenna_params = json.load(f)[curr_id[3]] 283 | tx_coords, tx_phi, tx_theta, tx_antenna = itemgetter('tx_coords', 'phi', 'theta', 'antenna')(tx_antenna_params) 284 | 285 | tx_antenna_pattern = self.antenna_gains[tx_antenna] 286 | 287 | nbuild = torch.minimum(torch.tensor(io.imread(self.dataset_path / "gis" / "nbuildings_{}_{}_{}.png".format(*curr_id[:3])), dtype=torch.float32), torch.tensor(self.z_max)) 288 | nveg = torch.minimum(torch.tensor(io.imread(self.dataset_path / "gis" / "nveg_{}_{}_{}.png".format(*curr_id[:3])), dtype=torch.float32), torch.tensor(self.z_max)) 289 | 290 | target = torch.tensor(io.imread(self.dataset_path / "path_gain" / "pg_{}_{}_{}_{}.png".format(*curr_id)), dtype=torch.float32) / 255 291 | target = torch.maximum((torch.reshape(target, (1, *target.shape))- self.thresh) / (1 - self.thresh), torch.tensor([0])) 292 | 293 | inputs = [] 294 | 295 | ### always generate spherical coordinates, these are needed for generating gain and LoS input as well 296 | dist_3d_build, dist_3d_veg, dist_3d_floor, theta_build, theta_veg, theta_floor, phi = utils_coords.spherical_coords(nbuild=nbuild, nveg=nveg, tx_coords=tx_coords, phi_base=tx_phi, theta_base=tx_theta) 297 | 298 | if self.los_floor or self.los_top or self.los_theta_max or self.los_z_min or self.los_z_min_rel or self.ant_gain_los_slices or self.test: 299 | los_theta_max = torch.minimum(torch.tensor(np.load(self.dataset_path / "los" / "theta_max_{}_{}_{}_{}.npy".format(*curr_id)), dtype=torch.float32), theta_floor) 300 | 301 | ### generate all requested input tensors, normalize and add them to inputs list 302 | if self.tx_one_hot: 303 | tx_one_hot_tens = torch.zeros_like(nbuild) 304 | tx_one_hot_tens[tx_coords[0], tx_coords[1]] = tx_coords[2] / self.z_max 305 | inputs.append(tx_one_hot_tens) 306 | if self.ndsms: 307 | inputs.append(nbuild / self.z_max) 308 | inputs.append(nveg / self.z_max) 309 | if self.ndsm_all: 310 | inputs.append(torch.where(nbuild > 0, nbuild, nveg) / self.z_max) 311 | if self.dist2d: 312 | inputs.append(utils_coords.dist_2d(tx_coords=tx_coords) / (np.sqrt(2)*255)) 313 | if self.dist2d_log: 314 | dist2d = utils_coords.dist_2d(tx_coords=tx_coords) 315 | ### we change the value at the Tx position, as log(0) doesn't make sense 316 | ### in this position, there will always be a building anyways and hence 0 path loss, and the usual path loss formula with log(dist) doesn't hold here 317 | dist2d[dist2d==0] = 1 318 | inputs.append(torch.log10(dist2d) / np.log10(np.sqrt(2)*255)) 319 | if self.coords_euclidian: 320 | xc, yc, zc_build, zc_veg, zc_floor = utils_coords.euclidian_coords(nbuild=nbuild, nveg=nveg, tx_coords=tx_coords) 321 | inputs.extend([xc / 255, yc / 255, zc_build / self.z_max, zc_veg / self.z_max, zc_floor / self.z_max]) 322 | if self.coords_cylindrical: 323 | dist2d, phi_cyl, zc_build, zc_veg, zc_floor = utils_coords.cylindrical_coords(nbuild=nbuild, nveg=nveg, tx_coords=tx_coords, phi_base=tx_phi) 324 | inputs.extend([dist2d / (np.sqrt(2)*255), phi_cyl / torch.pi, zc_build / self.z_max, zc_veg / self.z_max, zc_floor / self.z_max]) 325 | if self.coords_spherical: 326 | max_dist_3d = np.sqrt(255**2 + 255**2 + self.z_max**2) 327 | inputs.extend([dist_3d_build / max_dist_3d, dist_3d_veg / max_dist_3d, dist_3d_floor / max_dist_3d, theta_build / torch.pi, theta_veg / torch.pi, theta_floor / torch.pi, phi / torch.pi]) 328 | if self.coords_ga: 329 | xt, yt, zt, xs, ys = utils_coords.GA_coords(tx_coords=tx_coords) 330 | inputs.extend([xt / 256, yt / 256, zt / self.z_max, xs / 256, ys / 256]) 331 | ### for binary LoS, add a small constant to los_theta_max, otherwise we have small holes/distortions in the LoS maps due to rounding errors 332 | if self.los_floor: 333 | los_floor = (los_theta_max + 1e-4 >= theta_floor).to(dtype=torch.float32) 334 | inputs.append(los_floor) 335 | if self.los_top: 336 | inputs.append((los_theta_max + 1e-4 >= theta_build).to(dtype=torch.float32)) 337 | if self.los_z_min: 338 | los_z_min = utils_coords.get_heights(theta=los_theta_max, dist_2d=utils_coords.dist_2d(tx_coords=tx_coords), tx_z=tx_coords[2], theta_base=tx_theta, z_max=self.z_max) 339 | inputs.append(los_z_min / self.z_max) 340 | if self.los_z_min_rel: 341 | los_z_min_rel = utils_coords.get_heights(theta=los_theta_max, dist_2d=utils_coords.dist_2d(tx_coords=tx_coords), tx_z=tx_coords[2], theta_base=tx_theta, z_max=self.z_max) - tx_coords[2] 342 | inputs.append(los_z_min_rel / self.z_max) 343 | if self.los_theta_max: 344 | inputs.append(los_theta_max / torch.pi) 345 | if self.ant_gain_floor: 346 | ant_gain_floor = utils_coords.project_gain(phi=phi, theta=theta_floor, gain_array=tx_antenna_pattern, normalize=True) 347 | inputs.append(ant_gain_floor) 348 | if self.ant_gain_top: 349 | inputs.append(utils_coords.project_gain(phi=phi, theta=theta_build, gain_array=tx_antenna_pattern, normalize=True)) 350 | if self.ant_gain_slices: 351 | theta3d, _ = utils_coords.spherical_slices(tx_coords=tx_coords, theta_base=tx_theta, z_max=self.z_max, z_step=self.z_step) 352 | inputs.extend([utils_coords.project_gain(phi=phi, theta=theta3d[k,:], gain_array=tx_antenna_pattern, normalize=True) for k in range(theta3d.shape[0])]) 353 | if self.ant_gain_los_slices: 354 | theta3d, _ = utils_coords.spherical_slices(tx_coords=tx_coords, theta_base=tx_theta, z_max=self.z_max, z_step=self.z_step) 355 | inputs.extend([utils_coords.project_gain(phi=phi, theta=theta3d[k,:], gain_array=tx_antenna_pattern, theta_max=los_theta_max, normalize=True) for k in range(theta3d.shape[0])]) 356 | if self.ant_gain_floor_dist: 357 | inputs.append(utils_coords.project_gain(phi=phi, theta=theta_floor, gain_array=tx_antenna_pattern, normalize=True, dist=utils_coords.dist_3d(tx_coords))) 358 | if self.ant_gain_top_dist: 359 | inputs.append(utils_coords.project_gain(phi=phi, theta=theta_build, gain_array=tx_antenna_pattern, normalize=True, dist=utils_coords.dist_3d(tx_coords))) 360 | if self.ant_gain_slices_dist: 361 | theta3d, dist3d = utils_coords.spherical_slices(tx_coords=tx_coords, theta_base=tx_theta, z_max=self.z_max, z_step=self.z_step) 362 | inputs.extend([utils_coords.project_gain(phi=phi, theta=theta3d[k,:], gain_array=tx_antenna_pattern, normalize=True, dist=dist3d[k,:], z_max=self.z_max) for k in range(theta3d.shape[0])]) 363 | if self.ant_gain_los_slices_dist: 364 | theta3d, dist3d = utils_coords.spherical_slices(tx_coords=tx_coords, theta_base=tx_theta, z_max=self.z_max, z_step=self.z_step) 365 | inputs.extend([utils_coords.project_gain(phi=phi, theta=theta3d[k,:], gain_array=tx_antenna_pattern, theta_max=los_theta_max, normalize=True, dist=dist3d[k,:], z_max=self.z_max) for k in range(theta3d.shape[0])]) 366 | if self.azimuth: 367 | inputs.append(phi) 368 | if self.elevation: 369 | inputs.extend([theta_build / torch.pi, theta_floor / torch.pi, theta_veg / torch.pi]) 370 | if self.img: 371 | img_arr = io.imread(self.dataset_path / "img" / "img_{}_{}_{}.tif".format(*curr_id[:3])) 372 | for i in range(img_arr.shape[-1]): 373 | inputs.append(torch.tensor(img_arr[:,:,i], dtype=torch.float32) / 255) 374 | if self.img_rgb: 375 | img_arr = io.imread(self.dataset_path / "img" / "img_{}_{}_{}.tif".format(*curr_id[:3])) 376 | for i in range(3): 377 | inputs.append(torch.tensor(img_arr[:,:,i], dtype=torch.float32) / 255) 378 | 379 | inputs = torch.stack(inputs, dim=0) 380 | 381 | if self.test: 382 | if ant_gain_floor is None: 383 | ant_gain_floor = utils_coords.project_gain(phi=phi, theta=theta_floor, gain_array=tx_antenna_pattern, normalize=True) 384 | if los_floor is None: 385 | los_floor = (los_theta_max + 1e-4 >= theta_floor).to(dtype=torch.float32) 386 | return inputs, target, {'gain_los_floor' : 1.0*((los_floor * ant_gain_floor) == 0), 'gain_floor' : 1.0*(ant_gain_floor == 0)} 387 | else: 388 | return inputs, target 389 | 390 | -------------------------------------------------------------------------------- /lib/torch_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | import numpy as np 5 | from .dcn import DeformableConv2d 6 | 7 | '''Standard building blocks for CNN''' 8 | 9 | def get_padding(mode="same", kernel_size=3, dilation=1): 10 | '''returns value for padding and potentially output_padding to maintain input size or increase/reduce it by a factor of 2''' 11 | # same: applying normal conv, retaining dimensions, assuming stride==1 12 | if mode=="same": 13 | if dilation%2==1 and kernel_size%2==0: 14 | raise Exception(f"invalid padding parameters: mode {mode}, kernel_size {kernel_size}, dilation {dilation}") 15 | return int(0.5 * dilation * (kernel_size -1)) 16 | # down sampling with stride 2: 17 | elif mode=="down": 18 | return int(dilation * (kernel_size - 1) //2) 19 | # up sampling with stride 2 and convTranspose 20 | elif mode=="up": 21 | output_padding = 0 if dilation * (kernel_size - 1) % 2 ==1 else 1 22 | padding = int((dilation * (kernel_size - 1) + output_padding - 1) // 2) 23 | return padding, output_padding 24 | else: 25 | raise ValueError(f"get_padding got mode={mode}") 26 | 27 | class nConvBlocks(nn.Module): 28 | '''repeated convBlocks, see below''' 29 | def __init__(self, in_ch, num_layers=2, out_ch=None, mid_factor=None, kernel_size=3, stride=1, dilation=1, dropout=0, batchnorm=True, bn_eps : float = 1e-05, activation='ReLU', res=False, **kwargs) -> None: 30 | super().__init__() 31 | 32 | if out_ch is None: 33 | out_ch = in_ch 34 | if mid_factor is None: 35 | mid_ch = out_ch 36 | else: 37 | mid_ch = mid_factor * in_ch 38 | 39 | # for __repr__ function 40 | self.repr = f"nConvBlocks, in_ch={in_ch}, num_layers={num_layers}, mid_ch={mid_ch}, out_ch={out_ch}, kernel_size={kernel_size}, stride={stride}, dilation={dilation}, res={res}" 41 | 42 | mod_list = nn.ModuleList() 43 | for i in range(num_layers): 44 | in_ch_here = in_ch if i==0 else mid_ch 45 | out_ch_here = out_ch if i==num_layers-1 else mid_ch 46 | mod_list.append(convBlock(in_ch=in_ch_here, out_ch=out_ch_here, kernel_size=kernel_size, stride=stride, dilation=dilation, dropout=dropout, batchnorm=batchnorm, 47 | bn_eps=bn_eps, activation=activation)) 48 | 49 | self.mod_list = nn.Sequential(*mod_list) 50 | ### if we use residual skip connection and in_ch doesn't match out_ch, we need to adjust 51 | self.ch_proj = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=False) if res and in_ch != out_ch else nn.Identity() 52 | 53 | ### store some of the parameters needed in the forward pass 54 | self.res=res 55 | 56 | def forward(self, x): 57 | y = self.mod_list(x) 58 | if self.res: 59 | residual = self.ch_proj(x) 60 | return residual + y 61 | else: 62 | return y 63 | 64 | def __repr__(self): 65 | return self.repr 66 | 67 | class convBlock(nn.Sequential): 68 | def __init__(self, in_ch, out_ch=None, kernel_size=3, stride=1, dilation=1, dropout=0, batchnorm=True, bn_eps : float = 1e-05, activation='ReLU', **kwargs) -> None: 69 | super().__init__() 70 | 71 | padd = get_padding(mode="same", kernel_size=kernel_size, dilation=dilation) 72 | 73 | act = getattr(nn, activation) 74 | 75 | if out_ch is None: 76 | out_ch = in_ch 77 | 78 | mod_list = nn.ModuleList([nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padd, bias=not batchnorm, dilation=dilation)]) 79 | if batchnorm: 80 | mod_list.append(nn.BatchNorm2d(out_ch, eps=bn_eps)) 81 | mod_list.append(act()) 82 | if dropout > 0: 83 | mod_list.append(nn.Dropout2d(dropout)) 84 | 85 | # for __repr__ function 86 | self.repr = f"convBlock, in_ch={in_ch}, out_ch={out_ch}, kernel_size={kernel_size}, stride={stride}, dilation={dilation}" 87 | 88 | super().__init__(*mod_list) 89 | 90 | def __repr__(self): 91 | return self.repr 92 | 93 | class resNeXtBlock(nn.Module): 94 | def __init__(self, in_ch, out_ch=None, mid_ch=None, groups=32, mode="same", kernel_size=3, dilation=1, dropout=0, batchnorm=False, \ 95 | bn_eps : float = 1e-05, activation='ReLU', down_params=None, up_params=None, **kwargs): 96 | super().__init__() 97 | 98 | # standard 99 | if mid_ch is None: 100 | # choose mid_ch about in_ch/2 but divisible by 32 101 | mid_ch = int((in_ch // (2*groups)) * groups) 102 | if mid_ch == 0: 103 | mid_ch = groups 104 | if out_ch is None: 105 | out_ch = in_ch 106 | 107 | # repr 108 | self.repr = f"resNeXtBlock, in_ch={in_ch}, mid_ch={mid_ch}, out_ch={out_ch}" 109 | 110 | act = getattr(nn, activation) 111 | 112 | if mode=="same": 113 | stride = 1 114 | padd = get_padding(mode=mode, kernel_size=kernel_size, dilation=dilation) 115 | if in_ch != out_ch: 116 | self.transform_skip = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=False) 117 | else: 118 | self.transform_skip = nn.Identity() 119 | elif mode=="up": 120 | if up_params is not None: 121 | up_p = copy.deepcopy(up_params) 122 | else: 123 | up_p = dict() 124 | stride = 2 125 | padd, output_padding = get_padding(mode=mode, kernel_size=kernel_size, dilation=dilation) 126 | up_p["in_ch"] = in_ch 127 | up_p["out_ch"] = out_ch 128 | self.transform_skip = up(**up_p) 129 | elif mode=="down": 130 | if down_params is not None: 131 | down_p = copy.deepcopy(down_params) 132 | else: 133 | down_p = dict() 134 | stride = 2 135 | padd = get_padding(mode=mode, kernel_size=kernel_size, dilation=dilation) 136 | down_p["in_ch"] = in_ch 137 | down_p["out_ch"] = out_ch 138 | self.transform_skip = down(**down_p) 139 | else: 140 | raise NotImplementedError(f"mode {mode} in resNeXtBlock") 141 | 142 | self.mod_list = nn.Sequential(nn.Conv2d(in_ch, mid_ch, kernel_size=1, stride=1, padding=0)) 143 | 144 | if mode in ["same", "down"]: 145 | self.mod_list.append(nn.Conv2d(mid_ch, mid_ch, kernel_size=kernel_size, stride=stride, groups=groups, padding=padd, dilation=dilation, bias=not batchnorm)) 146 | else: 147 | self.mod_list.append(nn.ConvTranspose2d(mid_ch, mid_ch, kernel_size=kernel_size, stride=stride, groups=groups, padding=padd, output_padding=output_padding, bias=not batchnorm)) 148 | 149 | if batchnorm: 150 | self.mod_list.append(nn.BatchNorm2d(mid_ch, eps=bn_eps)) 151 | self.mod_list.append(nn.Conv2d(mid_ch, out_ch, kernel_size=1, stride=1, padding=0)) 152 | self.mod_list.append(act()) 153 | if dropout > 0: 154 | self.mod_list.append(nn.Dropout2d(dropout)) 155 | 156 | def forward(self, x): 157 | y = self.mod_list(x) 158 | 159 | residual = self.transform_skip(x) 160 | z = y + residual 161 | return z 162 | 163 | def __repr__(self): 164 | return self.repr 165 | 166 | class dilationBlock(nn.Module): 167 | ##### inspired by https://ieeexplore.ieee.org/document/9653079, puts blocks of the same type (e.g. conv, ResNeXt) in parallel with different dilation values 168 | def __init__(self, in_ch, out_ch=None, dilations=(1,2,3,4), block="resNeXtBlock", **block_params) -> None: 169 | super().__init__() 170 | if out_ch is None: 171 | out_ch = in_ch 172 | # repr 173 | self.repr = f"dilationBlock, in_ch={in_ch}, out_ch={out_ch}" 174 | 175 | ### copy the parameters for the block to not change them in place 176 | block_params_copy = copy.deepcopy(block_params) 177 | block_params_copy["in_ch"], block_params_copy["out_ch"] = in_ch, out_ch 178 | self.mod_list = nn.ModuleList() 179 | 180 | try: 181 | layer = globals()[block] 182 | except: 183 | layer = getattr(nn, block) 184 | 185 | for dilation in dilations: 186 | self.mod_list.append(layer(**block_params_copy, dilation=dilation)) 187 | ### 1x1 conv to go back to out_ch 188 | self.channel_conv = nn.Conv2d(len(dilations)*out_ch, out_ch, kernel_size=1, stride=1, padding=0) 189 | 190 | def forward(self, x): 191 | out = [] 192 | for layer in self.mod_list: 193 | out.append(layer(x)) 194 | ### concat all outputs along channel dimension and bring number of channels to out_ch 195 | return self.channel_conv(torch.cat(out, dim=1)) 196 | 197 | def __repr__(self): 198 | return self.repr 199 | 200 | def down(in_ch, sampling_down='max', kernel_size_down=2, out_ch=None, dilation_down=1, **kwargs): 201 | if out_ch is None: 202 | out_ch = in_ch 203 | padd = get_padding(mode="down", kernel_size=kernel_size_down, dilation=dilation_down) 204 | 205 | if sampling_down=='max': 206 | layer = nn.MaxPool2d(kernel_size_down, padding=padd, dilation=dilation_down) 207 | elif sampling_down=='conv': 208 | # strided conv 209 | layer = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size_down, stride=2, padding=padd, dilation=dilation_down), 210 | nn.ReLU()) 211 | elif sampling_down=="seg": 212 | layer = nn.MaxPool2d(kernel_size_down, padding=padd, dilation=dilation_down, return_indices=True) 213 | 214 | elif sampling_down=='avg': 215 | layer = nn.AvgPool2d(kernel_size_down, padding=padd) 216 | else: 217 | raise ValueError(f'sampling={sampling_down}') 218 | return layer 219 | 220 | def up(in_ch, sampling='conv', kernel_size_up=2, out_ch=None, dilation_up=1, **kwargs): 221 | if out_ch is None: 222 | out_ch = in_ch 223 | 224 | if sampling in ['bilinear', 'nearest']: 225 | layer = nn.Upsample(scale_factor=kernel_size_up, mode=sampling) 226 | elif sampling=='conv': 227 | padd, output_padd = get_padding(mode="up", kernel_size=kernel_size_up, dilation=dilation_up) 228 | layer = nn.Sequential(nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size_up, stride=kernel_size_up, padding=padd, output_padding=output_padd, dilation=dilation_up), 229 | nn.ReLU()) 230 | elif sampling=="seg": 231 | layer = nn.MaxUnpool2d(kernel_size=kernel_size_up) 232 | else: 233 | raise NotImplementedError("wrong upsampling ", sampling) 234 | return layer 235 | 236 | class DCNBlock(nn.Sequential): 237 | def __init__(self, in_ch, out_ch=None, kernel_size=3, stride=1, dropout=0, batchnorm=True, bn_eps : float = 1e-05, activation='ReLU', **kwargs) -> None: 238 | super().__init__() 239 | 240 | padd = get_padding(mode="same", kernel_size=kernel_size, dilation=1) 241 | 242 | act = getattr(nn, activation) 243 | 244 | if out_ch is None: 245 | out_ch = in_ch 246 | 247 | mod_list = nn.Sequential(DeformableConv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padd, bias=not batchnorm)) 248 | if batchnorm: 249 | mod_list.append(nn.BatchNorm2d(out_ch, eps=bn_eps)) 250 | mod_list.append(act()) 251 | if dropout > 0: 252 | mod_list.append(nn.Dropout2d(dropout)) 253 | 254 | # for __repr__ function 255 | self.repr = f"DCNBlock, in_ch={in_ch}, out_ch={out_ch}, kernel_size={kernel_size}, stride={stride}" 256 | 257 | super().__init__(mod_list) 258 | 259 | def __repr__(self): 260 | return self.repr 261 | 262 | class nDCNBlocks(nn.Module): 263 | '''''' 264 | def __init__(self, in_ch, num_layers=2, out_ch=None, mid_factor=None, kernel_size=3, stride=1, dropout=0, batchnorm=True, bn_eps : float = 1e-05, activation='ReLU', res=False, **kwargs) -> None: 265 | super().__init__() 266 | 267 | if out_ch is None: 268 | out_ch = in_ch 269 | if mid_factor is None: 270 | mid_ch = out_ch 271 | else: 272 | mid_ch = mid_factor * in_ch 273 | 274 | # for __repr__ function 275 | self.repr = f"nDCNBlocks, in_ch={in_ch}, num_layers={num_layers}, mid_ch={mid_ch}, out_ch={out_ch}, kernel_size={kernel_size}, stride={stride}, res={res}" 276 | 277 | mod_list = nn.Sequential() 278 | for i in range(num_layers): 279 | in_ch_here = in_ch if i==0 else mid_ch 280 | out_ch_here = out_ch if i==num_layers-1 else mid_ch 281 | mod_list.append(DCNBlock(in_ch=in_ch_here, out_ch=out_ch_here, kernel_size=kernel_size, stride=stride, dropout=dropout, batchnorm=batchnorm, 282 | bn_eps=bn_eps, activation=activation)) 283 | 284 | self.mod_list = nn.Sequential(*mod_list) 285 | ### if we use residual skip connection and in_ch doesn't match out_ch, we need to adjust 286 | self.ch_proj = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=False) if res and in_ch != out_ch else nn.Identity() 287 | 288 | ### store some of the parameters needed in the forward pass 289 | self.res=res 290 | 291 | 292 | def forward(self, x): 293 | y = self.mod_list(x) 294 | if self.res: 295 | residual = self.ch_proj(x) 296 | return residual + y 297 | else: 298 | return y 299 | 300 | def __repr__(self): 301 | return self.repr 302 | 303 | '''Building blocks for UNet-like CNN (encoder, decoder, skip-connections)''' 304 | 305 | class Modular(nn.Module): 306 | ''' 307 | general CNN with encoder-(bottleneck)-decoder plus skip connections structure 308 | encoder handles skip connections and operations on them, decoder connects them 309 | by default, each encoder block doubles the number of channels and each decoder block halves them, the bottleneck uses 2*channels[0] as mid_ch if this is possible. 310 | img_size has to be given for ViT 311 | blocks are defined as lists of tuples, tuples consist of a string referring to a Module defined by us or in torch and the parameters given to the Module 312 | blocks in the encoder and decoder are repeated according to number of channels 313 | skip connections are addded from the inBlock and each of the repeated encoder blocks to the decoder blocks and the outBlock, where their output gets concatenated 314 | 315 | some common params used in several architectures that will be tried to be passed to all layers where they make sense (/inBlock/outBlock): 316 | activation, last_activation, kernel_size, kernel_size_in, kernel_size_out, batchnorm, dropout 317 | will be override by parameters explicitely given in the definitions of the blocks 318 | 319 | ''' 320 | def __init__(self, 321 | in_ch : int, 322 | out_ch : int , 323 | inBlock : list | tuple, 324 | encoderBlock : list | tuple, 325 | skipBlock : list | tuple, 326 | bottleneck : list | tuple | None, 327 | decoderBlock : list | tuple, 328 | outBlock : list | tuple, 329 | channel : int | list[int], 330 | depth : int | None , 331 | dropout : tuple | float | None, 332 | img_size : int, 333 | params_down : dict, 334 | params_up : dict, 335 | skip_first=False, 336 | ### other params like stride... 337 | **kwargs 338 | ): 339 | 340 | super().__init__() 341 | img_size_orig = img_size 342 | dropout = (dropout, dropout) if isinstance(dropout, (float, int)) else dropout 343 | if isinstance(channel, int): 344 | channels = [channel * 2**i for i in range(depth + 1)] 345 | ### adjust channels in encoder in case they are lower than in_ch to the next power of 2 346 | channel_min = min([c for c in channels if c >= in_ch]) 347 | channels_enc = [max(c, channel_min) for c in channels] 348 | else: 349 | ### depth ignored in this case 350 | channels = channel 351 | channels_enc = channels 352 | channels_dec = [c1 + c2 for c1, c2 in zip(channels_enc, channels)] 353 | # print(f'MOdular got channel={channel}\ncalculated:\tchannels={channels}, channels_enc={channels_enc}, channels_dec={channels_dec}') 354 | # channels_dec.append(channels_enc[-1]) 355 | self.inBlock, f = Block(inBlock, in_ch=in_ch, out_ch=channels_enc[0], img_size=img_size, **kwargs) 356 | img_size *= f 357 | self.encoder= Encoder(encoderBlock, skipBlock=skipBlock, channels=channels_enc, img_size=img_size, dropout=dropout[0], params_down=params_down, skip_first=skip_first, **kwargs) 358 | img_size *= self.encoder.factor 359 | self.bottleneck, f = Block(bottleneck, in_ch=channels_enc[-1], out_ch=channels_enc[-1], img_size=img_size,dropout=dropout[0], **kwargs) 360 | img_size *= f 361 | self.decoder = Decoder(decoderBlock, channels_in=channels_dec, channels_out=channels, img_size=img_size, dropout=dropout[1], params_up=params_up, skip_first=skip_first, **kwargs) 362 | img_size *= self.decoder.factor 363 | self.outBlock, f = Block(outBlock, in_ch=channels_dec[0], out_ch=out_ch, img_size=img_size, **kwargs) 364 | img_size *= f 365 | 366 | assert img_size==img_size_orig, f"img_size={img_size_orig} given but {img_size} after all layers" 367 | 368 | def forward(self, x) -> torch.Tensor: 369 | x = self.inBlock(x) 370 | x, skips = self.encoder(x) 371 | if self.bottleneck is not None: 372 | x = self.bottleneck(x) 373 | x = self.decoder(x, skips) 374 | x = self.outBlock(x) 375 | return x 376 | 377 | def print_layers(self): 378 | return (f'inBlock:\t{self.inBlock}' 379 | f'encoder:\t{self.encoder}' 380 | f'bottleneck:\t{self.bottleneck}' 381 | f'decoder:\t{self.decoder}' 382 | f'outBlock:\t{self.outBlock}' 383 | ) 384 | 385 | class Encoder(nn.Module): 386 | ''' 387 | encoder consists of repeated downsampling plus arbitrary block (e.g. nConvBlocks) given as encoderBlock 388 | 389 | additionally, a specific block can be given for the values passed on the skip connections (default identity) 390 | ''' 391 | def __init__( 392 | self, 393 | encoderBlock : list[tuple] | tuple, 394 | channels : list[int], 395 | skipBlock : tuple | list, 396 | img_size : int, 397 | params_down : dict, 398 | skip_first=True, 399 | **kwargs 400 | ) -> nn.Module: 401 | 402 | super().__init__() 403 | 404 | self.factor = 1 405 | 406 | self.encBlocks = torch.nn.ModuleList() 407 | self.downBlocks = torch.nn.ModuleList() 408 | self.skipBlocks = torch.nn.ModuleList() 409 | self.skip_first = skip_first 410 | 411 | for i in range(len(channels) - 1): 412 | ### skip 413 | b, _ = Block(skipBlock, in_ch=channels[i], out_ch=channels[i], img_size=img_size, **kwargs) 414 | self.skipBlocks.append(b) 415 | 416 | ### some arbitrary block BEFORE downsampling 417 | b, f = Block(encoderBlock, in_ch=channels[i], out_ch=channels[i+1], img_size=img_size, **kwargs) 418 | self.factor *= f 419 | img_size *= f 420 | self.encBlocks.append(b) 421 | 422 | ### down sampling, this one doesn't receive the general **kwargs (to avoid setting kernel_size and so on only meant for other layers) 423 | b, f = Block(('down', params_down), in_ch=channels[i+1], out_ch=channels[i+1], img_size=img_size) 424 | self.factor *= f 425 | img_size *= f 426 | self.downBlocks.append(b) 427 | 428 | b, _ = Block(skipBlock, in_ch=channels[-1], out_ch=channels[-1], img_size=img_size, **kwargs) 429 | self.skipBlocks.append(b) 430 | 431 | def forward( 432 | self, 433 | x : torch.Tensor 434 | ) -> tuple[torch.Tensor, list]: 435 | skips = [] 436 | # print(f'encoder in: {x.shape}') 437 | for i in range(len(self.encBlocks)): 438 | if self.skip_first or i==0: 439 | skips.append(self.skipBlocks[i](x)) 440 | x = self.encBlocks[i](x) 441 | if not self.skip_first: 442 | skips.append(self.skipBlocks[i](x)) 443 | x = self.downBlocks[i](x) 444 | # print(f'after encoder {i}: {x.shape} and skip {skips[-1].shape}') 445 | if self.skip_first: 446 | skips.append(self.skipBlocks[-1](x)) 447 | # print(f'skips: {[s.shape for s in skips]}') 448 | return x, skips 449 | 450 | def get_features( 451 | self, 452 | x : torch.Tensor 453 | ) -> tuple[torch.Tensor, list, list]: 454 | if not self.skip_first: 455 | raise NotImplementedError() 456 | skips = [self.skipBlocks[0](x)] 457 | features = [] 458 | for i in range(len(self.encBlocks)): 459 | x = self.encBlocks[i](x) 460 | features.append(x.detach().clone()) 461 | skips.append(self.skipBlocks[i+1](x)) 462 | return x, skips, features 463 | 464 | class Decoder(nn.Module): 465 | def __init__( 466 | self, 467 | decoderBlock : tuple | list, 468 | channels_in : list[int], 469 | channels_out : list[int], 470 | img_size : int, 471 | params_up : dict, 472 | skip_first = True, 473 | **kwargs 474 | ) -> nn.Module: 475 | super().__init__() 476 | 477 | self.factor = 1 478 | 479 | self.decBlocks = torch.nn.ModuleList() 480 | self.upBlocks = torch.nn.ModuleList() 481 | self.skip_first = skip_first 482 | 483 | for i in range(len(channels_in) - 1): 484 | b, f = Block(decoderBlock, in_ch=channels_in[i+1], out_ch=channels_out[i], img_size=img_size, **kwargs) 485 | self.factor *= f 486 | img_size *= f 487 | self.decBlocks.append(b) 488 | 489 | b, f = Block(('up', params_up), in_ch=channels_out[i] if skip_first else channels_out[i+1], out_ch=channels_out[i] if skip_first else channels_out[i+1], img_size=img_size, **kwargs) 490 | self.factor *= f 491 | img_size *= f 492 | self.upBlocks.append(b) 493 | 494 | def forward( 495 | self, 496 | x : torch.Tensor, 497 | skips : list, 498 | ) -> torch.Tensor: 499 | # print(f'decoder: len(skips)={len(skips)}, x.shape={x.shape}') 500 | for i in range(len(self.decBlocks) - 1, -1, -1): 501 | if self.skip_first: 502 | x = torch.cat([x, skips[i+1]], dim=1) 503 | x = self.decBlocks[i](x) 504 | x = self.upBlocks[i](x) 505 | else: 506 | x = self.upBlocks[i](x) 507 | # print(f'after up: {x.shape}') 508 | x = torch.cat([x, skips[i+1]], dim=1) 509 | x = self.decBlocks[i](x) 510 | # print(f'after decoder i={i}, x: {x.shape}') 511 | # print(f'last step decoder cat {x.shape}, {skips[0].shape}') 512 | x = torch.cat([x, skips[0]], dim=1) 513 | return x 514 | 515 | def get_features( 516 | self, 517 | x : torch.Tensor, 518 | skips : list, 519 | ) -> torch.Tensor: 520 | features = [] 521 | for i in range(len(self.decBlocks) - 1, -1, -1): 522 | x = torch.cat([x, skips[i+1]], dim=1) 523 | x = self.decBlocks[i](x) 524 | features.append(x.detach().clone()) 525 | x = torch.cat([x, skips[0]], dim=1) 526 | return x, features 527 | 528 | def Block( 529 | mods_args : list | tuple | None, 530 | in_ch : int | None = None, 531 | out_ch : int | None = None, 532 | **kwargs 533 | ) -> tuple[nn.Module, int]: 534 | ''' 535 | construct a block from a tuple of nn.Module (or function returning nn.Module), arguments, and optionally number of repetitions of the block 536 | returns the current image size as well (for ViT) 537 | alternatively, give a list of such tuples, then Block will call itself iteratively and construct a sequential module 538 | in_ch is only applied to the first block in the list, afterwards we always use out_ch, this has to be specified for layers requiring in_ch 539 | **kwargs can be used to overwrite values in arguments, this is useful for for specific layers (e.g. inBlock, bottleneck) that should otherwise use the basic arguments given inside the tuple 540 | ''' 541 | if mods_args is None: 542 | return None, 1 543 | elif isinstance(mods_args, list): 544 | b, factor = Block(mods_args=mods_args[0], in_ch=in_ch, out_ch=out_ch, **kwargs) 545 | modList = nn.Sequential(b) 546 | for i in range(1, len(mods_args)): 547 | b, f = Block(mods_args=mods_args[i], in_ch=out_ch, out_ch=out_ch, **kwargs) 548 | factor *= f 549 | if b is not None: 550 | modList.append(b) 551 | return modList, factor 552 | else: 553 | modName, arguments = mods_args 554 | ### check special case for nConvBlocks, Transformer layers and so on 555 | if hasattr(arguments, 'num_layers') and arguments['num_layers']==0: 556 | return None, 1 557 | ### calculate how the spatial size changes for ViT 558 | if 'down' in modName or 'Pool' in modName or 'down' in arguments.values(): 559 | if 'kernel_size' in arguments: 560 | factor = 1 / arguments['kernel_size'] 561 | else: 562 | factor = 1 / 2 563 | ### remove the standard kernel size, stride, dilation from the parameters, these should only be applied to other layers and explicitely given for up/down scaling layers if wanted 564 | for arg in ['kernel_size', 'stride', 'dilation']: 565 | if arg in kwargs.keys(): 566 | del kwargs[arg] 567 | elif 'up' in modName or 'Unpool' in modName or 'up' in arguments.values(): 568 | if 'kernel_size' in arguments: 569 | factor = arguments['kernel_size'] 570 | else: 571 | factor = 2 572 | for arg in ['kernel_size', 'stride', 'dilation']: 573 | if arg in kwargs.keys(): 574 | del kwargs[arg] 575 | else: 576 | factor = 1 577 | module = get_layer_dict()[modName](in_ch=in_ch, out_ch=out_ch, **{**kwargs, **arguments}) 578 | 579 | return module, factor 580 | 581 | #### metrics 582 | 583 | class MSE_img(nn.Module): 584 | '''MSE with "proper" reducton for images (always average over spatial dimensions, 'reduction' argument for batches and channels)''' 585 | def __init__(self, reduction : str = 'mean') -> None: 586 | super().__init__() 587 | self.reduction = reduction 588 | 589 | def forward(self, prediction : torch.Tensor, target : torch.Tensor): 590 | loss = torch.mean((prediction - target)**2, dim=(-1,-2)) 591 | if self.reduction =='mean': 592 | return torch.mean(loss) 593 | elif self.reduction == 'none': 594 | ### remove potentially useless channel dimension 595 | return torch.squeeze(loss) 596 | elif self.reduction == 'sum': 597 | return torch.sum(loss) 598 | else: 599 | raise NotImplementedError(f'reduction={self.reduction}') 600 | 601 | class L1_img(nn.Module): 602 | '''L1 with "proper" reducton for images (always average over spatial dimensions, 'reduction' argument for batches and channels)''' 603 | def __init__(self, reduction : str = 'mean') -> None: 604 | super().__init__() 605 | self.reduction = reduction 606 | 607 | def forward(self, prediction : torch.Tensor, target : torch.Tensor): 608 | loss = torch.mean(torch.abs(prediction - target), dim=(-1,-2)) 609 | if self.reduction =='mean': 610 | return torch.mean(loss) 611 | elif self.reduction == 'none': 612 | ### remove potentially useless channel dimension 613 | return torch.squeeze(loss) 614 | elif self.reduction == 'sum': 615 | return torch.sum(loss) 616 | else: 617 | raise NotImplementedError(f'reduction={self.reduction}') 618 | 619 | ''' 620 | Mostly copied from TransUNet, with minor changes 621 | https://arxiv.org/abs/2102.04306 622 | https://github.com/Beckschen/TransUNet 623 | 624 | Apache License 625 | Version 2.0, January 2004 626 | http://www.apache.org/licenses/ 627 | 628 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 629 | 630 | 1. Definitions. 631 | 632 | "License" shall mean the terms and conditions for use, reproduction, 633 | and distribution as defined by Sections 1 through 9 of this document. 634 | 635 | "Licensor" shall mean the copyright owner or entity authorized by 636 | the copyright owner that is granting the License. 637 | 638 | "Legal Entity" shall mean the union of the acting entity and all 639 | other entities that control, are controlled by, or are under common 640 | control with that entity. For the purposes of this definition, 641 | "control" means (i) the power, direct or indirect, to cause the 642 | direction or management of such entity, whether by contract or 643 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 644 | outstanding shares, or (iii) beneficial ownership of such entity. 645 | 646 | "You" (or "Your") shall mean an individual or Legal Entity 647 | exercising permissions granted by this License. 648 | 649 | "Source" form shall mean the preferred form for making modifications, 650 | including but not limited to software source code, documentation 651 | source, and configuration files. 652 | 653 | "Object" form shall mean any form resulting from mechanical 654 | transformation or translation of a Source form, including but 655 | not limited to compiled object code, generated documentation, 656 | and conversions to other media types. 657 | 658 | "Work" shall mean the work of authorship, whether in Source or 659 | Object form, made available under the License, as indicated by a 660 | copyright notice that is included in or attached to the work 661 | (an example is provided in the Appendix below). 662 | 663 | "Derivative Works" shall mean any work, whether in Source or Object 664 | form, that is based on (or derived from) the Work and for which the 665 | editorial revisions, annotations, elaborations, or other modifications 666 | represent, as a whole, an original work of authorship. For the purposes 667 | of this License, Derivative Works shall not include works that remain 668 | separable from, or merely link (or bind by name) to the interfaces of, 669 | the Work and Derivative Works thereof. 670 | 671 | "Contribution" shall mean any work of authorship, including 672 | the original version of the Work and any modifications or additions 673 | to that Work or Derivative Works thereof, that is intentionally 674 | submitted to Licensor for inclusion in the Work by the copyright owner 675 | or by an individual or Legal Entity authorized to submit on behalf of 676 | the copyright owner. For the purposes of this definition, "submitted" 677 | means any form of electronic, verbal, or written communication sent 678 | to the Licensor or its representatives, including but not limited to 679 | communication on electronic mailing lists, source code control systems, 680 | and issue tracking systems that are managed by, or on behalf of, the 681 | Licensor for the purpose of discussing and improving the Work, but 682 | excluding communication that is conspicuously marked or otherwise 683 | designated in writing by the copyright owner as "Not a Contribution." 684 | 685 | "Contributor" shall mean Licensor and any individual or Legal Entity 686 | on behalf of whom a Contribution has been received by Licensor and 687 | subsequently incorporated within the Work. 688 | 689 | 2. Grant of Copyright License. Subject to the terms and conditions of 690 | this License, each Contributor hereby grants to You a perpetual, 691 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 692 | copyright license to reproduce, prepare Derivative Works of, 693 | publicly display, publicly perform, sublicense, and distribute the 694 | Work and such Derivative Works in Source or Object form. 695 | 696 | 3. Grant of Patent License. Subject to the terms and conditions of 697 | this License, each Contributor hereby grants to You a perpetual, 698 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 699 | (except as stated in this section) patent license to make, have made, 700 | use, offer to sell, sell, import, and otherwise transfer the Work, 701 | where such license applies only to those patent claims licensable 702 | by such Contributor that are necessarily infringed by their 703 | Contribution(s) alone or by combination of their Contribution(s) 704 | with the Work to which such Contribution(s) was submitted. If You 705 | institute patent litigation against any entity (including a 706 | cross-claim or counterclaim in a lawsuit) alleging that the Work 707 | or a Contribution incorporated within the Work constitutes direct 708 | or contributory patent infringement, then any patent licenses 709 | granted to You under this License for that Work shall terminate 710 | as of the date such litigation is filed. 711 | 712 | 4. Redistribution. You may reproduce and distribute copies of the 713 | Work or Derivative Works thereof in any medium, with or without 714 | modifications, and in Source or Object form, provided that You 715 | meet the following conditions: 716 | 717 | (a) You must give any other recipients of the Work or 718 | Derivative Works a copy of this License; and 719 | 720 | (b) You must cause any modified files to carry prominent notices 721 | stating that You changed the files; and 722 | 723 | (c) You must retain, in the Source form of any Derivative Works 724 | that You distribute, all copyright, patent, trademark, and 725 | attribution notices from the Source form of the Work, 726 | excluding those notices that do not pertain to any part of 727 | the Derivative Works; and 728 | 729 | (d) If the Work includes a "NOTICE" text file as part of its 730 | distribution, then any Derivative Works that You distribute must 731 | include a readable copy of the attribution notices contained 732 | within such NOTICE file, excluding those notices that do not 733 | pertain to any part of the Derivative Works, in at least one 734 | of the following places: within a NOTICE text file distributed 735 | as part of the Derivative Works; within the Source form or 736 | documentation, if provided along with the Derivative Works; or, 737 | within a display generated by the Derivative Works, if and 738 | wherever such third-party notices normally appear. The contents 739 | of the NOTICE file are for informational purposes only and 740 | do not modify the License. You may add Your own attribution 741 | notices within Derivative Works that You distribute, alongside 742 | or as an addendum to the NOTICE text from the Work, provided 743 | that such additional attribution notices cannot be construed 744 | as modifying the License. 745 | 746 | You may add Your own copyright statement to Your modifications and 747 | may provide additional or different license terms and conditions 748 | for use, reproduction, or distribution of Your modifications, or 749 | for any such Derivative Works as a whole, provided Your use, 750 | reproduction, and distribution of the Work otherwise complies with 751 | the conditions stated in this License. 752 | 753 | 5. Submission of Contributions. Unless You explicitly state otherwise, 754 | any Contribution intentionally submitted for inclusion in the Work 755 | by You to the Licensor shall be under the terms and conditions of 756 | this License, without any additional terms or conditions. 757 | Notwithstanding the above, nothing herein shall supersede or modify 758 | the terms of any separate license agreement you may have executed 759 | with Licensor regarding such Contributions. 760 | 761 | 6. Trademarks. This License does not grant permission to use the trade 762 | names, trademarks, service marks, or product names of the Licensor, 763 | except as required for reasonable and customary use in describing the 764 | origin of the Work and reproducing the content of the NOTICE file. 765 | 766 | 7. Disclaimer of Warranty. Unless required by applicable law or 767 | agreed to in writing, Licensor provides the Work (and each 768 | Contributor provides its Contributions) on an "AS IS" BASIS, 769 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 770 | implied, including, without limitation, any warranties or conditions 771 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 772 | PARTICULAR PURPOSE. You are solely responsible for determining the 773 | appropriateness of using or redistributing the Work and assume any 774 | risks associated with Your exercise of permissions under this License. 775 | 776 | 8. Limitation of Liability. In no event and under no legal theory, 777 | whether in tort (including negligence), contract, or otherwise, 778 | unless required by applicable law (such as deliberate and grossly 779 | negligent acts) or agreed to in writing, shall any Contributor be 780 | liable to You for damages, including any direct, indirect, special, 781 | incidental, or consequential damages of any character arising as a 782 | result of this License or out of the use or inability to use the 783 | Work (including but not limited to damages for loss of goodwill, 784 | work stoppage, computer failure or malfunction, or any and all 785 | other commercial damages or losses), even if such Contributor 786 | has been advised of the possibility of such damages. 787 | 788 | 9. Accepting Warranty or Additional Liability. While redistributing 789 | the Work or Derivative Works thereof, You may choose to offer, 790 | and charge a fee for, acceptance of support, warranty, indemnity, 791 | or other liability obligations and/or rights consistent with this 792 | License. However, in accepting such obligations, You may act only 793 | on Your own behalf and on Your sole responsibility, not on behalf 794 | of any other Contributor, and only if You agree to indemnify, 795 | defend, and hold each Contributor harmless for any liability 796 | incurred by, or claims asserted against, such Contributor by reason 797 | of your accepting any such warranty or additional liability. 798 | 799 | END OF TERMS AND CONDITIONS 800 | 801 | APPENDIX: How to apply the Apache License to your work. 802 | 803 | To apply the Apache License to your work, attach the following 804 | boilerplate notice, with the fields enclosed by brackets "[]" 805 | replaced with your own identifying information. (Don't include 806 | the brackets!) The text should be enclosed in the appropriate 807 | comment syntax for the file format. We also recommend that a 808 | file or class name and description of purpose be included on the 809 | same "printed page" as the copyright notice for easier 810 | identification within third-party archives. 811 | 812 | Copyright [yyyy] [name of copyright owner] 813 | 814 | Licensed under the Apache License, Version 2.0 (the "License"); 815 | you may not use this file except in compliance with the License. 816 | You may obtain a copy of the License at 817 | 818 | http://www.apache.org/licenses/LICENSE-2.0 819 | 820 | Unless required by applicable law or agreed to in writing, software 821 | distributed under the License is distributed on an "AS IS" BASIS, 822 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 823 | See the License for the specific language governing permissions and 824 | limitations under the License. 825 | 826 | ''' 827 | 828 | class Embedder(nn.Module): 829 | def __init__(self, in_ch, img_size, hidden_size, grid_size=16, positional_embedding="zero", dropout=0): 830 | ### adapted from TransUNet 831 | ### img_size: real size at the current stage, update while creating the blocks accounting for down sampling before the transformer block 832 | ### only for square images so far 833 | super().__init__() 834 | ### calculate shapes 835 | assert img_size%grid_size==0, f"Embedder: img_size {img_size} must be divisible by grid_size {grid_size}" 836 | if img_size < grid_size: 837 | print(f"\n\nViT embedder got grid_size={grid_size}, but img_size={img_size}. Reducing grid_size to img_size.") 838 | grid_size = img_size 839 | patch_size = int(img_size // grid_size) 840 | n_patches = grid_size**2 841 | 842 | self.repr = f"embedder, in_ch={in_ch}, img_size={img_size}, hidden_size={hidden_size}, grid_size={grid_size}" 843 | ### define layers 844 | ### TBD: change this for lower layers, e.g. when kernel_size > 3, use >=two convs instead 845 | self.patch_embeddings = nn.Conv2d( in_ch, 846 | out_channels=hidden_size, 847 | kernel_size=patch_size, 848 | stride=patch_size) 849 | 850 | if positional_embedding=="zero": 851 | ### more TBD 852 | ### can we model spatial relations here by a graph and have a small graph CNN learn position embeddings? 853 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, hidden_size)) 854 | elif positional_embedding is None: 855 | self.position_embeddings = torch.zeros(1, n_patches, hidden_size) 856 | # elif positional_embedding=="coords": 857 | ### can we use coordinates here in any way? 858 | else: 859 | raise NotImplementedError("other position embeddings TBD") 860 | self.do = nn.Dropout(dropout) 861 | # print(f"embedder: {img_size}=img_size, hidden_size={hidden_size}, grid_size={grid_size}, calculated patch_size={patch_size}, n_patches={n_patches}") 862 | 863 | def forward(self, x): 864 | ### as in TransUNet 865 | x = self.patch_embeddings(x) 866 | x = x.flatten(2) 867 | x = x.transpose(-1, -2) 868 | embeddings = x + self.position_embeddings 869 | embeddings = self.do(embeddings) 870 | return embeddings 871 | 872 | def __repr__(self): 873 | return self.repr 874 | 875 | class Mlp(nn.Module): 876 | def __init__(self, hidden_size, dim_feedforward=None, dropout=0): 877 | super(Mlp, self).__init__() 878 | if dim_feedforward is None: 879 | dim_feedforward = 4 * hidden_size 880 | self.fc1 = nn.Linear(hidden_size, dim_feedforward) 881 | self.fc2 = nn.Linear(dim_feedforward, hidden_size) 882 | self.act_fn = nn.functional.gelu 883 | self.dropout = nn.Dropout(dropout) 884 | 885 | self._init_weights() 886 | 887 | self.repr = f"Mlp, hidden_size={hidden_size}, dim_feedforward={dim_feedforward}" 888 | 889 | def _init_weights(self): 890 | nn.init.xavier_uniform_(self.fc1.weight) 891 | nn.init.xavier_uniform_(self.fc2.weight) 892 | nn.init.normal_(self.fc1.bias, std=1e-6) 893 | nn.init.normal_(self.fc2.bias, std=1e-6) 894 | 895 | def forward(self, x): 896 | x = self.fc1(x) 897 | x = self.act_fn(x) 898 | x = self.dropout(x) 899 | x = self.fc2(x) 900 | x = self.dropout(x) 901 | return x 902 | 903 | def __repr__(self): 904 | return self.repr 905 | 906 | class selfAttentionBlock(nn.Module): 907 | def __init__(self, hidden_size, nhead=12, dropout=0, bias=True, vis = False): 908 | super().__init__() 909 | assert hidden_size % nhead == 0, f'choose hidden_size={hidden_size} divisible by nhead={nhead}' 910 | self.vis = False 911 | 912 | self.num_attention_heads = nhead 913 | self.attention_head_size = int(hidden_size / nhead) 914 | self.all_head_size = nhead * self.attention_head_size 915 | 916 | self.query = nn.Linear(hidden_size, self.all_head_size, bias=bias) 917 | self.key = nn.Linear(hidden_size, self.all_head_size, bias=bias) 918 | self.value = nn.Linear(hidden_size, self.all_head_size, bias=bias) 919 | 920 | self.out = nn.Linear(hidden_size, hidden_size) 921 | self.attn_dropout = nn.Dropout(dropout) 922 | self.proj_dropout = nn.Dropout(dropout) 923 | 924 | self.out = nn.Linear(hidden_size, hidden_size) 925 | self.attn_dropout = nn.Dropout(dropout) 926 | self.proj_dropout = nn.Dropout(dropout) 927 | 928 | self.softmax = nn.Softmax(dim=-1) 929 | 930 | self.repr = f"selfAttentionBlock, hidden_size={hidden_size}, nhead={nhead}, attention_head_size={self.attention_head_size}" 931 | 932 | def transpose_for_scores(self, x): 933 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 934 | x = x.view(*new_x_shape) 935 | return x.permute(0, 2, 1, 3) 936 | 937 | def forward(self, hidden_states): 938 | mixed_query_layer = self.query(hidden_states) 939 | mixed_key_layer = self.key(hidden_states) 940 | mixed_value_layer = self.value(hidden_states) 941 | 942 | query_layer = self.transpose_for_scores(mixed_query_layer) 943 | key_layer = self.transpose_for_scores(mixed_key_layer) 944 | value_layer = self.transpose_for_scores(mixed_value_layer) 945 | 946 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 947 | attention_scores = attention_scores / np.sqrt(self.attention_head_size) 948 | attention_probs = self.softmax(attention_scores) 949 | weights = attention_probs if self.vis else None 950 | attention_probs = self.attn_dropout(attention_probs) 951 | 952 | context_layer = torch.matmul(attention_probs, value_layer) 953 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 954 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 955 | context_layer = context_layer.view(*new_context_layer_shape) 956 | attention_output = self.out(context_layer) 957 | attention_output = self.proj_dropout(attention_output) 958 | return attention_output, weights 959 | 960 | def __repr__(self): 961 | return self.repr 962 | 963 | class TransformerBlock(nn.Module): 964 | def __init__(self, hidden_size, nhead=12, dropout=0, attn_bias=True, dim_feedforward=None, layer_norm_eps=1e-6, vis=False): 965 | super().__init__() 966 | self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) 967 | self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6) 968 | self.ffn = Mlp(hidden_size, dim_feedforward, dropout) 969 | self.attn = selfAttentionBlock(hidden_size, nhead, dropout=dropout, bias=attn_bias, vis=vis) 970 | 971 | self.repr = f"TransformerBlock, hidden_size={hidden_size}, nhead={nhead}" 972 | 973 | def forward(self, x): 974 | h = x 975 | x = self.attention_norm(x) 976 | x, weights = self.attn(x) 977 | x = x + h 978 | 979 | h = x 980 | x = self.ffn_norm(x) 981 | x = self.ffn(x) 982 | x = x + h 983 | return x, weights 984 | 985 | def __repr__(self): 986 | return self.repr 987 | 988 | class TransformerEncoder(nn.Module): 989 | def __init__(self, hidden_size, num_layers=12, nhead=12, dropout=0, attn_bias=False, dim_feedforward=None, layer_norm_eps=1e-6, vis=False): 990 | super().__init__() 991 | self.vis = vis 992 | self.layer = nn.ModuleList() 993 | self.encoder_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) 994 | for _ in range(num_layers): 995 | layer = TransformerBlock(hidden_size, nhead, dropout, attn_bias, dim_feedforward, layer_norm_eps, vis) 996 | self.layer.append(copy.deepcopy(layer)) 997 | self.repr = f"TransformerEncoder, hidden_size={hidden_size}, num_layers={num_layers}, nhead={nhead}" 998 | 999 | def forward(self, hidden_states): 1000 | attn_weights = [] 1001 | for layer_block in self.layer: 1002 | hidden_states, weights = layer_block(hidden_states) 1003 | if self.vis: 1004 | attn_weights.append(weights) 1005 | encoded = self.encoder_norm(hidden_states) 1006 | return encoded, attn_weights 1007 | 1008 | def __repr__(self): 1009 | return self.repr 1010 | 1011 | class TransformerDecoder(nn.Module): 1012 | def __init__(self, hidden_size, out_ch, out_size, grid_size=16, up="conv", bn=True) -> None: 1013 | super().__init__() 1014 | ### channels go from hidden_size to out_ch 1015 | ### resolution goes from grid_size to out_size 1016 | ########## 1017 | # BN eps??? 1018 | 1019 | ########## 1020 | self.out = nn.Sequential( 1021 | nn.Conv2d(hidden_size, out_ch, kernel_size=3, padding=1), 1022 | ) 1023 | if out_size > grid_size: 1024 | if up in ["nearest", "bilinear"]: 1025 | self.out.append(nn.Upsample(size=out_size, mode=up)) 1026 | elif up=="conv": 1027 | assert out_size%grid_size==0, f"in TransformerDecoder grid_size={grid_size} and out_size={out_size} don't work with mode conv" 1028 | factor = int(out_size//grid_size) 1029 | self.out.append(nn.ConvTranspose2d(out_ch, out_ch, kernel_size=factor, stride=factor)) 1030 | self.out.append(nn.ReLU()) 1031 | if bn: 1032 | self.out.append(nn.BatchNorm2d(out_ch)) 1033 | 1034 | self.repr = f"TransformerDecoder, hidden_size={hidden_size}, out_ch={out_ch}, out_size={out_size}, grid_size={grid_size}" 1035 | 1036 | def forward(self, hidden_states): 1037 | B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) 1038 | h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) 1039 | x = hidden_states.permute(0, 2, 1) 1040 | x = x.contiguous().view(B, hidden, h, w) 1041 | 1042 | return self.out(x) 1043 | 1044 | def __repr__(self): 1045 | return self.repr 1046 | 1047 | class ViT(nn.Module): 1048 | def __init__(self, in_ch, img_size, hidden_size=768, out_size=None, out_ch=None, grid_size=16, num_layers=12, nhead=12, dropout=0.1, attn_bias=False,\ 1049 | dim_feedforward=None, layer_norm_eps=1e-6, positional_embedding="zero", vis=False, up_out="conv", res=False, **kwargs) -> None: 1050 | super().__init__() 1051 | ### add assertion regarding hidden_size,... 1052 | 1053 | if grid_size is None or grid_size > img_size: 1054 | print(f"got grid_size={grid_size}>{img_size}=img_size, adjusting grid_size") 1055 | grid_size = int(img_size) 1056 | 1057 | if out_size is None: 1058 | out_size = img_size 1059 | if res and img_size!=out_size: 1060 | raise NotImplementedError(f"ViT with img_size={img_size} != out_size={out_size} and res={res}") 1061 | if out_ch is None: 1062 | out_ch = in_ch 1063 | 1064 | img_size = int(img_size) 1065 | 1066 | self.embedder = Embedder(in_ch, img_size, hidden_size, grid_size, positional_embedding, dropout) 1067 | self.encoder = TransformerEncoder(hidden_size, num_layers=num_layers, nhead=nhead, dropout=dropout, attn_bias=attn_bias, dim_feedforward=dim_feedforward, layer_norm_eps=layer_norm_eps, vis=vis) 1068 | self.out = TransformerDecoder(hidden_size, out_ch, out_size, grid_size, up=up_out) 1069 | 1070 | self.res = res 1071 | 1072 | self.repr = f"ViT, in_ch={in_ch}, out_ch={out_ch}, img_size={img_size}, hidden_size={hidden_size}, out_size={out_size}, grid_size={grid_size}, num_layers={num_layers}, nhead={nhead}" 1073 | 1074 | def forward(self, x): 1075 | residual = x 1076 | #print(f"ViT got input {x.shape}") 1077 | x = self.embedder(x) 1078 | #print(f"from embedder {x.shape}") 1079 | x, attn_weights = self.encoder(x) 1080 | #print(f"from transformer encoder {x.shape}") 1081 | x = self.out(x) 1082 | 1083 | if self.res: 1084 | x = x + residual 1085 | return x 1086 | 1087 | def __repr__(self): 1088 | return self.repr 1089 | 1090 | ### used to assign layers in Block 1091 | def get_layer_dict(): 1092 | return { 1093 | 'get_padding' : get_padding, 1094 | 'nConvBlocks' : nConvBlocks, 1095 | 'convBlock' : convBlock, 1096 | 'resNeXtBlock' : resNeXtBlock, 1097 | 'dilationBlock' : dilationBlock, 1098 | 'down' : down, 1099 | 'up' : up, 1100 | 'DCNBlock' : DCNBlock, 1101 | 'nDCNBlocks' : nDCNBlocks, 1102 | 'Encoder' : Encoder, 1103 | 'Decoder' : Decoder, 1104 | #### torch 1105 | 'Identity' : nn.Identity, 1106 | #### experimental 1107 | 'ViT' : ViT, 1108 | } 1109 | -------------------------------------------------------------------------------- /lib/utils_coords.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file contains several functions to produce tensors containing coordinates and encode antenna gain as inputs to the CNN models. 3 | ''' 4 | import torch 5 | import warnings 6 | import numpy as np 7 | 8 | xy_range = torch.arange(256, dtype=torch.float32) 9 | xc_base = torch.repeat_interleave(xy_range.reshape((-1, 1)), repeats=256, dim=1) 10 | yc_base = torch.repeat_interleave(xy_range.reshape((1, -1)), repeats=256, dim=0) 11 | 12 | def GA_coords(tx_coords : tuple | list): 13 | ''' 14 | out: xt, yt, zt (coordinates of Tx, constant tensors), xs, ys (coordinates of spatial positions) 15 | 16 | grid anchor from RadioTrans paper (https://ieeexplore.ieee.org/document/9753644) 17 | ''' 18 | xt = tx_coords[0] * torch.ones((256, 256)) 19 | yt = tx_coords[1] * torch.ones((256, 256)) 20 | zt = tx_coords[2] * torch.ones((256, 256)) 21 | 22 | return xt, yt, zt, xc_base, yc_base 23 | 24 | def euclidian_coords(nbuild : torch.Tensor, nveg : torch.Tensor, tx_coords : tuple | list): 25 | ''' 26 | out: cartesian coordinates of top of building/veg/floor in each 2D location, cut off at tx-height, relative to tx position 27 | 28 | xc, yc, zc_build, zc_veg, zc_floor 29 | ''' 30 | xt, yt, zt = tx_coords 31 | xc = xc_base - xt 32 | yc = yc_base - yt 33 | 34 | zc_build = torch.minimum(torch.zeros_like(nbuild), nbuild - zt) 35 | zc_veg = torch.minimum(torch.zeros_like(nbuild), nveg - zt) 36 | zc_floor = -zt * torch.ones_like(nbuild) 37 | 38 | return xc, yc, zc_build, zc_veg, zc_floor 39 | 40 | def cylindrical_coords(nbuild : torch.Tensor, nveg : torch.Tensor, tx_coords : tuple | list, phi_base : float): 41 | ''' 42 | out: cylindrical coordinates of top of building/veg/floor in each 2D location, cut off at tx-height, relative to tx position and orientation 43 | 44 | dist2d, phi, zc_build, zc_veg, zc_floor 45 | ''' 46 | xc, yc, zc_build, zc_veg, zc_floor = euclidian_coords(nbuild, nveg, tx_coords) 47 | 48 | phi = torch.arctan2(yc, xc) 49 | ### rotate 50 | phi = phi - phi_base 51 | ### correct 52 | phi = torch.where(phi < -1 * torch.pi, 2 * torch.pi + phi, phi) 53 | phi = torch.where(phi > torch.pi, phi - 2 * torch.pi, phi) 54 | 55 | r = torch.sqrt(xc**2 + yc**2) 56 | 57 | return r, phi, zc_build, zc_veg, zc_floor 58 | 59 | def spherical_coords(nbuild : torch.Tensor, nveg : torch.Tensor, tx_coords : tuple | list, phi_base : float, theta_base : float): 60 | '''out: spherical coordinates of top of building/veg/floor in each 2D location, cut off at tx-height, relative to tx position and orientation 61 | 62 | dist_3d_build, dist_3d_veg, dist_3d_floor, theta_build, theta_veg, theta_floor, phi''' 63 | assert theta_base >= 0 and theta_base < torch.pi, f"check phi_base={phi_base}, theta_base={theta_base}" 64 | ### REMOVE LATER 65 | 66 | if phi_base < -torch.pi or phi_base > torch.pi: 67 | # print(f'correcting phi={phi_base} to {(phi_base+torch.pi)%(2*torch.pi) -torch.pi}') 68 | phi_base = (phi_base + torch.pi)%(2 * torch.pi) - torch.pi 69 | 70 | xc, yc, zc_build, zc_veg, zc_floor = euclidian_coords(nbuild, nveg, tx_coords) 71 | 72 | dist_3d_build = torch.sqrt(xc**2 + yc**2 + zc_build**2) 73 | dist_3d_veg = torch.sqrt(xc**2 + yc**2 + zc_veg**2) 74 | dist_3d_floor = torch.sqrt(xc**2 + yc**2 + zc_floor**2) 75 | 76 | with warnings.catch_warnings(): 77 | warnings.simplefilter("ignore") 78 | theta_build = torch.where(dist_3d_build==0, torch.pi, torch.arccos(zc_build / dist_3d_build)) 79 | theta_veg = torch.where(dist_3d_veg==0, torch.pi, torch.arccos(zc_veg / dist_3d_veg)) 80 | theta_floor = torch.where(dist_3d_floor==0, torch.pi, torch.arccos(zc_floor / dist_3d_floor)) 81 | 82 | phi = torch.arctan2(yc, xc) 83 | ### rotate 84 | phi = phi - phi_base 85 | theta_build = theta_build - theta_base + torch.pi/2 86 | theta_veg = theta_veg - theta_base + torch.pi/2 87 | theta_floor = theta_floor - theta_base + torch.pi/2 88 | ### correct 89 | phi = torch.where(phi < -1 * torch.pi, 2 * torch.pi + phi, phi) 90 | phi = torch.where(phi > torch.pi, phi - 2 * torch.pi, phi) 91 | 92 | return dist_3d_build, dist_3d_veg, dist_3d_floor, theta_build, theta_veg, theta_floor, phi 93 | 94 | def spherical_slices(tx_coords : tuple | list, theta_base : float, z_max : int = 32, z_step : int = 2): 95 | '''' 96 | generates thetas, distances of spherical coordinates of the 3D positions (not buildings/veg) on the 2D grid and at heights 0, 0+z_step, 0+2*z_step,...z_max-z_step, 97 | for generating gain in slices 98 | 99 | out: theta, dist 3D tensors 100 | ''' 101 | xt, yt, zt = tx_coords 102 | xc = (torch.unsqueeze(xc_base, dim=0) - xt).repeat(z_max//z_step, 1, 1) 103 | yc = (torch.unsqueeze(yc_base, dim=0) - yt).repeat(z_max//z_step, 1, 1) 104 | zc = (torch.arange(start=0, end=z_max, step=z_step, dtype=torch.float32).reshape((-1,1,1)) - zt).repeat(1, 256, 256) 105 | 106 | r = torch.sqrt(xc**2 + yc**2 + zc**2) 107 | 108 | with warnings.catch_warnings(): 109 | warnings.simplefilter("ignore") 110 | theta = torch.where(r==0, torch.pi, torch.arccos(torch.maximum(zc / r, torch.tensor(-1)))) 111 | 112 | ### rotate 113 | theta = theta - theta_base + torch.pi/2 114 | 115 | return theta, r 116 | 117 | def dist_2d(tx_coords : tuple | list): 118 | ''' 119 | out: tensor containing 2D distance of each spatial location to Tx location 9only in x-y plane) 120 | ''' 121 | xt, yt = tx_coords[0], tx_coords[1] 122 | xc = xc_base - xt 123 | yc = yc_base - yt 124 | dist2d = torch.sqrt(xc**2 + yc**2) 125 | return dist2d 126 | 127 | def dist_3d(tx_coords : tuple | list): 128 | ''' 129 | out: tensor containing 3D distance of Rx at 1.5m height (Rx in our dataset) to Tx 130 | ''' 131 | xt, yt, zt = tx_coords[0], tx_coords[1], tx_coords[2] 132 | xc = xc_base - xt 133 | yc = yc_base - yt 134 | dist3d = torch.sqrt(xc**2 + yc**2 + (zt - 1.5)**2) 135 | return dist3d 136 | 137 | def get_heights(theta : torch.Tensor, dist_2d : torch.Tensor, tx_z : float, theta_base : float, z_max : int = 32): 138 | ''' 139 | takes theta, dist_2d and calculates back to height value in each pixel location 140 | max_height specifies value for locations with theta<=-3 (assigned to pixels which are behind Tx by our LoS-algorithm) 141 | 142 | out: height values corresponding to theta in each spatial lcoation 143 | ''' 144 | with warnings.catch_warnings(): 145 | warnings.simplefilter("ignore") 146 | heights = torch.minimum(tx_z - dist_2d / torch.tan(3/ 2 *torch.pi - theta - theta_base), torch.tensor(z_max, dtype=torch.float32)) 147 | ### set values 148 | heights[(theta<=-3)] = z_max 149 | heights[dist_2d==0] = tx_z 150 | 151 | return heights 152 | 153 | def project_gain(phi : torch.Tensor, theta : torch.Tensor, gain_array : np.ndarray, theta_max : torch.Tensor | None = None, normalize : bool =False, dist : torch.Tensor | None = None, z_max : int = 0): 154 | ''' "draw" antenna gain on the map according to antenna pattern, angles 155 | antenna gain is expected to be in the form theta x phi, angles in int steps 0,...,179 x 0,...,359 156 | theta_max can be given optionally, to include visibility (0 gain if LoS to the pixel is obstructed) 157 | slices at different heights can be taken by varying theta 158 | if dist 2d is given, we subtract -2log_10(dist2d), according to free psace path loss 159 | values are between -250 (no gain) and 0 without normalization, or shifted and scaled to [0,1] 160 | ''' 161 | assert gain_array[0,0]==-250.0, f'gain_array[0,0]={gain_array[0,0]}' 162 | ### set everything behind Tx to -250dB 163 | behind = (theta <= 0) | (theta>=torch.pi) | (phi >= torch.pi/2) | (phi <= -torch.pi/2) 164 | 165 | ### if theta_max is given, also set all pixels in shadows to (0,0), i.e. -250 dB 166 | if theta_max is not None: 167 | shadow = (theta_max + 1e-4 < theta) 168 | else: 169 | shadow = torch.full(theta.shape, False, dtype=torch.bool) 170 | 171 | phi_deg = torch.where(behind | shadow, 0, torch.floor(((phi / (2*torch.pi) * 360)))).type(torch.int)%360 172 | theta_deg = torch.where(behind | shadow, 0, torch.floor(((theta / (2*torch.pi) * 360)))).type(torch.int)%360 173 | 174 | assert ((phi_deg >= 0) & (phi_deg < 360) & (theta_deg >= 0) & (theta_deg < 180)).all(), f'torch.amin(phi_deg)={torch.amin(phi_deg)}, torch.amax(phi_deg)={torch.amax(phi_deg)}, torch.amin(theta_deg)={torch.amin(theta_deg)}, torch.amax(theta_deg)={torch.amax(theta_deg)}' 175 | 176 | ### torch doesn't offer a ravel_multi_index function unfortunately 177 | gain_proj = torch.tensor(np.take(gain_array[:180,:360], np.ravel_multi_index((theta_deg.numpy(), phi_deg.numpy()), (180, 360))), dtype=torch.float32) 178 | if dist is not None: 179 | ### correct in Tx position to avoid nans (no Rx here anyways) 180 | ### keep minimal power -250dB 181 | dist[dist==0] = 1 182 | gain_proj = torch.max(gain_proj - 20 * torch.log10(dist), -250 * torch.ones_like(gain_proj)) 183 | if normalize: 184 | gain_proj = (gain_proj + 250) / 250 185 | # gain_proj = (gain_proj + 250 + 20 * np.log10(np.sqrt(2 * 255**2 + z_max**2))) / (250 + 20 * np.log10(np.sqrt(2 * 255**2 + z_max**2))) 186 | elif normalize: 187 | gain_proj = (gain_proj + 250) / 250 188 | 189 | return gain_proj -------------------------------------------------------------------------------- /main_cli.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.cli import LightningCLI, SaveConfigCallback, ArgsType 2 | import lib.pl_datamodule # noqa: F401 3 | import lib.pl_lightningmodule # noqa: F401 4 | from lib import pl_callbacks 5 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar 6 | import datetime 7 | import warnings 8 | warnings.filterwarnings('ignore', message='.*shuffling enabled', ) 9 | warnings.filterwarnings('ignore', message='.*You have overridden', ) 10 | warnings.filterwarnings('ignore', message='.*number of training batches') 11 | 12 | class CLICustom(LightningCLI): 13 | '''' 14 | LightningCLI for training any model (potentially with any dataset). 15 | By default, results will be saved to ./logs/DATETIME/. TensorBoard logs are in a subdirectory showing the model name and inputs, checkpoints in a separate subdirectory. 16 | ''' 17 | def add_arguments_to_parser(self, parser) -> None: 18 | 19 | ### define callbacks we usually use with defaults, can still be changed from CLI 20 | parser.add_lightning_class_args(EarlyStopping, 'cb_early_stopping') 21 | parser.add_lightning_class_args(ModelCheckpoint, 'cb_ckpt_best') 22 | parser.add_lightning_class_args(ModelCheckpoint, 'cb_ckpt_last') 23 | parser.add_lightning_class_args(pl_callbacks.SummaryCustom, 'cb_summary') 24 | parser.add_lightning_class_args(pl_callbacks.AdjustThreads, 'cb_threads') 25 | parser.add_lightning_class_args(pl_callbacks.Precision, 'cb_precision') 26 | parser.add_lightning_class_args(TQDMProgressBar, 'cb_progbar') 27 | 28 | ### set default arguments for organization of logs, checkpoints, callbacks 29 | parser.set_defaults( 30 | { 31 | 'model' : 'LitUNetDCN', 32 | 'data' : 'LitRM_directional', 33 | 'trainer.logger' : { 34 | 'class_path' : 'pytorch_lightning.loggers.TensorBoardLogger', 35 | 'init_args' : { 36 | 'version' : '', 37 | 'sub_dir' : '' 38 | } 39 | }, 40 | 'trainer.default_root_dir' : f'./logs/{datetime.datetime.now().strftime("%y%m%d-%H%M%S")}', 41 | 'cb_early_stopping.monitor' : "loss_val_avg", 42 | 'cb_early_stopping.min_delta' : 1e-4, 43 | 'cb_early_stopping.patience' : 10, 44 | 'cb_early_stopping.check_on_train_epoch_end' : False, 45 | 'cb_early_stopping.verbose' : True, 46 | 'cb_early_stopping.check_finite' : False, 47 | 'cb_ckpt_best.filename' : "ep_{epoch}_loss_{loss_val_avg:.5f}", 48 | 'cb_ckpt_best.save_top_k' : 1, 49 | 'cb_ckpt_best.monitor' : "loss_val_avg", 50 | 'cb_ckpt_best.save_on_train_epoch_end' : False, 51 | 'cb_ckpt_last.filename' : "ep_{epoch}_step_{step}", 52 | 'cb_ckpt_last.save_top_k' : 1, 53 | 'cb_ckpt_last.train_time_interval' : datetime.timedelta(minutes=30), 54 | 'cb_progbar.refresh_rate' : 10, 55 | 56 | } 57 | ) 58 | ### link some values between data and model 59 | parser.link_arguments('data.PL_scale', 'model.init_args.PL_scale', apply_on='instantiate') 60 | parser.link_arguments('data.in_ch', 'model.init_args.in_ch', apply_on='instantiate') 61 | ### link the used inputs and the model name to the directory for the logger, checkpoints and so on 62 | parser.link_arguments('trainer.default_root_dir', 'trainer.logger.init_args.save_dir') 63 | parser.link_arguments(('model.name', 'data.name'), 'trainer.logger.init_args.name', compute_fn=lambda w, x: f"{w}_{x}", apply_on='instantiate') 64 | 65 | def main(args: ArgsType = None): 66 | CLICustom( 67 | model_class = lib.pl_lightningmodule.LitCNN, 68 | datamodule_class = lib.pl_datamodule.LitRM, 69 | subclass_mode_model = True, 70 | subclass_mode_data = True, 71 | seed_everything_default = 123, 72 | trainer_defaults = { 73 | 'max_epochs' : 120, 74 | 'devices' : [0], 75 | 'default_root_dir' : '.', 76 | 'precision' : '16-mixed', 77 | 'num_sanity_val_steps' : '0', 78 | 'enable_model_summary' : False, 79 | 'deterministic' : False, 80 | 'benchmark' : True, 81 | }, 82 | args = args 83 | ) 84 | 85 | if __name__ == '__main__': 86 | main() -------------------------------------------------------------------------------- /sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabja19/RML/aaa88503fcdd0810ac4544b942fee2548d01f866/sample.png --------------------------------------------------------------------------------