├── fig
├── demo.png
├── model1.png
├── model2.png
├── transition1.png
└── transition2.png
├── .gitignore
├── README.md
├── vae_generate_lidar.ipynb
└── cGAN_generate_lidar.ipynb
/fig/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjuite/radar-navigation/HEAD/fig/demo.png
--------------------------------------------------------------------------------
/fig/model1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjuite/radar-navigation/HEAD/fig/model1.png
--------------------------------------------------------------------------------
/fig/model2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjuite/radar-navigation/HEAD/fig/model2.png
--------------------------------------------------------------------------------
/fig/transition1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjuite/radar-navigation/HEAD/fig/transition1.png
--------------------------------------------------------------------------------
/fig/transition2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjuite/radar-navigation/HEAD/fig/transition2.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | devel/
3 | result*
4 | *.egg-info
5 | *.pyc
6 | .catkin_workspace
7 | log
8 | .vscode/
9 | procman/bot2-procman/lcmtypes/c/*
10 | procman/bot2-procman/lcmtypes/cpp/*
11 | procman/bot2-procman/pod-build/*
12 | procman/bot2-procman/python/src/bot_procman/build_prefix.py
13 | .ipynb_checkpoints/
14 | *.pkl
15 | bags/*/
16 | __pycache__
17 | events.*
18 | *.pth
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Enabling Learning-based Navigation in Obscurants with Lightweight, Low-cost Millimeter Wave Radar Using Cross-modal Contrastive Learning of Representations
2 |
3 | ## Intro
4 | This repo demonstrate using generative model to reconstruct mmWave radar range data to dense range data closer to LiDAR ground truth.
5 |
6 | The reconstructed range data can be used as signals for control policys.
7 | Futher details please refer to our [website](https://ARG-NCTU.github.io/projects/deeprl-mmWave.html).
8 |
9 |
11 |
12 |
13 | ## Dataset
14 | [dataset on our google drive](https://drive.google.com/drive/u/0/folders/1FMkjvJl070_LxqcNBFeBedPsZFoy0VNe)
15 |
16 | To run the inference model on colab. Please create a short cut of the dataset to your own google drive
17 |
18 |
20 |
22 |
23 | ## inference model
24 | [pretrained model on our google drive](https://drive.google.com/drive/u/2/folders/1oz7vF7SROx8Q85B1cLGpNItQHwsZkCKr)
25 |
26 | To run the inference model on colab. Please also create a short cut of pretrained models to your own google drive
27 |
28 |
30 |
32 |
33 |
34 | ## run colab
35 | - cGAN generate
36 | - [](https://colab.research.google.com/github/huangjuite/radar-navigation/blob/master/cGAN_generate_lidar.ipynb)
37 |
38 | - VAE generate
39 | - [](https://colab.research.google.com/github/huangjuite/radar-navigation/blob/master/vae_generate_lidar.ipynb)
40 |
41 |
42 |
--------------------------------------------------------------------------------
/vae_generate_lidar.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "display_name": "Python 3",
7 | "language": "python",
8 | "name": "python3"
9 | },
10 | "language_info": {
11 | "codemirror_mode": {
12 | "name": "ipython",
13 | "version": 3
14 | },
15 | "file_extension": ".py",
16 | "mimetype": "text/x-python",
17 | "name": "python",
18 | "nbconvert_exporter": "python",
19 | "pygments_lexer": "ipython3",
20 | "version": "3.6.9"
21 | },
22 | "colab": {
23 | "name": "vae_generate_lidar.ipynb",
24 | "provenance": []
25 | },
26 | "accelerator": "GPU"
27 | },
28 | "cells": [
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "id": "UJLg_uaFqf2V"
33 | },
34 | "source": [
35 | "import os\n",
36 | "import io\n",
37 | "import cv2\n",
38 | "import copy\n",
39 | "import math\n",
40 | "import random\n",
41 | "import numpy as np\n",
42 | "import pickle as pkl\n",
43 | "from tqdm import tqdm, trange\n",
44 | "from typing import Deque, Dict, List, Tuple\n",
45 | "import matplotlib.pyplot as plt\n",
46 | "\n",
47 | "\n",
48 | "import torch\n",
49 | "import torch.nn as nn\n",
50 | "import torch.nn.functional as F\n",
51 | "import torch.optim as optim\n",
52 | "from torch.utils.data.dataset import Dataset\n",
53 | "from torch.utils.data import DataLoader, random_split\n",
54 | "\n"
55 | ],
56 | "execution_count": 1,
57 | "outputs": []
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {
62 | "id": "Ca4eB4Gxqf2Z"
63 | },
64 | "source": [
65 | "## dataset\n",
66 | "\n",
67 | " Load dataset from your google drive.\n",
68 | " Please add a short cut of our dataset on google drive to your own google drive.\n",
69 | " Change the \"main_path\" of the dataset if necessary."
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "metadata": {
75 | "id": "VtEy0lUxqp6M",
76 | "outputId": "d2c4be62-2d5f-499c-a26a-fd37109077c4",
77 | "colab": {
78 | "base_uri": "https://localhost:8080/",
79 | "height": 35
80 | }
81 | },
82 | "source": [
83 | "from google.colab import drive\n",
84 | "drive.mount('/content/gdrive')"
85 | ],
86 | "execution_count": 2,
87 | "outputs": [
88 | {
89 | "output_type": "stream",
90 | "text": [
91 | "Mounted at /content/gdrive\n"
92 | ],
93 | "name": "stdout"
94 | }
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "metadata": {
100 | "scrolled": true,
101 | "id": "7vg8gr-Yqf2Z",
102 | "outputId": "5ac148d7-1f83-4df6-ee9d-0fa25bb29adc",
103 | "colab": {
104 | "base_uri": "https://localhost:8080/",
105 | "height": 35
106 | }
107 | },
108 | "source": [
109 | "paths = []\n",
110 | "main_path = '/content/gdrive/My Drive/transitions/'\n",
111 | "dirs = os.listdir(main_path)\n",
112 | "dirs.sort()\n",
113 | "for d in dirs:\n",
114 | " dirs1 = os.listdir(main_path+'/'+d)\n",
115 | " dirs1.sort()\n",
116 | " for p in dirs1:\n",
117 | " paths.append(main_path+'/'+d+'/'+p)\n",
118 | " # print(paths[-1])\n",
119 | "print('%d episodes'%len(paths))\n"
120 | ],
121 | "execution_count": 3,
122 | "outputs": [
123 | {
124 | "output_type": "stream",
125 | "text": [
126 | "228 episodes\n"
127 | ],
128 | "name": "stdout"
129 | }
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "metadata": {
135 | "id": "pgYz_jarqf2d",
136 | "outputId": "3b6ced08-2a0e-4ac9-c6b1-d834c6e8edf0",
137 | "colab": {
138 | "base_uri": "https://localhost:8080/",
139 | "height": 35
140 | }
141 | },
142 | "source": [
143 | "class MMDataset(Dataset):\n",
144 | " def __init__(self, paths):\n",
145 | " self.transitions = []\n",
146 | "\n",
147 | " for p in tqdm(paths):\n",
148 | " with open(p, \"rb\") as f:\n",
149 | " demo = pkl.load(f, encoding=\"bytes\")\n",
150 | " self.transitions.extend(demo)\n",
151 | " \n",
152 | " def __getitem__(self,index):\n",
153 | " mm_scan = self.transitions[index][b'mm_scan']\n",
154 | " laser_scan = self.transitions[index][b'laser_scan']\n",
155 | " mm_scan = torch.Tensor(mm_scan).reshape(1,-1)\n",
156 | " laser_scan = torch.Tensor(laser_scan).reshape(1,-1)\n",
157 | " \n",
158 | " return mm_scan, laser_scan\n",
159 | " \n",
160 | " def __len__(self):\n",
161 | " return len(self.transitions)\n",
162 | "\n",
163 | " \n",
164 | "batch_size = 16\n",
165 | "mm_dataset = MMDataset(paths)\n",
166 | "\n",
167 | "loader = DataLoader(dataset=mm_dataset,\n",
168 | " batch_size=batch_size,\n",
169 | " shuffle=True,\n",
170 | " num_workers=4)\n"
171 | ],
172 | "execution_count": 4,
173 | "outputs": [
174 | {
175 | "output_type": "stream",
176 | "text": [
177 | "100%|██████████| 228/228 [02:25<00:00, 1.57it/s]\n"
178 | ],
179 | "name": "stderr"
180 | }
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {
186 | "id": "xqmER4KDqf2g"
187 | },
188 | "source": [
189 | "## hyper parameters"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "metadata": {
195 | "id": "_w0mwItAqf2h"
196 | },
197 | "source": [
198 | "hyper_parameter = dict(\n",
199 | " kernel=3,\n",
200 | " stride=2,\n",
201 | " padding=2,\n",
202 | " latent=128,\n",
203 | " deconv_dim=32,\n",
204 | " deconv_channel=128,\n",
205 | " adjust_linear=235,\n",
206 | " epoch=100,\n",
207 | " learning_rate=0.001,\n",
208 | ")\n",
209 | "class Struct:\n",
210 | " def __init__(self, **entries):\n",
211 | " self.__dict__.update(entries)\n",
212 | "config = Struct(**hyper_parameter)"
213 | ],
214 | "execution_count": 5,
215 | "outputs": []
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "metadata": {
220 | "id": "8oKinQAWqf2j"
221 | },
222 | "source": [
223 | "## model"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "metadata": {
229 | "id": "HSCKIfJfqf2k"
230 | },
231 | "source": [
232 | "class MMvae(nn.Module):\n",
233 | " def __init__(self):\n",
234 | " super(MMvae, self).__init__()\n",
235 | " kernel = 3\n",
236 | " stride = 2\n",
237 | " self.conv = nn.Sequential(\n",
238 | " nn.Conv1d(1, 64, kernel_size=kernel, stride=stride),\n",
239 | " nn.ReLU(),\n",
240 | " nn.Conv1d(64, 64, kernel_size=kernel, stride=stride),\n",
241 | " nn.ReLU()\n",
242 | " )\n",
243 | " \n",
244 | " dim = 64*59\n",
245 | " self.linear1=nn.Sequential(\n",
246 | " nn.Linear(dim,512),\n",
247 | " nn.ReLU()\n",
248 | " )\n",
249 | " self.en_fc1=nn.Linear(512,config.latent)\n",
250 | " self.en_fc2=nn.Linear(512,config.latent)\n",
251 | " \n",
252 | " self.de_fc1=nn.Sequential(\n",
253 | " nn.Linear(config.latent,config.deconv_channel*config.deconv_dim),\n",
254 | " nn.ReLU()\n",
255 | " )\n",
256 | " \n",
257 | " self.de_conv =nn.Sequential(\n",
258 | " nn.ConvTranspose1d(config.deconv_channel, config.deconv_channel//2, kernel, stride=stride, padding=config.padding),\n",
259 | "# nn.ReLU(),\n",
260 | " nn.ConvTranspose1d(config.deconv_channel//2, config.deconv_channel//4, kernel, stride=stride, padding=config.padding),\n",
261 | "# nn.ReLU(),\n",
262 | " nn.ConvTranspose1d(config.deconv_channel//4, 1, kernel, stride=stride, padding=config.padding),\n",
263 | "# nn.ReLU(),\n",
264 | " )\n",
265 | " self.adjust_linear=nn.Sequential(\n",
266 | " nn.Linear(config.adjust_linear,241),\n",
267 | " nn.ReLU()\n",
268 | " )\n",
269 | "\n",
270 | " \n",
271 | " def encoder(self,x):\n",
272 | " x = self.conv(x)\n",
273 | " x = x.view(x.size(0),-1)\n",
274 | " x = self.linear1(x)\n",
275 | " mean = self.en_fc1(x)\n",
276 | " logvar = self.en_fc2(x)\n",
277 | " return mean, logvar\n",
278 | "\n",
279 | " def reparameter(self, mean, logvar):\n",
280 | " std = torch.exp(0.5*logvar)\n",
281 | " eps = torch.randn_like(std)\n",
282 | " return mean + eps*std\n",
283 | "\n",
284 | " def decoder(self,x):\n",
285 | " x = self.de_fc1(x)\n",
286 | " x = x.view(-1, config.deconv_channel, config.deconv_dim)\n",
287 | " x = self.de_conv(x)\n",
288 | " x = self.adjust_linear(x)\n",
289 | " return x\n",
290 | "\n",
291 | " def forward(self,x):\n",
292 | " mean, logvar = self.encoder(x)\n",
293 | " x = self.reparameter(mean, logvar)\n",
294 | " x = self.decoder(x)\n",
295 | " return x ,mean ,logvar"
296 | ],
297 | "execution_count": 6,
298 | "outputs": []
299 | },
300 | {
301 | "cell_type": "markdown",
302 | "metadata": {
303 | "id": "M9p8UDYFqf2n"
304 | },
305 | "source": [
306 | "## load model\n",
307 | "\n",
308 | " Load model from your google drive.\n",
309 | " Please add a short cut of our inference model on google drive to your own google drive.\n",
310 | " Change the \"model_path\" of the dataset if necessary. "
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "metadata": {
316 | "scrolled": false,
317 | "id": "_-OT3hAEqf2n",
318 | "outputId": "157db66c-0138-47b8-9de1-e71d45dac4fe",
319 | "colab": {
320 | "base_uri": "https://localhost:8080/",
321 | "height": 54
322 | }
323 | },
324 | "source": [
325 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
326 | "print('device, ',device)\n",
327 | "model = MMvae()\n",
328 | "model.to(device)\n",
329 | "model_path = '/content/gdrive/My Drive/deploy_model/vae/0726_1557.pth'\n",
330 | "model.load_state_dict(torch.load(model_path))"
331 | ],
332 | "execution_count": 7,
333 | "outputs": [
334 | {
335 | "output_type": "stream",
336 | "text": [
337 | "device, cuda:0\n"
338 | ],
339 | "name": "stdout"
340 | },
341 | {
342 | "output_type": "execute_result",
343 | "data": {
344 | "text/plain": [
345 | ""
346 | ]
347 | },
348 | "metadata": {
349 | "tags": []
350 | },
351 | "execution_count": 7
352 | }
353 | ]
354 | },
355 | {
356 | "cell_type": "markdown",
357 | "metadata": {
358 | "id": "Cv74SaEpqf2r"
359 | },
360 | "source": [
361 | "## visualize examples"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "metadata": {
367 | "id": "FVdIzAeGqf2s"
368 | },
369 | "source": [
370 | "def laser_visual(lasers=[], show=False, range_limit=6):\n",
371 | " colors = ['#3483EB','#FFA500','#15B01D']\n",
372 | " fig = plt.figure(figsize=(8, 8))\n",
373 | " for i, l in enumerate(lasers):\n",
374 | " # fig = plt.figure(figsize=(8, 8))\n",
375 | " angle = 120\n",
376 | " xp = []\n",
377 | " yp = []\n",
378 | " for r in l:\n",
379 | " if r <= range_limit:\n",
380 | " yp.append(r * math.cos(math.radians(angle)))\n",
381 | " xp.append(r * math.sin(math.radians(angle)))\n",
382 | " angle -= 1\n",
383 | " plt.xlim(-6, 6)\n",
384 | " plt.ylim(-6, 6)\n",
385 | " # plt.axis('off')\n",
386 | " plt.plot(xp, yp, 'x', color=colors[i])\n",
387 | " plt.show()\n"
388 | ],
389 | "execution_count": 8,
390 | "outputs": []
391 | },
392 | {
393 | "cell_type": "code",
394 | "metadata": {
395 | "id": "ewAqFzV3qf2v",
396 | "outputId": "461b16d2-3dc1-4cdc-85dc-195cc88a11f0",
397 | "colab": {
398 | "base_uri": "https://localhost:8080/",
399 | "height": 487
400 | }
401 | },
402 | "source": [
403 | "data1 = None\n",
404 | "for mm_scan, laser_scan in loader:\n",
405 | " mm_scan = mm_scan.to(device)\n",
406 | " \n",
407 | " x_hat ,mean ,logvar = model(mm_scan)\n",
408 | " \n",
409 | " x = x_hat.detach().cpu().numpy().reshape(batch_size,-1)[0]\n",
410 | " laser = laser_scan.numpy().reshape(batch_size,-1)[0]\n",
411 | " mm = mm_scan.detach().cpu().numpy().reshape(batch_size,-1)[0]\n",
412 | " \n",
413 | " laser_visual([laser, x, mm], show=True, range_limit=4.9)\n",
414 | " data1 = [laser, x, mm]\n",
415 | " break"
416 | ],
417 | "execution_count": 9,
418 | "outputs": [
419 | {
420 | "output_type": "display_data",
421 | "data": {
422 | "image/png": "\n",
423 | "text/plain": [
424 | ""
425 | ]
426 | },
427 | "metadata": {
428 | "tags": [],
429 | "needs_background": "light"
430 | }
431 | }
432 | ]
433 | }
434 | ]
435 | }
--------------------------------------------------------------------------------
/cGAN_generate_lidar.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "display_name": "Python 3",
7 | "language": "python",
8 | "name": "python3"
9 | },
10 | "language_info": {
11 | "codemirror_mode": {
12 | "name": "ipython",
13 | "version": 3
14 | },
15 | "file_extension": ".py",
16 | "mimetype": "text/x-python",
17 | "name": "python",
18 | "nbconvert_exporter": "python",
19 | "pygments_lexer": "ipython3",
20 | "version": "3.6.9"
21 | },
22 | "colab": {
23 | "name": "cGAN_generate_lidar.ipynb",
24 | "provenance": []
25 | },
26 | "accelerator": "GPU"
27 | },
28 | "cells": [
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "P_VunP7mcLdF"
33 | },
34 | "source": [
35 | "## cGAN generate LiDAR"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "metadata": {
41 | "id": "jXbqhZBHcLdG"
42 | },
43 | "source": [
44 | "import os\n",
45 | "import io\n",
46 | "import cv2\n",
47 | "import copy\n",
48 | "import math\n",
49 | "import random\n",
50 | "import numpy as np\n",
51 | "import pickle as pkl\n",
52 | "from tqdm import tqdm, trange\n",
53 | "from typing import Deque, Dict, List, Tuple\n",
54 | "import matplotlib.pyplot as plt\n",
55 | "\n",
56 | "\n",
57 | "import torch\n",
58 | "import torch.nn as nn\n",
59 | "import torch.nn.functional as F\n",
60 | "import torch.optim as optim\n",
61 | "from torch.autograd import Variable\n",
62 | "from torch.utils.data.dataset import Dataset\n",
63 | "from torch.utils.data import DataLoader, random_split\n",
64 | "\n"
65 | ],
66 | "execution_count": 1,
67 | "outputs": []
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "metadata": {
72 | "id": "f1qtyFZicLdK"
73 | },
74 | "source": [
75 | "## dataset\n",
76 | "\n",
77 | " Load dataset from your google drive.\n",
78 | " Please add a short cut of our dataset on google drive to your own google drive.\n",
79 | " Change the \"main_path\" of the dataset if necessary."
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "metadata": {
85 | "id": "79mIlikZiu92",
86 | "outputId": "0d709ee2-7966-48ec-86b1-1026c2401595",
87 | "colab": {
88 | "base_uri": "https://localhost:8080/",
89 | "height": 35
90 | }
91 | },
92 | "source": [
93 | "from google.colab import drive\n",
94 | "drive.mount('/content/gdrive')"
95 | ],
96 | "execution_count": 2,
97 | "outputs": [
98 | {
99 | "output_type": "stream",
100 | "text": [
101 | "Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n"
102 | ],
103 | "name": "stdout"
104 | }
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "metadata": {
110 | "scrolled": true,
111 | "id": "YqeNMpGCcLdK",
112 | "outputId": "1e762023-ed41-41cd-f3c5-c6b5617462ae",
113 | "colab": {
114 | "base_uri": "https://localhost:8080/",
115 | "height": 35
116 | }
117 | },
118 | "source": [
119 | "paths = []\n",
120 | "main_path = '/content/gdrive/My Drive/transitions/'\n",
121 | "dirs = os.listdir(main_path)\n",
122 | "dirs.sort()\n",
123 | "for d in dirs:\n",
124 | " dirs1 = os.listdir(main_path+'/'+d)\n",
125 | " dirs1.sort()\n",
126 | " for p in dirs1:\n",
127 | " paths.append(main_path+'/'+d+'/'+p)\n",
128 | " # print(paths[-1])\n",
129 | "print('%d episodes'%len(paths))\n"
130 | ],
131 | "execution_count": 3,
132 | "outputs": [
133 | {
134 | "output_type": "stream",
135 | "text": [
136 | "228 episodes\n"
137 | ],
138 | "name": "stdout"
139 | }
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "metadata": {
145 | "id": "14HAQLFucLdO",
146 | "outputId": "7338d1ce-f233-4a77-9db5-3f73f6c0b310",
147 | "colab": {
148 | "base_uri": "https://localhost:8080/",
149 | "height": 35
150 | }
151 | },
152 | "source": [
153 | "class MMDataset(Dataset):\n",
154 | " def __init__(self, paths):\n",
155 | " self.transitions = []\n",
156 | "\n",
157 | " for p in tqdm(paths):\n",
158 | " with open(p, \"rb\") as f:\n",
159 | " demo = pkl.load(f, encoding=\"bytes\")\n",
160 | " self.transitions.extend(demo)\n",
161 | " \n",
162 | " def __getitem__(self,index):\n",
163 | " mm_scan = self.transitions[index][b'mm_scan']\n",
164 | " laser_scan = self.transitions[index][b'laser_scan']\n",
165 | " mm_scan = torch.Tensor(mm_scan).reshape(1,-1)\n",
166 | " laser_scan = torch.Tensor(laser_scan).reshape(1,-1)\n",
167 | " \n",
168 | " return mm_scan, laser_scan\n",
169 | " \n",
170 | " def __len__(self):\n",
171 | " return len(self.transitions)\n",
172 | "\n",
173 | " \n",
174 | "batch_size = 16\n",
175 | "mm_dataset = MMDataset(paths)\n",
176 | "\n",
177 | "loader = DataLoader(dataset=mm_dataset,\n",
178 | " batch_size=batch_size,\n",
179 | " shuffle=True,\n",
180 | " num_workers=4)\n"
181 | ],
182 | "execution_count": 4,
183 | "outputs": [
184 | {
185 | "output_type": "stream",
186 | "text": [
187 | "100%|██████████| 228/228 [00:05<00:00, 40.39it/s]\n"
188 | ],
189 | "name": "stderr"
190 | }
191 | ]
192 | },
193 | {
194 | "cell_type": "markdown",
195 | "metadata": {
196 | "id": "vR_682nicLdR"
197 | },
198 | "source": [
199 | "## hyper parameters"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "metadata": {
205 | "id": "Y9F8bdeOcLdR"
206 | },
207 | "source": [
208 | "hyper_parameter = dict(\n",
209 | " kernel=3,\n",
210 | " stride=2,\n",
211 | " padding=2,\n",
212 | " deconv_dim=32,\n",
213 | " deconv_channel=128,\n",
214 | " adjust_linear=235,\n",
215 | " epoch=500,\n",
216 | " beta1=0.5,\n",
217 | " learning_rate=0.0002,\n",
218 | " nz=100,\n",
219 | " lambda_l1=100,\n",
220 | ")\n",
221 | "class Struct:\n",
222 | " def __init__(self, **entries):\n",
223 | " self.__dict__.update(entries)\n",
224 | "config = Struct(**hyper_parameter)"
225 | ],
226 | "execution_count": 5,
227 | "outputs": []
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "metadata": {
232 | "id": "85dWFlDCcLdU"
233 | },
234 | "source": [
235 | "## model"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "metadata": {
241 | "id": "yNAlrgJXcLdU"
242 | },
243 | "source": [
244 | "class Generator(nn.Module):\n",
245 | " def __init__(self):\n",
246 | " super(Generator, self).__init__()\n",
247 | " kernel = 3\n",
248 | " stride = 2\n",
249 | " self.conv = nn.Sequential(\n",
250 | " nn.Conv1d(1, 64, kernel_size=kernel, stride=stride),\n",
251 | " nn.ReLU(),\n",
252 | " nn.Conv1d(64, 64, kernel_size=kernel, stride=stride),\n",
253 | " nn.ReLU()\n",
254 | " )\n",
255 | " \n",
256 | " dim = 64*59\n",
257 | " self.linear=nn.Sequential(\n",
258 | " nn.Linear(dim,512),\n",
259 | " nn.ReLU(),\n",
260 | " nn.Linear(512,128)\n",
261 | " )\n",
262 | " \n",
263 | "# self.n_fc1=nn.Linear(config.nz, 128)\n",
264 | "# self.n_fc2=nn.Linear(128, 128)\n",
265 | " \n",
266 | "# self.fc_combine=nn.Linear(128*2, 128)\n",
267 | " \n",
268 | " self.de_fc1=nn.Sequential(\n",
269 | " nn.Linear(128,config.deconv_channel*config.deconv_dim),\n",
270 | " nn.ReLU()\n",
271 | " )\n",
272 | " \n",
273 | " self.de_conv =nn.Sequential(\n",
274 | " nn.ConvTranspose1d(config.deconv_channel, config.deconv_channel//2, kernel, stride=stride, padding=config.padding),\n",
275 | " nn.ConvTranspose1d(config.deconv_channel//2, config.deconv_channel//4, kernel, stride=stride, padding=config.padding),\n",
276 | " nn.ConvTranspose1d(config.deconv_channel//4, 1, kernel, stride=stride, padding=config.padding),\n",
277 | " )\n",
278 | " self.adjust_linear=nn.Sequential(\n",
279 | " nn.Linear(config.adjust_linear,241),\n",
280 | " nn.ReLU()\n",
281 | " )\n",
282 | "\n",
283 | " \n",
284 | " def encoder(self,x):\n",
285 | " x = self.conv(x)\n",
286 | " x = x.view(x.size(0),-1)\n",
287 | " x = self.linear(x)\n",
288 | " return x\n",
289 | "\n",
290 | " def decoder(self,x):\n",
291 | " x = self.de_fc1(x)\n",
292 | " x = x.view(-1, config.deconv_channel, config.deconv_dim)\n",
293 | " x = self.de_conv(x)\n",
294 | " x = self.adjust_linear(x)\n",
295 | " return x\n",
296 | "\n",
297 | " def forward(self, x):\n",
298 | " x = self.encoder(x)\n",
299 | "# n = self.n_fc1(n)\n",
300 | "# n = self.n_fc2(n)\n",
301 | " \n",
302 | "# x = torch.cat((x,n),dim=-1)\n",
303 | "# x = self.fc_combine(x)\n",
304 | " \n",
305 | " x = self.decoder(x)\n",
306 | " return x"
307 | ],
308 | "execution_count": 6,
309 | "outputs": []
310 | },
311 | {
312 | "cell_type": "code",
313 | "metadata": {
314 | "id": "sTNajyBWcLdX"
315 | },
316 | "source": [
317 | "class Discriminator(nn.Module):\n",
318 | " def __init__(self):\n",
319 | " super(Discriminator, self).__init__()\n",
320 | " kernel = 3\n",
321 | " stride = 2\n",
322 | " self.conv = nn.Sequential(\n",
323 | " nn.Conv1d(2, 64, kernel_size=kernel, stride=stride),\n",
324 | " nn.ReLU(),\n",
325 | " nn.Conv1d(64, 64, kernel_size=kernel, stride=stride),\n",
326 | " nn.ReLU()\n",
327 | " )\n",
328 | " \n",
329 | " dim = 64*59\n",
330 | " self.linear=nn.Sequential(\n",
331 | " nn.Linear(dim,512),\n",
332 | " nn.ReLU(),\n",
333 | " nn.Linear(512,128),\n",
334 | " nn.ReLU(),\n",
335 | " nn.Linear(128, 1),\n",
336 | " nn.Sigmoid(),\n",
337 | " )\n",
338 | "\n",
339 | " def forward(self, x):\n",
340 | " \n",
341 | " x = self.conv(x)\n",
342 | " x = x.view(x.size(0),-1)\n",
343 | " x = self.linear(x)\n",
344 | " \n",
345 | " return x"
346 | ],
347 | "execution_count": 7,
348 | "outputs": []
349 | },
350 | {
351 | "cell_type": "markdown",
352 | "metadata": {
353 | "id": "npt_JqzGcLda"
354 | },
355 | "source": [
356 | "## load model\n",
357 | "\n",
358 | " Load model from your google drive.\n",
359 | " Please add a short cut of our inference model on google drive to your own google drive.\n",
360 | " Change the \"model_path\" of the dataset if necessary. "
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "metadata": {
366 | "id": "1r_OdviacLda",
367 | "outputId": "a1bcb9f9-da76-4e28-c727-ef894d360881",
368 | "colab": {
369 | "base_uri": "https://localhost:8080/",
370 | "height": 508
371 | }
372 | },
373 | "source": [
374 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
375 | "print('device, ',device)\n",
376 | "model = Generator()\n",
377 | "\n",
378 | "# bce logits loss L1:0.1163\n",
379 | "model_path = '/content/gdrive/My Drive/deploy_model/cgan/0827_1851.pth'\n",
380 | "\n",
381 | "model.load_state_dict(torch.load(model_path))\n",
382 | "model.to(device)\n"
383 | ],
384 | "execution_count": 8,
385 | "outputs": [
386 | {
387 | "output_type": "stream",
388 | "text": [
389 | "device, cuda:0\n"
390 | ],
391 | "name": "stdout"
392 | },
393 | {
394 | "output_type": "execute_result",
395 | "data": {
396 | "text/plain": [
397 | "Generator(\n",
398 | " (conv): Sequential(\n",
399 | " (0): Conv1d(1, 64, kernel_size=(3,), stride=(2,))\n",
400 | " (1): ReLU()\n",
401 | " (2): Conv1d(64, 64, kernel_size=(3,), stride=(2,))\n",
402 | " (3): ReLU()\n",
403 | " )\n",
404 | " (linear): Sequential(\n",
405 | " (0): Linear(in_features=3776, out_features=512, bias=True)\n",
406 | " (1): ReLU()\n",
407 | " (2): Linear(in_features=512, out_features=128, bias=True)\n",
408 | " )\n",
409 | " (de_fc1): Sequential(\n",
410 | " (0): Linear(in_features=128, out_features=4096, bias=True)\n",
411 | " (1): ReLU()\n",
412 | " )\n",
413 | " (de_conv): Sequential(\n",
414 | " (0): ConvTranspose1d(128, 64, kernel_size=(3,), stride=(2,), padding=(2,))\n",
415 | " (1): ConvTranspose1d(64, 32, kernel_size=(3,), stride=(2,), padding=(2,))\n",
416 | " (2): ConvTranspose1d(32, 1, kernel_size=(3,), stride=(2,), padding=(2,))\n",
417 | " )\n",
418 | " (adjust_linear): Sequential(\n",
419 | " (0): Linear(in_features=235, out_features=241, bias=True)\n",
420 | " (1): ReLU()\n",
421 | " )\n",
422 | ")"
423 | ]
424 | },
425 | "metadata": {
426 | "tags": []
427 | },
428 | "execution_count": 8
429 | }
430 | ]
431 | },
432 | {
433 | "cell_type": "markdown",
434 | "metadata": {
435 | "id": "fofwexSxcLdd"
436 | },
437 | "source": [
438 | "## visualize"
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "metadata": {
444 | "id": "T0ZBUAA3cLde"
445 | },
446 | "source": [
447 | "def laser_visual(lasers=[], show=False, range_limit=6):\n",
448 | " colors = ['#3483EB','#FFA500','#15B01D']\n",
449 | " fig = plt.figure(figsize=(8, 8))\n",
450 | " for i, l in enumerate(lasers):\n",
451 | " # fig = plt.figure(figsize=(8, 8))\n",
452 | " angle = 120\n",
453 | " xp = []\n",
454 | " yp = []\n",
455 | " for r in l:\n",
456 | " if r <= range_limit:\n",
457 | " yp.append(r * math.cos(math.radians(angle)))\n",
458 | " xp.append(r * math.sin(math.radians(angle)))\n",
459 | " angle -= 1\n",
460 | " plt.xlim(-6, 6)\n",
461 | " plt.ylim(-6, 6)\n",
462 | " # plt.axis('off')\n",
463 | " plt.plot(xp, yp, 'x', color=colors[i])\n",
464 | " plt.show()\n"
465 | ],
466 | "execution_count": 9,
467 | "outputs": []
468 | },
469 | {
470 | "cell_type": "code",
471 | "metadata": {
472 | "scrolled": false,
473 | "id": "IFUUegxMcLdg",
474 | "outputId": "a56a43ef-4603-454f-fe7b-5e57fd3d9542",
475 | "colab": {
476 | "base_uri": "https://localhost:8080/",
477 | "height": 487
478 | }
479 | },
480 | "source": [
481 | "data1 = None\n",
482 | "for mm_scan, laser_scan in loader:\n",
483 | " mm_scan = mm_scan.to(device)\n",
484 | " x_hat = model(mm_scan)\n",
485 | " \n",
486 | " x = x_hat.detach().cpu().numpy().reshape(batch_size,-1)[0]\n",
487 | " laser = laser_scan.numpy().reshape(batch_size,-1)[0]\n",
488 | " mm = mm_scan.detach().cpu().numpy().reshape(batch_size,-1)[0]\n",
489 | " \n",
490 | " laser_visual([laser, x, mm], show=True, range_limit=4.9)\n",
491 | " data1 = [laser, x, mm]\n",
492 | " \n",
493 | " break"
494 | ],
495 | "execution_count": 15,
496 | "outputs": [
497 | {
498 | "output_type": "display_data",
499 | "data": {
500 | "image/png": "\n",
501 | "text/plain": [
502 | ""
503 | ]
504 | },
505 | "metadata": {
506 | "tags": [],
507 | "needs_background": "light"
508 | }
509 | }
510 | ]
511 | },
512 | {
513 | "cell_type": "code",
514 | "metadata": {
515 | "id": "rNiHJNispA6U"
516 | },
517 | "source": [
518 | ""
519 | ],
520 | "execution_count": 10,
521 | "outputs": []
522 | }
523 | ]
524 | }
--------------------------------------------------------------------------------