├── fig ├── demo.png ├── model1.png ├── model2.png ├── transition1.png └── transition2.png ├── .gitignore ├── README.md ├── vae_generate_lidar.ipynb └── cGAN_generate_lidar.ipynb /fig/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjuite/radar-navigation/HEAD/fig/demo.png -------------------------------------------------------------------------------- /fig/model1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjuite/radar-navigation/HEAD/fig/model1.png -------------------------------------------------------------------------------- /fig/model2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjuite/radar-navigation/HEAD/fig/model2.png -------------------------------------------------------------------------------- /fig/transition1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjuite/radar-navigation/HEAD/fig/transition1.png -------------------------------------------------------------------------------- /fig/transition2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjuite/radar-navigation/HEAD/fig/transition2.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | devel/ 3 | result* 4 | *.egg-info 5 | *.pyc 6 | .catkin_workspace 7 | log 8 | .vscode/ 9 | procman/bot2-procman/lcmtypes/c/* 10 | procman/bot2-procman/lcmtypes/cpp/* 11 | procman/bot2-procman/pod-build/* 12 | procman/bot2-procman/python/src/bot_procman/build_prefix.py 13 | .ipynb_checkpoints/ 14 | *.pkl 15 | bags/*/ 16 | __pycache__ 17 | events.* 18 | *.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Enabling Learning-based Navigation in Obscurants with Lightweight, Low-cost Millimeter Wave Radar Using Cross-modal Contrastive Learning of Representations 2 | 3 | ## Intro 4 | This repo demonstrate using generative model to reconstruct mmWave radar range data to dense range data closer to LiDAR ground truth. 5 | 6 | The reconstructed range data can be used as signals for control policys. 7 | Futher details please refer to our [website](https://ARG-NCTU.github.io/projects/deeprl-mmWave.html). 8 | 9 | 11 | 12 | 13 | ## Dataset 14 | [dataset on our google drive](https://drive.google.com/drive/u/0/folders/1FMkjvJl070_LxqcNBFeBedPsZFoy0VNe) 15 | 16 | To run the inference model on colab. Please create a short cut of the dataset to your own google drive 17 | 18 | 20 | 22 | 23 | ## inference model 24 | [pretrained model on our google drive](https://drive.google.com/drive/u/2/folders/1oz7vF7SROx8Q85B1cLGpNItQHwsZkCKr) 25 | 26 | To run the inference model on colab. Please also create a short cut of pretrained models to your own google drive 27 | 28 | 30 | 32 | 33 | 34 | ## run colab 35 | - cGAN generate 36 | - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huangjuite/radar-navigation/blob/master/cGAN_generate_lidar.ipynb) 37 | 38 | - VAE generate 39 | - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huangjuite/radar-navigation/blob/master/vae_generate_lidar.ipynb) 40 | 41 | 42 | -------------------------------------------------------------------------------- /vae_generate_lidar.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.9" 21 | }, 22 | "colab": { 23 | "name": "vae_generate_lidar.ipynb", 24 | "provenance": [] 25 | }, 26 | "accelerator": "GPU" 27 | }, 28 | "cells": [ 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "UJLg_uaFqf2V" 33 | }, 34 | "source": [ 35 | "import os\n", 36 | "import io\n", 37 | "import cv2\n", 38 | "import copy\n", 39 | "import math\n", 40 | "import random\n", 41 | "import numpy as np\n", 42 | "import pickle as pkl\n", 43 | "from tqdm import tqdm, trange\n", 44 | "from typing import Deque, Dict, List, Tuple\n", 45 | "import matplotlib.pyplot as plt\n", 46 | "\n", 47 | "\n", 48 | "import torch\n", 49 | "import torch.nn as nn\n", 50 | "import torch.nn.functional as F\n", 51 | "import torch.optim as optim\n", 52 | "from torch.utils.data.dataset import Dataset\n", 53 | "from torch.utils.data import DataLoader, random_split\n", 54 | "\n" 55 | ], 56 | "execution_count": 1, 57 | "outputs": [] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": { 62 | "id": "Ca4eB4Gxqf2Z" 63 | }, 64 | "source": [ 65 | "## dataset\n", 66 | "\n", 67 | " Load dataset from your google drive.\n", 68 | " Please add a short cut of our dataset on google drive to your own google drive.\n", 69 | " Change the \"main_path\" of the dataset if necessary." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "metadata": { 75 | "id": "VtEy0lUxqp6M", 76 | "outputId": "d2c4be62-2d5f-499c-a26a-fd37109077c4", 77 | "colab": { 78 | "base_uri": "https://localhost:8080/", 79 | "height": 35 80 | } 81 | }, 82 | "source": [ 83 | "from google.colab import drive\n", 84 | "drive.mount('/content/gdrive')" 85 | ], 86 | "execution_count": 2, 87 | "outputs": [ 88 | { 89 | "output_type": "stream", 90 | "text": [ 91 | "Mounted at /content/gdrive\n" 92 | ], 93 | "name": "stdout" 94 | } 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "metadata": { 100 | "scrolled": true, 101 | "id": "7vg8gr-Yqf2Z", 102 | "outputId": "5ac148d7-1f83-4df6-ee9d-0fa25bb29adc", 103 | "colab": { 104 | "base_uri": "https://localhost:8080/", 105 | "height": 35 106 | } 107 | }, 108 | "source": [ 109 | "paths = []\n", 110 | "main_path = '/content/gdrive/My Drive/transitions/'\n", 111 | "dirs = os.listdir(main_path)\n", 112 | "dirs.sort()\n", 113 | "for d in dirs:\n", 114 | " dirs1 = os.listdir(main_path+'/'+d)\n", 115 | " dirs1.sort()\n", 116 | " for p in dirs1:\n", 117 | " paths.append(main_path+'/'+d+'/'+p)\n", 118 | " # print(paths[-1])\n", 119 | "print('%d episodes'%len(paths))\n" 120 | ], 121 | "execution_count": 3, 122 | "outputs": [ 123 | { 124 | "output_type": "stream", 125 | "text": [ 126 | "228 episodes\n" 127 | ], 128 | "name": "stdout" 129 | } 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "metadata": { 135 | "id": "pgYz_jarqf2d", 136 | "outputId": "3b6ced08-2a0e-4ac9-c6b1-d834c6e8edf0", 137 | "colab": { 138 | "base_uri": "https://localhost:8080/", 139 | "height": 35 140 | } 141 | }, 142 | "source": [ 143 | "class MMDataset(Dataset):\n", 144 | " def __init__(self, paths):\n", 145 | " self.transitions = []\n", 146 | "\n", 147 | " for p in tqdm(paths):\n", 148 | " with open(p, \"rb\") as f:\n", 149 | " demo = pkl.load(f, encoding=\"bytes\")\n", 150 | " self.transitions.extend(demo)\n", 151 | " \n", 152 | " def __getitem__(self,index):\n", 153 | " mm_scan = self.transitions[index][b'mm_scan']\n", 154 | " laser_scan = self.transitions[index][b'laser_scan']\n", 155 | " mm_scan = torch.Tensor(mm_scan).reshape(1,-1)\n", 156 | " laser_scan = torch.Tensor(laser_scan).reshape(1,-1)\n", 157 | " \n", 158 | " return mm_scan, laser_scan\n", 159 | " \n", 160 | " def __len__(self):\n", 161 | " return len(self.transitions)\n", 162 | "\n", 163 | " \n", 164 | "batch_size = 16\n", 165 | "mm_dataset = MMDataset(paths)\n", 166 | "\n", 167 | "loader = DataLoader(dataset=mm_dataset,\n", 168 | " batch_size=batch_size,\n", 169 | " shuffle=True,\n", 170 | " num_workers=4)\n" 171 | ], 172 | "execution_count": 4, 173 | "outputs": [ 174 | { 175 | "output_type": "stream", 176 | "text": [ 177 | "100%|██████████| 228/228 [02:25<00:00, 1.57it/s]\n" 178 | ], 179 | "name": "stderr" 180 | } 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": { 186 | "id": "xqmER4KDqf2g" 187 | }, 188 | "source": [ 189 | "## hyper parameters" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "metadata": { 195 | "id": "_w0mwItAqf2h" 196 | }, 197 | "source": [ 198 | "hyper_parameter = dict(\n", 199 | " kernel=3,\n", 200 | " stride=2,\n", 201 | " padding=2,\n", 202 | " latent=128,\n", 203 | " deconv_dim=32,\n", 204 | " deconv_channel=128,\n", 205 | " adjust_linear=235,\n", 206 | " epoch=100,\n", 207 | " learning_rate=0.001,\n", 208 | ")\n", 209 | "class Struct:\n", 210 | " def __init__(self, **entries):\n", 211 | " self.__dict__.update(entries)\n", 212 | "config = Struct(**hyper_parameter)" 213 | ], 214 | "execution_count": 5, 215 | "outputs": [] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": { 220 | "id": "8oKinQAWqf2j" 221 | }, 222 | "source": [ 223 | "## model" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "metadata": { 229 | "id": "HSCKIfJfqf2k" 230 | }, 231 | "source": [ 232 | "class MMvae(nn.Module):\n", 233 | " def __init__(self):\n", 234 | " super(MMvae, self).__init__()\n", 235 | " kernel = 3\n", 236 | " stride = 2\n", 237 | " self.conv = nn.Sequential(\n", 238 | " nn.Conv1d(1, 64, kernel_size=kernel, stride=stride),\n", 239 | " nn.ReLU(),\n", 240 | " nn.Conv1d(64, 64, kernel_size=kernel, stride=stride),\n", 241 | " nn.ReLU()\n", 242 | " )\n", 243 | " \n", 244 | " dim = 64*59\n", 245 | " self.linear1=nn.Sequential(\n", 246 | " nn.Linear(dim,512),\n", 247 | " nn.ReLU()\n", 248 | " )\n", 249 | " self.en_fc1=nn.Linear(512,config.latent)\n", 250 | " self.en_fc2=nn.Linear(512,config.latent)\n", 251 | " \n", 252 | " self.de_fc1=nn.Sequential(\n", 253 | " nn.Linear(config.latent,config.deconv_channel*config.deconv_dim),\n", 254 | " nn.ReLU()\n", 255 | " )\n", 256 | " \n", 257 | " self.de_conv =nn.Sequential(\n", 258 | " nn.ConvTranspose1d(config.deconv_channel, config.deconv_channel//2, kernel, stride=stride, padding=config.padding),\n", 259 | "# nn.ReLU(),\n", 260 | " nn.ConvTranspose1d(config.deconv_channel//2, config.deconv_channel//4, kernel, stride=stride, padding=config.padding),\n", 261 | "# nn.ReLU(),\n", 262 | " nn.ConvTranspose1d(config.deconv_channel//4, 1, kernel, stride=stride, padding=config.padding),\n", 263 | "# nn.ReLU(),\n", 264 | " )\n", 265 | " self.adjust_linear=nn.Sequential(\n", 266 | " nn.Linear(config.adjust_linear,241),\n", 267 | " nn.ReLU()\n", 268 | " )\n", 269 | "\n", 270 | " \n", 271 | " def encoder(self,x):\n", 272 | " x = self.conv(x)\n", 273 | " x = x.view(x.size(0),-1)\n", 274 | " x = self.linear1(x)\n", 275 | " mean = self.en_fc1(x)\n", 276 | " logvar = self.en_fc2(x)\n", 277 | " return mean, logvar\n", 278 | "\n", 279 | " def reparameter(self, mean, logvar):\n", 280 | " std = torch.exp(0.5*logvar)\n", 281 | " eps = torch.randn_like(std)\n", 282 | " return mean + eps*std\n", 283 | "\n", 284 | " def decoder(self,x):\n", 285 | " x = self.de_fc1(x)\n", 286 | " x = x.view(-1, config.deconv_channel, config.deconv_dim)\n", 287 | " x = self.de_conv(x)\n", 288 | " x = self.adjust_linear(x)\n", 289 | " return x\n", 290 | "\n", 291 | " def forward(self,x):\n", 292 | " mean, logvar = self.encoder(x)\n", 293 | " x = self.reparameter(mean, logvar)\n", 294 | " x = self.decoder(x)\n", 295 | " return x ,mean ,logvar" 296 | ], 297 | "execution_count": 6, 298 | "outputs": [] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": { 303 | "id": "M9p8UDYFqf2n" 304 | }, 305 | "source": [ 306 | "## load model\n", 307 | "\n", 308 | " Load model from your google drive.\n", 309 | " Please add a short cut of our inference model on google drive to your own google drive.\n", 310 | " Change the \"model_path\" of the dataset if necessary. " 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "metadata": { 316 | "scrolled": false, 317 | "id": "_-OT3hAEqf2n", 318 | "outputId": "157db66c-0138-47b8-9de1-e71d45dac4fe", 319 | "colab": { 320 | "base_uri": "https://localhost:8080/", 321 | "height": 54 322 | } 323 | }, 324 | "source": [ 325 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 326 | "print('device, ',device)\n", 327 | "model = MMvae()\n", 328 | "model.to(device)\n", 329 | "model_path = '/content/gdrive/My Drive/deploy_model/vae/0726_1557.pth'\n", 330 | "model.load_state_dict(torch.load(model_path))" 331 | ], 332 | "execution_count": 7, 333 | "outputs": [ 334 | { 335 | "output_type": "stream", 336 | "text": [ 337 | "device, cuda:0\n" 338 | ], 339 | "name": "stdout" 340 | }, 341 | { 342 | "output_type": "execute_result", 343 | "data": { 344 | "text/plain": [ 345 | "" 346 | ] 347 | }, 348 | "metadata": { 349 | "tags": [] 350 | }, 351 | "execution_count": 7 352 | } 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": { 358 | "id": "Cv74SaEpqf2r" 359 | }, 360 | "source": [ 361 | "## visualize examples" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "metadata": { 367 | "id": "FVdIzAeGqf2s" 368 | }, 369 | "source": [ 370 | "def laser_visual(lasers=[], show=False, range_limit=6):\n", 371 | " colors = ['#3483EB','#FFA500','#15B01D']\n", 372 | " fig = plt.figure(figsize=(8, 8))\n", 373 | " for i, l in enumerate(lasers):\n", 374 | " # fig = plt.figure(figsize=(8, 8))\n", 375 | " angle = 120\n", 376 | " xp = []\n", 377 | " yp = []\n", 378 | " for r in l:\n", 379 | " if r <= range_limit:\n", 380 | " yp.append(r * math.cos(math.radians(angle)))\n", 381 | " xp.append(r * math.sin(math.radians(angle)))\n", 382 | " angle -= 1\n", 383 | " plt.xlim(-6, 6)\n", 384 | " plt.ylim(-6, 6)\n", 385 | " # plt.axis('off')\n", 386 | " plt.plot(xp, yp, 'x', color=colors[i])\n", 387 | " plt.show()\n" 388 | ], 389 | "execution_count": 8, 390 | "outputs": [] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "metadata": { 395 | "id": "ewAqFzV3qf2v", 396 | "outputId": "461b16d2-3dc1-4cdc-85dc-195cc88a11f0", 397 | "colab": { 398 | "base_uri": "https://localhost:8080/", 399 | "height": 487 400 | } 401 | }, 402 | "source": [ 403 | "data1 = None\n", 404 | "for mm_scan, laser_scan in loader:\n", 405 | " mm_scan = mm_scan.to(device)\n", 406 | " \n", 407 | " x_hat ,mean ,logvar = model(mm_scan)\n", 408 | " \n", 409 | " x = x_hat.detach().cpu().numpy().reshape(batch_size,-1)[0]\n", 410 | " laser = laser_scan.numpy().reshape(batch_size,-1)[0]\n", 411 | " mm = mm_scan.detach().cpu().numpy().reshape(batch_size,-1)[0]\n", 412 | " \n", 413 | " laser_visual([laser, x, mm], show=True, range_limit=4.9)\n", 414 | " data1 = [laser, x, mm]\n", 415 | " break" 416 | ], 417 | "execution_count": 9, 418 | "outputs": [ 419 | { 420 | "output_type": "display_data", 421 | "data": { 422 | "image/png": "\n", 423 | "text/plain": [ 424 | "
" 425 | ] 426 | }, 427 | "metadata": { 428 | "tags": [], 429 | "needs_background": "light" 430 | } 431 | } 432 | ] 433 | } 434 | ] 435 | } -------------------------------------------------------------------------------- /cGAN_generate_lidar.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.9" 21 | }, 22 | "colab": { 23 | "name": "cGAN_generate_lidar.ipynb", 24 | "provenance": [] 25 | }, 26 | "accelerator": "GPU" 27 | }, 28 | "cells": [ 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "P_VunP7mcLdF" 33 | }, 34 | "source": [ 35 | "## cGAN generate LiDAR" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "metadata": { 41 | "id": "jXbqhZBHcLdG" 42 | }, 43 | "source": [ 44 | "import os\n", 45 | "import io\n", 46 | "import cv2\n", 47 | "import copy\n", 48 | "import math\n", 49 | "import random\n", 50 | "import numpy as np\n", 51 | "import pickle as pkl\n", 52 | "from tqdm import tqdm, trange\n", 53 | "from typing import Deque, Dict, List, Tuple\n", 54 | "import matplotlib.pyplot as plt\n", 55 | "\n", 56 | "\n", 57 | "import torch\n", 58 | "import torch.nn as nn\n", 59 | "import torch.nn.functional as F\n", 60 | "import torch.optim as optim\n", 61 | "from torch.autograd import Variable\n", 62 | "from torch.utils.data.dataset import Dataset\n", 63 | "from torch.utils.data import DataLoader, random_split\n", 64 | "\n" 65 | ], 66 | "execution_count": 1, 67 | "outputs": [] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "f1qtyFZicLdK" 73 | }, 74 | "source": [ 75 | "## dataset\n", 76 | "\n", 77 | " Load dataset from your google drive.\n", 78 | " Please add a short cut of our dataset on google drive to your own google drive.\n", 79 | " Change the \"main_path\" of the dataset if necessary." 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "metadata": { 85 | "id": "79mIlikZiu92", 86 | "outputId": "0d709ee2-7966-48ec-86b1-1026c2401595", 87 | "colab": { 88 | "base_uri": "https://localhost:8080/", 89 | "height": 35 90 | } 91 | }, 92 | "source": [ 93 | "from google.colab import drive\n", 94 | "drive.mount('/content/gdrive')" 95 | ], 96 | "execution_count": 2, 97 | "outputs": [ 98 | { 99 | "output_type": "stream", 100 | "text": [ 101 | "Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n" 102 | ], 103 | "name": "stdout" 104 | } 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "metadata": { 110 | "scrolled": true, 111 | "id": "YqeNMpGCcLdK", 112 | "outputId": "1e762023-ed41-41cd-f3c5-c6b5617462ae", 113 | "colab": { 114 | "base_uri": "https://localhost:8080/", 115 | "height": 35 116 | } 117 | }, 118 | "source": [ 119 | "paths = []\n", 120 | "main_path = '/content/gdrive/My Drive/transitions/'\n", 121 | "dirs = os.listdir(main_path)\n", 122 | "dirs.sort()\n", 123 | "for d in dirs:\n", 124 | " dirs1 = os.listdir(main_path+'/'+d)\n", 125 | " dirs1.sort()\n", 126 | " for p in dirs1:\n", 127 | " paths.append(main_path+'/'+d+'/'+p)\n", 128 | " # print(paths[-1])\n", 129 | "print('%d episodes'%len(paths))\n" 130 | ], 131 | "execution_count": 3, 132 | "outputs": [ 133 | { 134 | "output_type": "stream", 135 | "text": [ 136 | "228 episodes\n" 137 | ], 138 | "name": "stdout" 139 | } 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "metadata": { 145 | "id": "14HAQLFucLdO", 146 | "outputId": "7338d1ce-f233-4a77-9db5-3f73f6c0b310", 147 | "colab": { 148 | "base_uri": "https://localhost:8080/", 149 | "height": 35 150 | } 151 | }, 152 | "source": [ 153 | "class MMDataset(Dataset):\n", 154 | " def __init__(self, paths):\n", 155 | " self.transitions = []\n", 156 | "\n", 157 | " for p in tqdm(paths):\n", 158 | " with open(p, \"rb\") as f:\n", 159 | " demo = pkl.load(f, encoding=\"bytes\")\n", 160 | " self.transitions.extend(demo)\n", 161 | " \n", 162 | " def __getitem__(self,index):\n", 163 | " mm_scan = self.transitions[index][b'mm_scan']\n", 164 | " laser_scan = self.transitions[index][b'laser_scan']\n", 165 | " mm_scan = torch.Tensor(mm_scan).reshape(1,-1)\n", 166 | " laser_scan = torch.Tensor(laser_scan).reshape(1,-1)\n", 167 | " \n", 168 | " return mm_scan, laser_scan\n", 169 | " \n", 170 | " def __len__(self):\n", 171 | " return len(self.transitions)\n", 172 | "\n", 173 | " \n", 174 | "batch_size = 16\n", 175 | "mm_dataset = MMDataset(paths)\n", 176 | "\n", 177 | "loader = DataLoader(dataset=mm_dataset,\n", 178 | " batch_size=batch_size,\n", 179 | " shuffle=True,\n", 180 | " num_workers=4)\n" 181 | ], 182 | "execution_count": 4, 183 | "outputs": [ 184 | { 185 | "output_type": "stream", 186 | "text": [ 187 | "100%|██████████| 228/228 [00:05<00:00, 40.39it/s]\n" 188 | ], 189 | "name": "stderr" 190 | } 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": { 196 | "id": "vR_682nicLdR" 197 | }, 198 | "source": [ 199 | "## hyper parameters" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "metadata": { 205 | "id": "Y9F8bdeOcLdR" 206 | }, 207 | "source": [ 208 | "hyper_parameter = dict(\n", 209 | " kernel=3,\n", 210 | " stride=2,\n", 211 | " padding=2,\n", 212 | " deconv_dim=32,\n", 213 | " deconv_channel=128,\n", 214 | " adjust_linear=235,\n", 215 | " epoch=500,\n", 216 | " beta1=0.5,\n", 217 | " learning_rate=0.0002,\n", 218 | " nz=100,\n", 219 | " lambda_l1=100,\n", 220 | ")\n", 221 | "class Struct:\n", 222 | " def __init__(self, **entries):\n", 223 | " self.__dict__.update(entries)\n", 224 | "config = Struct(**hyper_parameter)" 225 | ], 226 | "execution_count": 5, 227 | "outputs": [] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": { 232 | "id": "85dWFlDCcLdU" 233 | }, 234 | "source": [ 235 | "## model" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "metadata": { 241 | "id": "yNAlrgJXcLdU" 242 | }, 243 | "source": [ 244 | "class Generator(nn.Module):\n", 245 | " def __init__(self):\n", 246 | " super(Generator, self).__init__()\n", 247 | " kernel = 3\n", 248 | " stride = 2\n", 249 | " self.conv = nn.Sequential(\n", 250 | " nn.Conv1d(1, 64, kernel_size=kernel, stride=stride),\n", 251 | " nn.ReLU(),\n", 252 | " nn.Conv1d(64, 64, kernel_size=kernel, stride=stride),\n", 253 | " nn.ReLU()\n", 254 | " )\n", 255 | " \n", 256 | " dim = 64*59\n", 257 | " self.linear=nn.Sequential(\n", 258 | " nn.Linear(dim,512),\n", 259 | " nn.ReLU(),\n", 260 | " nn.Linear(512,128)\n", 261 | " )\n", 262 | " \n", 263 | "# self.n_fc1=nn.Linear(config.nz, 128)\n", 264 | "# self.n_fc2=nn.Linear(128, 128)\n", 265 | " \n", 266 | "# self.fc_combine=nn.Linear(128*2, 128)\n", 267 | " \n", 268 | " self.de_fc1=nn.Sequential(\n", 269 | " nn.Linear(128,config.deconv_channel*config.deconv_dim),\n", 270 | " nn.ReLU()\n", 271 | " )\n", 272 | " \n", 273 | " self.de_conv =nn.Sequential(\n", 274 | " nn.ConvTranspose1d(config.deconv_channel, config.deconv_channel//2, kernel, stride=stride, padding=config.padding),\n", 275 | " nn.ConvTranspose1d(config.deconv_channel//2, config.deconv_channel//4, kernel, stride=stride, padding=config.padding),\n", 276 | " nn.ConvTranspose1d(config.deconv_channel//4, 1, kernel, stride=stride, padding=config.padding),\n", 277 | " )\n", 278 | " self.adjust_linear=nn.Sequential(\n", 279 | " nn.Linear(config.adjust_linear,241),\n", 280 | " nn.ReLU()\n", 281 | " )\n", 282 | "\n", 283 | " \n", 284 | " def encoder(self,x):\n", 285 | " x = self.conv(x)\n", 286 | " x = x.view(x.size(0),-1)\n", 287 | " x = self.linear(x)\n", 288 | " return x\n", 289 | "\n", 290 | " def decoder(self,x):\n", 291 | " x = self.de_fc1(x)\n", 292 | " x = x.view(-1, config.deconv_channel, config.deconv_dim)\n", 293 | " x = self.de_conv(x)\n", 294 | " x = self.adjust_linear(x)\n", 295 | " return x\n", 296 | "\n", 297 | " def forward(self, x):\n", 298 | " x = self.encoder(x)\n", 299 | "# n = self.n_fc1(n)\n", 300 | "# n = self.n_fc2(n)\n", 301 | " \n", 302 | "# x = torch.cat((x,n),dim=-1)\n", 303 | "# x = self.fc_combine(x)\n", 304 | " \n", 305 | " x = self.decoder(x)\n", 306 | " return x" 307 | ], 308 | "execution_count": 6, 309 | "outputs": [] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "metadata": { 314 | "id": "sTNajyBWcLdX" 315 | }, 316 | "source": [ 317 | "class Discriminator(nn.Module):\n", 318 | " def __init__(self):\n", 319 | " super(Discriminator, self).__init__()\n", 320 | " kernel = 3\n", 321 | " stride = 2\n", 322 | " self.conv = nn.Sequential(\n", 323 | " nn.Conv1d(2, 64, kernel_size=kernel, stride=stride),\n", 324 | " nn.ReLU(),\n", 325 | " nn.Conv1d(64, 64, kernel_size=kernel, stride=stride),\n", 326 | " nn.ReLU()\n", 327 | " )\n", 328 | " \n", 329 | " dim = 64*59\n", 330 | " self.linear=nn.Sequential(\n", 331 | " nn.Linear(dim,512),\n", 332 | " nn.ReLU(),\n", 333 | " nn.Linear(512,128),\n", 334 | " nn.ReLU(),\n", 335 | " nn.Linear(128, 1),\n", 336 | " nn.Sigmoid(),\n", 337 | " )\n", 338 | "\n", 339 | " def forward(self, x):\n", 340 | " \n", 341 | " x = self.conv(x)\n", 342 | " x = x.view(x.size(0),-1)\n", 343 | " x = self.linear(x)\n", 344 | " \n", 345 | " return x" 346 | ], 347 | "execution_count": 7, 348 | "outputs": [] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "metadata": { 353 | "id": "npt_JqzGcLda" 354 | }, 355 | "source": [ 356 | "## load model\n", 357 | "\n", 358 | " Load model from your google drive.\n", 359 | " Please add a short cut of our inference model on google drive to your own google drive.\n", 360 | " Change the \"model_path\" of the dataset if necessary. " 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "metadata": { 366 | "id": "1r_OdviacLda", 367 | "outputId": "a1bcb9f9-da76-4e28-c727-ef894d360881", 368 | "colab": { 369 | "base_uri": "https://localhost:8080/", 370 | "height": 508 371 | } 372 | }, 373 | "source": [ 374 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 375 | "print('device, ',device)\n", 376 | "model = Generator()\n", 377 | "\n", 378 | "# bce logits loss L1:0.1163\n", 379 | "model_path = '/content/gdrive/My Drive/deploy_model/cgan/0827_1851.pth'\n", 380 | "\n", 381 | "model.load_state_dict(torch.load(model_path))\n", 382 | "model.to(device)\n" 383 | ], 384 | "execution_count": 8, 385 | "outputs": [ 386 | { 387 | "output_type": "stream", 388 | "text": [ 389 | "device, cuda:0\n" 390 | ], 391 | "name": "stdout" 392 | }, 393 | { 394 | "output_type": "execute_result", 395 | "data": { 396 | "text/plain": [ 397 | "Generator(\n", 398 | " (conv): Sequential(\n", 399 | " (0): Conv1d(1, 64, kernel_size=(3,), stride=(2,))\n", 400 | " (1): ReLU()\n", 401 | " (2): Conv1d(64, 64, kernel_size=(3,), stride=(2,))\n", 402 | " (3): ReLU()\n", 403 | " )\n", 404 | " (linear): Sequential(\n", 405 | " (0): Linear(in_features=3776, out_features=512, bias=True)\n", 406 | " (1): ReLU()\n", 407 | " (2): Linear(in_features=512, out_features=128, bias=True)\n", 408 | " )\n", 409 | " (de_fc1): Sequential(\n", 410 | " (0): Linear(in_features=128, out_features=4096, bias=True)\n", 411 | " (1): ReLU()\n", 412 | " )\n", 413 | " (de_conv): Sequential(\n", 414 | " (0): ConvTranspose1d(128, 64, kernel_size=(3,), stride=(2,), padding=(2,))\n", 415 | " (1): ConvTranspose1d(64, 32, kernel_size=(3,), stride=(2,), padding=(2,))\n", 416 | " (2): ConvTranspose1d(32, 1, kernel_size=(3,), stride=(2,), padding=(2,))\n", 417 | " )\n", 418 | " (adjust_linear): Sequential(\n", 419 | " (0): Linear(in_features=235, out_features=241, bias=True)\n", 420 | " (1): ReLU()\n", 421 | " )\n", 422 | ")" 423 | ] 424 | }, 425 | "metadata": { 426 | "tags": [] 427 | }, 428 | "execution_count": 8 429 | } 430 | ] 431 | }, 432 | { 433 | "cell_type": "markdown", 434 | "metadata": { 435 | "id": "fofwexSxcLdd" 436 | }, 437 | "source": [ 438 | "## visualize" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "metadata": { 444 | "id": "T0ZBUAA3cLde" 445 | }, 446 | "source": [ 447 | "def laser_visual(lasers=[], show=False, range_limit=6):\n", 448 | " colors = ['#3483EB','#FFA500','#15B01D']\n", 449 | " fig = plt.figure(figsize=(8, 8))\n", 450 | " for i, l in enumerate(lasers):\n", 451 | " # fig = plt.figure(figsize=(8, 8))\n", 452 | " angle = 120\n", 453 | " xp = []\n", 454 | " yp = []\n", 455 | " for r in l:\n", 456 | " if r <= range_limit:\n", 457 | " yp.append(r * math.cos(math.radians(angle)))\n", 458 | " xp.append(r * math.sin(math.radians(angle)))\n", 459 | " angle -= 1\n", 460 | " plt.xlim(-6, 6)\n", 461 | " plt.ylim(-6, 6)\n", 462 | " # plt.axis('off')\n", 463 | " plt.plot(xp, yp, 'x', color=colors[i])\n", 464 | " plt.show()\n" 465 | ], 466 | "execution_count": 9, 467 | "outputs": [] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "metadata": { 472 | "scrolled": false, 473 | "id": "IFUUegxMcLdg", 474 | "outputId": "a56a43ef-4603-454f-fe7b-5e57fd3d9542", 475 | "colab": { 476 | "base_uri": "https://localhost:8080/", 477 | "height": 487 478 | } 479 | }, 480 | "source": [ 481 | "data1 = None\n", 482 | "for mm_scan, laser_scan in loader:\n", 483 | " mm_scan = mm_scan.to(device)\n", 484 | " x_hat = model(mm_scan)\n", 485 | " \n", 486 | " x = x_hat.detach().cpu().numpy().reshape(batch_size,-1)[0]\n", 487 | " laser = laser_scan.numpy().reshape(batch_size,-1)[0]\n", 488 | " mm = mm_scan.detach().cpu().numpy().reshape(batch_size,-1)[0]\n", 489 | " \n", 490 | " laser_visual([laser, x, mm], show=True, range_limit=4.9)\n", 491 | " data1 = [laser, x, mm]\n", 492 | " \n", 493 | " break" 494 | ], 495 | "execution_count": 15, 496 | "outputs": [ 497 | { 498 | "output_type": "display_data", 499 | "data": { 500 | "image/png": "\n", 501 | "text/plain": [ 502 | "
" 503 | ] 504 | }, 505 | "metadata": { 506 | "tags": [], 507 | "needs_background": "light" 508 | } 509 | } 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "metadata": { 515 | "id": "rNiHJNispA6U" 516 | }, 517 | "source": [ 518 | "" 519 | ], 520 | "execution_count": 10, 521 | "outputs": [] 522 | } 523 | ] 524 | } --------------------------------------------------------------------------------