├── .gitignore ├── LICENSE ├── README.MD ├── dataset ├── __init__.py └── locPair.py ├── modules ├── Attention.py ├── Convolutional_RE_Net.py ├── Linear_RE_Net.py ├── __init__.py ├── algorithmSim.py └── weightGen.py ├── performance_evaluation.py ├── test.py ├── training.py ├── utils ├── Config.py ├── Map.py ├── __init__.py ├── dist.py ├── line_calc.py └── load_dict.py ├── weight_ob_2.pth ├── weight_pretrain.py └── weight_rss_2.pth /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | ## modules 2 | 3 | Deep learning models. 4 | 5 | + **algorithmSim.py**: The main frame of the whole net, combining all the solitary modules. **Multichannel Gain** and the **Obstruction Network** are implemented directly in the main frame *class Algorithm*. 6 | + **Convolutional_RE_Net.py**: **Convolutional RE-Net** is realized. 7 | + **weightGen.py**: The two nonlinear models in the **Obstruction Network** and the **RM-Net** are realized here. 8 | + **Linear_RE_Net.py**: **Linear RE-Net** is realized. 9 | 10 | ## utils 11 | 12 | Some useful functions and settings. 13 | 14 | + **Config.py**: Basic settings. 15 | + **dist.py**: Calculate the log-distance. 16 | + **line_calc.py**: Calculate the line equation. 17 | + **load_dict.py**: Load the trained part into the combined framework. 18 | + **Map.py**: Transform the location pair to a position map. 19 | 20 | ## dataset 21 | 22 | Generate the dataset. 23 | 24 | + **locPair.py**: Obtain the location pair and the RSS measurements. 25 | 26 | 27 | 28 | ------ 29 | 30 | 31 | 32 | ## Running Files 33 | 34 | + **training.py**: Train our framework. 35 | + **weight_pretrain**: Pre-train the nonlinear models in the **Obstruction Network** and the **RM-Net**. 36 | + **performance_evaluation.py**: Evaluate the performance of our framework with whole datasets. 37 | 38 | ## Tactic 39 | 40 | + Set the basic settings, like the dataset filename in *training.py*, all the parameters in the *Config.py*. 41 | + Initialize the parameters of Multichannel Gain in *algorithmSim.py*. 42 | + Run *weight_pretrain.py*. 43 | + Run *training.py*. 44 | + Run *performance_evaluation.py* to see the performance. -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zppppppx/deep-learning-based-radioMap/aff528a1281f0432e86064d5ddb2537449e89d06/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/locPair.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import tensor as t 4 | from torch.utils.data import Dataset 5 | from scipy.io import loadmat 6 | from utils import Map 7 | from utils import dist 8 | 9 | # import sys 10 | # sys.path.append('./utils') 11 | # import Map 12 | 13 | class RadioMap(Dataset): 14 | def __init__(self, file_path, choice='synthetic', fraction=None, seed=1): 15 | All_data = loadmat(file_path) 16 | self.radio_map = All_data['RadioMap'] 17 | self.length = self.radio_map.shape[0] 18 | # self.marks = [int(marks[0]*self.length), int(marks[1]*self.length)] 19 | np.random.seed(seed) 20 | 21 | if fraction == None: 22 | fraction = 10000./self.length 23 | rand_index = np.random.randint(0,self.length, int(fraction*self.length)) 24 | self.locs = self.radio_map[rand_index, :6]#self.radio_map[self.marks[0]:self.marks[1], :6] 25 | 26 | # Let the ordinate begin from 0. 27 | if choice == 'synthetic': 28 | self.locs[:, np.array([0,1,3,4])] -= 3 29 | if choice == 'simulated': 30 | self.locs[:, np.array([0,1,3,4])] += np.array([26.71, 131.3, 26.71, 131.3]) 31 | 32 | # self.Position_maps = self.locs_to_maps() 33 | self.rss = self.radio_map[rand_index, -1] #[self.marks[0]:self.marks[1], -1] 34 | 35 | 36 | def __len__(self): 37 | return self.rss.shape[0] 38 | 39 | def __getitem__(self, index): 40 | return t(self.locs[index]).float(), t(self.rss[index]).float() 41 | 42 | 43 | 44 | 45 | 46 | 47 | if __name__ == '__main__': 48 | file_path = '../data/radiomap_simulated100tx_3class.mat' 49 | data = loadmat(file_path) 50 | print(data['RadioMap'].shape) 51 | print(Map.loc_pair_to_map()) -------------------------------------------------------------------------------- /modules/Attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | ########################################################################### 7 | # Not functional in the Main Frame MAY BE USEFUL IN CONVOLUTIONAL RE-NET # 8 | ########################################################################### 9 | 10 | class Flatten(nn.Module): 11 | def forward(self, x): 12 | return x.view(x.size(0), -1) 13 | 14 | class spatial_attention(nn.Module): 15 | def __init__(self): 16 | super(spatial_attention, self).__init__() 17 | 18 | # conv 19 | self.conv = nn.Sequential( 20 | nn.Conv2d(2, 1, 7, 1, 3), 21 | nn.BatchNorm2d(1, eps=1e-5, momentum=0.01, affine=True), 22 | ) 23 | 24 | def forward(self, x): 25 | """ 26 | Args: 27 | x: here we need the Position map of User and UAV. 28 | """ 29 | 30 | Position_avgpool = torch.mean(x, dim=1)[:,None,:] 31 | Position_maxpool = torch.max(x, dim=1)[0][:,None,:] 32 | 33 | Spatial_attention = self.conv(torch.cat((Position_maxpool, Position_avgpool), dim=1)) 34 | Spatial_attention = torch.sigmoid(Spatial_attention) 35 | 36 | return Spatial_attention*x 37 | 38 | 39 | class channel_attention(nn.Module): 40 | def __init__(self, in_ch, reduction_ratio=16, pool_types=['avg', 'max']): 41 | super(channel_attention, self).__init__() 42 | self.in_ch = in_ch 43 | self.mlp = nn.Sequential( 44 | Flatten(), 45 | nn.Linear(in_ch, in_ch//reduction_ratio), 46 | nn.ReLU(), 47 | nn.Linear(in_ch//reduction_ratio, in_ch) 48 | ) 49 | self.pool_types = pool_types 50 | 51 | def forward(self, x): 52 | channel_sum = None 53 | for pool_type in self.pool_types: 54 | if pool_type == 'avg': 55 | avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 56 | channel_raw = self.mlp(avg_pool) 57 | elif pool_type == 'max': 58 | max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 59 | channel_raw = self.mlp(max_pool) 60 | elif pool_type == 'lp': 61 | lp_pool = F.lp_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 62 | channel_raw = self.mlp(lp_pool) 63 | 64 | if channel_sum == None: 65 | channel_sum = channel_raw 66 | else: 67 | channel_sum += channel_raw 68 | 69 | scale = torch.sigmoid(channel_sum).unsqueeze(2).unsqueeze(3) 70 | return scale*x 71 | 72 | class CBAM(nn.Module): 73 | def __init__(self, in_ch, reduction_ratio=16, pool_types=['avg', 'max']): 74 | super(CBAM, self).__init__() 75 | self.ch_att = channel_attention(in_ch, reduction_ratio, pool_types) 76 | self.spa_att = spatial_attention() 77 | 78 | def forward(self, x): 79 | x = self.ch_att(x) 80 | x = self.spa_att(x) 81 | 82 | return x 83 | 84 | 85 | if __name__ == '__main__': 86 | x = torch.ones([1,64,94,94]) 87 | ch_att = channel_attention(64) 88 | y = ch_att(x) 89 | sp_att = spatial_attention() 90 | y = sp_att(x) 91 | cbam = CBAM(64) 92 | y = cbam(x) 93 | print(y.shape) 94 | 95 | 96 | -------------------------------------------------------------------------------- /modules/Convolutional_RE_Net.py: -------------------------------------------------------------------------------- 1 | from utils.Config import Config 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | 10 | 11 | ######################################### 12 | # Basic blocks mentioned in the Appendix # 13 | ######################################### 14 | class double_conv(nn.Module): 15 | """ 16 | Double-Conv Layer. 17 | """ 18 | def __init__(self, in_ch, out_ch): 19 | super(double_conv, self).__init__() 20 | self.conv = nn.Sequential( 21 | nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), 22 | nn.BatchNorm2d(out_ch), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), 25 | nn.BatchNorm2d(out_ch), 26 | nn.ReLU(inplace=True) 27 | ) 28 | self.conv.apply(self.init_weights) 29 | 30 | def forward(self, x): 31 | x = self.conv(x) 32 | return x 33 | 34 | @staticmethod 35 | def init_weights(m): 36 | if type(m) == nn.Conv2d: 37 | init.xavier_normal_(m.weight) 38 | # init.uniform_(m.weight) 39 | init.constant_(m.bias,0) 40 | 41 | class up_net_sole(nn.Module): 42 | """ 43 | Up-sampling Deconvolutional Layer. 44 | """ 45 | def __init__(self, in_ch, out_ch): 46 | super(up_net_sole, self).__init__() 47 | # self.up = nn.Sequential( 48 | # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 49 | # nn.Conv2d(in_ch, in_ch, 3, 1, 1), 50 | # nn.ReLU() 51 | # ) 52 | self.up = nn.ConvTranspose2d(in_ch, in_ch, 2, 2) 53 | self.conv = double_conv(in_ch, out_ch) 54 | self.up.apply(self.init_weights) 55 | self.conv.apply(self.init_weights) 56 | 57 | def forward(self, x1, target_size): 58 | x1 = self.up(x1) 59 | 60 | diffY = target_size[0] - x1.size()[2] 61 | diffX = target_size[1] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 64 | diffY // 2, diffY - diffY//2), 'replicate') 65 | 66 | # x = torch.cat([x2,x1], dim=1) 67 | x = self.conv(x1) 68 | return x 69 | 70 | @staticmethod 71 | def init_weights(m): 72 | if type(m) == nn.Conv2d: 73 | init.xavier_normal_(m.weight) 74 | init.constant_(m.bias,0) 75 | 76 | 77 | 78 | ############################################# 79 | # Overall structure of Convolutional RE-Net # 80 | ############################################# 81 | class map_block(nn.Module): 82 | """ 83 | Single structure mentioned in the Convolutional RE-Net. 84 | """ 85 | def __init__(self, opt): 86 | super(map_block, self).__init__() 87 | self.map_size = opt.map_size 88 | self.map_order = opt.map_order 89 | self.target_size = [ 90 | [*map(lambda x: x//2//2, self.map_size)], 91 | [*map(lambda x: x//2, self.map_size)], 92 | [*map(lambda x: x, self.map_size)], 93 | ] 94 | 95 | self.inconv = nn.Sequential( 96 | double_conv(opt.shapeUnits.num, 128), 97 | double_conv(128, 128), 98 | double_conv(128, 256), 99 | ) 100 | # self.inconv = nn.Sequential( 101 | # double_conv(opt.shapeUnits.num, 64), 102 | # double_conv(64, 64), 103 | # double_conv(64, 128), 104 | # ) 105 | 106 | 107 | self.up1 = up_net_sole(256, 256) # up_net 1 108 | self.up2 = up_net_sole(256, 256) # up_net 2 109 | self.up3 = up_net_sole(256, 512) # up_net 3 110 | 111 | # self.up1 = up_net_sole(128, 256) # up_net 1 112 | # self.up2 = up_net_sole(256, 256) # up_net 2 113 | # self.up3 = up_net_sole(256, 512) # up_net 3 114 | 115 | # self.outconv = nn.Sequential( 116 | # double_conv(512, 256), 117 | # double_conv(256, 128), 118 | # double_conv(128, 64) 119 | # ) 120 | self.outconv = nn.Sequential( 121 | double_conv(512, 256), 122 | double_conv(256, 128), 123 | double_conv(128, 64), 124 | ) 125 | self.outcome = nn.Sequential( 126 | nn.Conv2d(64, 1, 1), 127 | nn.ReLU() 128 | ) 129 | 130 | def forward(self, x): 131 | # Shallow Convolutional Network 132 | x = self.inconv(x) 133 | 134 | # Up-sampling Deconvolutional Network 135 | x = self.up1(x, self.target_size[0]) 136 | x = self.up2(x, self.target_size[1]) 137 | x = self.up3(x, self.target_size[2]) 138 | 139 | # Shallow Convolutional Averaging Network 140 | x = self.outconv(x) 141 | 142 | # Convolutional Layer 143 | x = self.outcome(x) 144 | 145 | return x 146 | 147 | 148 | class map_generate(nn.Module): 149 | """ 150 | Combine the K identical structures. 151 | """ 152 | def __init__(self, opt): 153 | super(map_generate, self).__init__() 154 | self.map_blocks = nn.ModuleList([map_block(opt) for _ in range(opt.map_order)]) 155 | self.opt = opt 156 | 157 | def forward(self, x): 158 | device = x.device 159 | out = torch.zeros(1, self.opt.map_order, self.opt.map_size[0], self.opt.map_size[1]).to(device) 160 | for i in range(self.opt.map_order): 161 | out[0, i] = self.map_blocks[i](x)[0,0] 162 | 163 | return out 164 | 165 | 166 | -------------------------------------------------------------------------------- /modules/Linear_RE_Net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | 9 | 10 | class linear_unit(nn.Module): 11 | """ 12 | The basic linear layer. 13 | """ 14 | def __init__(self, ch): 15 | super(linear_unit, self).__init__() 16 | self.fc = nn.Sequential( 17 | nn.Linear(1, ch, bias=False), 18 | # nn.Linear(1, 16), 19 | # nn.Linear(16,ch), 20 | nn.ReLU(inplace=True) 21 | ) 22 | self.fc.apply(self.init_weight) 23 | 24 | def forward(self, x): 25 | return self.fc(x) 26 | 27 | @staticmethod 28 | def init_weight(m): 29 | if type(m) == nn.Linear: 30 | # init.uniform_(m.weight, 0, 0.1) 31 | init.constant_(m.weight, 0.1) # initialize 32 | 33 | 34 | class map_generate(nn.Module): 35 | """ 36 | Reshape the linear layers to form a obstacle map. 37 | """ 38 | def __init__(self, opt): 39 | super(map_generate, self).__init__() 40 | self.opt = opt 41 | self.units = self._set_units(opt.map_size[0]*opt.map_size[1], opt.map_order) 42 | # self.map_seed = torch.ones(1, 1, opt.map_size[0], opt.map_size[1]) 43 | 44 | def forward(self, x): 45 | map_seed = x.to(self.opt.device) 46 | 47 | return self._units_calc(self.units, map_seed, self.opt) 48 | 49 | def _set_units(self, num, map_order, unit=linear_unit): 50 | """ 51 | Set corresponding number of the linear layers. 52 | 53 | Args: 54 | num: the number of the linear layers 55 | """ 56 | units = nn.ModuleList([unit(map_order) for i in range(num)]) 57 | return units 58 | 59 | def _units_calc(self, units, x, opt): 60 | """ 61 | Obtain each linear layer's output. 62 | """ 63 | out = torch.zeros(opt.map_size[0]*opt.map_size[1], opt.map_order).to(x.device) 64 | # print(x.shape) 65 | for i in range(len(units)): 66 | out_i = units[i](x) 67 | # print(out_i.shape) 68 | # out = torch.cat((out, out_i), dim=2) 69 | out[i] = out_i 70 | 71 | return out.permute([1,0])[None,].view(1, opt.map_order, opt.map_size[0], opt.map_size[1]) 72 | 73 | @staticmethod 74 | def units_loss(out, std, loss_func, opt): 75 | loss = torch.zeros([1]).to(opt.device) 76 | for i in range(len(out)): 77 | lossi = loss_func(out[i].to(opt.device), std[i].to(opt.device)) 78 | loss += lossi 79 | return loss 80 | 81 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zppppppx/deep-learning-based-radioMap/aff528a1281f0432e86064d5ddb2537449e89d06/modules/__init__.py -------------------------------------------------------------------------------- /modules/algorithmSim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | from modules.weightGen import * 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | from utils import Map, dist 8 | from modules import Convolutional_RE_Net 9 | from modules import Linear_RE_Net 10 | 11 | 12 | # def barrier_check(Position_map, Obstacle_maps): 13 | # # Set inputs to the same shape to compare 14 | # pos = Position_map.repeat(1, Obstacle_maps.size(1), 1, 1) 15 | # obs = Obstacle_maps.repeat(Position_map.size(0), 1, 1, 1) 16 | 17 | 18 | # # Conventional techniques 19 | # # obs[pos == 0] = 0 # set the Null positions to zeros 20 | # # obs[pos >= obs] = 0 # set the positions where there is no obstruct to zeros 21 | # # obs[obs != 0] = 1 # set the positions where there is obstruct to ones 22 | 23 | # diff = obs*pos - pos*pos 24 | # out = F.relu(diff) 25 | 26 | # return out 27 | 28 | 29 | def barrier_check_pos(Position_map, Obstacle_maps): 30 | """ 31 | This function realizes the function to generate the obstruction indications (see the obstruction network). 32 | pos means positive, which is to say we obtain the obstructed elements from those of the Obstacle Maps higher than 33 | those of the Position Map. 34 | 35 | ################################# 36 | # We mainly adopt this solution # 37 | ################################# 38 | 39 | Args: 40 | Position_map: generated from the location pairs. 41 | Obstacle_maps: generated from the RE-Net. 42 | 43 | Return: 44 | out: obstruction relationships turned into a vector, see the Obstruction Net part. 45 | """ 46 | # Set inputs to the same shape to compare 47 | pos = Position_map.repeat(1, Obstacle_maps.size(1), 1, 1) 48 | obs = Obstacle_maps.repeat(Position_map.size(0), 1, 1, 1) 49 | 50 | # Conventional techniques 51 | # obs[pos == 0] = 0 # set the Null positions to zeros 52 | # obs[pos >= obs] = 0 # set the positions where there is no obstruct to zeros 53 | # obs[obs != 0] = 1 # set the positions where there is obstruct to ones 54 | 55 | diff = obs*pos - pos*pos 56 | out = F.relu(diff) 57 | 58 | out = out.sum(-1).sum(-1) 59 | 60 | return out 61 | 62 | def barrier_check_neg(Position_map, Obstacle_maps): 63 | """ 64 | This function realizes the function to generate the obstruction indications (see the obstruction network). 65 | neg means negative, which is to say we obtain the obstructed elements from those of the Position Map lower than 66 | those of the Obstacle Map. 67 | 68 | Args: 69 | Position_map: generated from the location pairs. 70 | Obstacle_maps: generated from the RE-Net. 71 | 72 | Return: 73 | out: obstruction relationships turned into a vector, see the Obstruction Net part. 74 | """ 75 | pos = Position_map.repeat(1, Obstacle_maps.size(1), 1, 1) 76 | obs = Obstacle_maps.repeat(Position_map.size(0), 1, 1, 1) 77 | 78 | diff = F.relu(pos-obs)*pos + obs*pos 79 | diff = torch.abs(diff - pos*pos) 80 | out = diff.sum(-1).sum(-1) 81 | 82 | return out 83 | 84 | 85 | ###################### 86 | # Proposed Framework # 87 | ###################### 88 | 89 | class Algorithm(nn.Module): 90 | """ 91 | Args: 92 | x: location pair. 93 | map: metainfo, stacked fundamental shapes for Convolutional RE-Net and 1 for Linear RE-Net. 94 | 95 | Return: 96 | out: predicted RSS. 97 | obs_weight: tell the propagation condition, namely Sk, generated the Obstruction Network. 98 | Obstacle_maps: the virtual obstacle map. 99 | multi-channel gain: 100 | """ 101 | def __init__(self, opt, scale=50): 102 | super(Algorithm, self).__init__() 103 | self.opt = opt 104 | self.scale = scale 105 | self.params_model = nn.Linear(2, opt.map_order+1, bias=False) # Multichannel Gain. 106 | self.map = Convolutional_RE_Net.map_generate(opt) if opt.RE_Net == 'Conv' else Linear_RE_Net.map_generate(opt) # RE-Net 107 | self.weight_gen_ob = weightGen_ob(opt) # nonlinear model in the Obstruction Network. 108 | 109 | 110 | ################################################## 111 | # Initialize the parameters in Multichannel Gain # 112 | ################################################## 113 | 114 | # self.params_model.weight = torch.nn.Parameter(torch.tensor([[-16.,-35.],[-24.,-27.],[-36.,-18.],[-55., 12.]]), requires_grad=True) 115 | # self.params_model.weight = torch.nn.Parameter(torch.tensor([[-22.,-27.],[-28.,-23.],[-36.,-21.]]), requires_grad=True) 116 | 117 | self.params_model.weight = torch.nn.Parameter(torch.tensor([[-22.,-27.],[-28.,-24.],[-36.,-23.]]), requires_grad=True) 118 | # self.params_model.weight = torch.nn.Parameter(torch.tensor([[-22.,-28.],[-36.,-22.]]), requires_grad=True) 119 | 120 | 121 | # self.params_model.weight = torch.nn.Parameter(torch.tensor([[-22.,-27.],[-28.8,-22.],[-36.,-21.8]]), requires_grad=True) 122 | 123 | 124 | # G/P(N) 125 | # self.params_model.weight = torch.nn.Parameter(torch.tensor([[-21.,-30.],[-38.,-20.]]), requires_grad=True) 126 | # self.params_model.weight = torch.nn.Parameter(torch.tensor([[-23.28,-23.259],[-33.492,-24.262],[-43.577,-12.865]]), requires_grad=True) 127 | # self.params_model.weight = torch.nn.Parameter(torch.tensor( 128 | # [[-22.,-28.],[-27.,-26.],[-32.,-24.],[-36., -22.]]), requires_grad=True) 129 | # self.params_model.weight = torch.nn.Parameter(torch.tensor( 130 | # [[-28., -24.],[-28., -24.],[-28., -24.],[-28., -24.],[-28., -24.]]), requires_grad=True) # 5 131 | # self.params_model.weight = torch.nn.Parameter(torch.tensor( 132 | # [[-33.,-16.3],[-33.,-17.3],[-33.,-18.3],[-33.,-19.3],[-33.,-20.3]]), requires_grad=True) 133 | # self.params_model.weight = torch.nn.Parameter(torch.tensor( 134 | # [[-33.,-10.3],[-33.,-12.3],[-33.,-14.3],[-33.,-16.3],[-33.,-18.3], 135 | # [-33.,-20.3],[-33.,-22.3],[-33.,-24.3],[-33.,-26.3],[-33.,-28.3]]), requires_grad=True) 136 | 137 | 138 | def forward(self, x, mapseed): 139 | device = x.device 140 | 141 | # Branch 1: The distance param models. 142 | log_dist = dist.distance(x.cpu()).to(device) 143 | param_output = self.params_model(log_dist) 144 | 145 | # Branch 2: Generate the map and the final barrier map 146 | Position_maps = Map.locs_to_map(x.cpu(), self.opt)/self.scale 147 | Position_maps = Position_maps.to(device) 148 | Obstacle_maps = self.map(mapseed) 149 | 150 | Barrier_vec = barrier_check_pos(Position_maps, Obstacle_maps) # 151 | obs_weight = self.weight_gen_ob(Barrier_vec) # The obstruction degree 152 | 153 | # calculate the model's output 154 | out = param_output*obs_weight 155 | out = out.sum(-1) 156 | 157 | return out, obs_weight, Obstacle_maps, param_output 158 | 159 | 160 | -------------------------------------------------------------------------------- /modules/weightGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.activation import ReLU 5 | from torch.nn.modules.linear import Linear 6 | from utils import Config 7 | 8 | 9 | 10 | class weightGen_ob(nn.Module): 11 | """ 12 | The nonlinear model in the Obstruction Network. 13 | """ 14 | def __init__(self, opt=Config.Config()): 15 | super(weightGen_ob, self).__init__() 16 | self.mlp = nn.Sequential( 17 | nn.Linear(opt.map_order, 128), 18 | nn.ReLU(inplace=True), 19 | nn.Linear(128, 128), 20 | nn.ReLU(inplace=True), 21 | nn.Linear(128, opt.map_order+1), 22 | nn.Softmax(dim=1) 23 | ) 24 | 25 | def forward(self, x): 26 | # x = x.sum(-1).sum(-1) 27 | # print(x.shape) 28 | x = self.mlp(x) 29 | return x 30 | 31 | class weightGen_rss(nn.Module): 32 | """ 33 | The nonlinear model in the RM-Net. 34 | """ 35 | def __init__(self, opt=Config.Config()): 36 | super(weightGen_rss, self).__init__() 37 | self.opt = opt 38 | self.mlp = nn.Sequential( 39 | nn.Linear(opt.map_order+1, 32), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(32,32), 42 | nn.ReLU(inplace=True), 43 | nn.Linear(32, opt.map_order+1), 44 | nn.Softmax(dim=1) 45 | ) 46 | 47 | def forward(self, rss_calc, rss_std): 48 | x = torch.abs(rss_calc - rss_std) 49 | return self.mlp(x) 50 | -------------------------------------------------------------------------------- /performance_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from dataset import locPair 5 | from modules import algorithmSim, weightGen 6 | import utils.Config as Config 7 | from torch.utils.data import DataLoader 8 | import torch.optim as optim 9 | import matplotlib.pyplot as plt 10 | from utils.load_dict import load_dict 11 | 12 | 13 | 14 | opt = Config.Config() 15 | file_path = '../data/radiomap_simulated100tx_3class_noise.mat' 16 | # file_path = '../data/radiomap_simulated100tx.mat' 17 | net_path = opt.RE_Net + '_3_' + str(opt.map_order+1) + '.pth' # 3 for 3class. Suit yourself. 18 | 19 | radio_net = algorithmSim.Algorithm(opt).to(opt.device) 20 | radio_data_val = locPair.RadioMap(file_path, fraction=1, seed=5) 21 | 22 | 23 | if os.path.exists(net_path): 24 | state = torch.load(net_path) 25 | radio_net.load_state_dict(state['net']) 26 | print(radio_net.params_model.state_dict()) 27 | 28 | criterion = torch.nn.L1Loss(reduction='sum').to(opt.device) 29 | 30 | 31 | with torch.no_grad(): 32 | loader_val = DataLoader(radio_data_val, batch_size=opt.batchsize) 33 | running_loss = 0. 34 | 35 | for idx, data in enumerate(loader_val, 0): 36 | locs, rss = data 37 | locs, rss = locs.to(opt.device), rss.to(opt.device) 38 | 39 | rss_pre, weight, Obstacle_maps, rss_calc = radio_net(locs, opt.metainfo.to(opt.device)) 40 | loss = criterion(rss_pre, rss) 41 | 42 | running_loss += loss.item() 43 | 44 | if idx % 10 == 9: 45 | print('idx: %d, Accumulated MAE: %.3f, Batch MAE: %.3f' % 46 | (idx, running_loss/((idx+1)*opt.batchsize), loss.item()/opt.batchsize)) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import torch 3 | from modules.Convolutional_RE_Net import map_generate 4 | from utils import Config 5 | 6 | opt = Config.Config() 7 | mapseed = opt.shapeUnits.units 8 | 9 | Map = map_generate(opt) 10 | out = Map(mapseed) 11 | 12 | print(out.shape) -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import torch 5 | from dataset import locPair 6 | from modules import algorithmSim, weightGen 7 | import utils.Config as Config 8 | from torch.utils.data import DataLoader 9 | import torch.optim as optim 10 | import matplotlib.pyplot as plt 11 | from utils.load_dict import load_dict 12 | 13 | ################## 14 | # Basic Settings # 15 | ################## 16 | opt = Config.Config() 17 | file_path = '../data/radiomap_simulated100tx_3class_noise.mat' 18 | # file_path = '../data/radiomap_simulated100tx.mat' 19 | net_path = opt.RE_Net + '_3_' + str(opt.map_order+1) + '.pth' # 3 for 3class. Suit yourself. 20 | 21 | 22 | weightNet_ob_path = './weight_ob_' + str(opt.map_order) + '.pth' 23 | weightNet_rss_path = './weight_rss_' + str(opt.map_order) + '.pth' 24 | 25 | 26 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 27 | radio_data_tr = locPair.RadioMap(file_path, choice='synthetic') 28 | batchsize = opt.batchsize 29 | epoch = opt.epoch 30 | 31 | weight_rss = weightGen.weightGen_rss(opt).to(device) 32 | radio_net = algorithmSim.Algorithm(opt).to(device) # RadioNet 33 | 34 | # Load trained part 35 | radio_net = load_dict(radio_net, weightNet_ob_path, 'weight_gen_ob') 36 | weight_rss.load_state_dict(torch.load(weightNet_rss_path)) 37 | 38 | 39 | ################# 40 | # Training Loss # 41 | ################# 42 | criterion_rss = torch.nn.MSELoss().to(device) 43 | criterion_ob = torch.nn.MSELoss().to(device) 44 | 45 | 46 | ############# 47 | # Optimizer # 48 | ############# 49 | optimizer_radio_RENet = optim.Adam([{'params':radio_net.params_model.parameters(), 'lr': 0}, 50 | {'params':radio_net.map.parameters(), 'lr': opt.lr_map}, 51 | {'params':radio_net.weight_gen_ob.parameters(), 'lr': 0}]) 52 | 53 | optimizer_radio_multiChannel = optim.SGD([{'params':radio_net.params_model.parameters(), 'lr': opt.lr_param}, 54 | {'params':radio_net.map.parameters(), 'lr': 0}, 55 | {'params':radio_net.weight_gen_ob.parameters(), 'lr': 0}]) 56 | 57 | optimizer_radio_finetune = optim.Adam([{'params':radio_net.params_model.parameters(), 'lr': 5e-3}, 58 | {'params':radio_net.map.parameters(), 'lr': 5e-3}, 59 | {'params':radio_net.weight_gen_ob.parameters(), 'lr': 0}]) 60 | 61 | 62 | #################### 63 | # Training section # 64 | #################### 65 | T = True 66 | F = False 67 | train_rss = T 68 | train_ob = T 69 | finetune = F 70 | 71 | 72 | if os.path.exists(net_path): 73 | state = torch.load(net_path) 74 | radio_net.load_state_dict(state['net']) 75 | print(radio_net.params_model.state_dict()) 76 | 77 | time_start=time.time() 78 | 79 | for epoch in range(opt.epoch): 80 | ######## Obstacle map 81 | if train_ob: 82 | for i in range(opt.epoch_map): 83 | loader_tr = DataLoader(radio_data_tr, batch_size=batchsize, shuffle=True) 84 | 85 | running_loss = 0. 86 | 87 | for idx, data in enumerate(loader_tr, 0): 88 | locs, rss = data 89 | locs = locs.to(device) 90 | rss = rss.to(device) 91 | 92 | rss_pre, weight, Obstacle_maps, rss_calc = radio_net(locs, opt.metainfo.to(device)) 93 | std = weight_rss(rss_calc, rss.unsqueeze(1)) 94 | 95 | loss = 10*criterion_ob(weight, std) # MSE loss 96 | loss.backward() 97 | optimizer_radio_RENet.step() 98 | optimizer_radio_RENet.zero_grad() 99 | running_loss += loss.item() 100 | 101 | 102 | if idx % 10 == 9: 103 | print('RE-Net: outer: %d, inner: %d, idx: %d, loss: %.4f'%(epoch, i, idx, running_loss/10)) 104 | # print(rss_calc[:3], rss[:3], classes[:3], weight[:3]) 105 | # print(radio_net.map.outcome.state_dict()) 106 | # plt.imshow(Obstacle_maps.cpu().detach().squeeze(), cmap='plasma', vmin=0, vmax=2) 107 | # plt.show() 108 | state = {'net': radio_net.state_dict()} 109 | torch.save(state, net_path) 110 | 111 | running_loss = 0. 112 | # if idx == 169: 113 | # break 114 | 115 | state = {'net': radio_net.state_dict()} 116 | torch.save(state, net_path) 117 | 118 | 119 | ############ RSS 120 | if train_rss: 121 | for i in range(opt.epoch_param): 122 | loader_tr = DataLoader(radio_data_tr, batch_size=batchsize, shuffle=True) 123 | 124 | running_loss = 0. 125 | # check = torch.ones([2,2]).cuda() 126 | for idx, data in enumerate(loader_tr, 0): 127 | locs, rss = data 128 | locs = locs.to(device) 129 | rss = rss.to(device) 130 | 131 | rss_pre, weight, Obstacle_maps, rss_calc = radio_net(locs, opt.metainfo.to(device)) 132 | 133 | 134 | loss = criterion_rss(rss_pre, rss) 135 | loss.backward() 136 | optimizer_radio_multiChannel.step() 137 | optimizer_radio_multiChannel.zero_grad() 138 | running_loss += loss.item() 139 | 140 | 141 | if idx % 10 == 9: 142 | print('Multi-Channel: outer: %d, inner: %d, idx: %d, loss: %.4f'%(epoch, i, idx, running_loss/10)) 143 | print(radio_net.params_model.state_dict()) 144 | state = {'net': radio_net.state_dict()} 145 | torch.save(state, net_path) 146 | 147 | running_loss = 0. 148 | 149 | # if idx == 169: 150 | # break 151 | 152 | state = {'net': radio_net.state_dict()} 153 | torch.save(state, net_path) 154 | 155 | time_end=time.time() 156 | print('Time consumed: ', time_end-time_start) 157 | 158 | 159 | 160 | if finetune: 161 | for i in range(opt.epoch_finetune): 162 | loader_tr = DataLoader(radio_data_tr, batch_size=batchsize, shuffle=True) 163 | 164 | running_loss = 0. 165 | # check = torch.ones([2,2]).cuda() 166 | for idx, data in enumerate(loader_tr, 0): 167 | locs, rss = data 168 | locs = locs.to(device) 169 | rss = rss.to(device) 170 | 171 | rss_pre, weight, Obstacle_maps, rss_calc = radio_net(locs, opt.metainfo.to(device)) 172 | std = weight_rss(rss_calc, rss.unsqueeze(1)) 173 | 174 | 175 | loss = criterion_rss(rss_pre, rss) 176 | loss.backward() 177 | optimizer_radio_finetune.step() 178 | optimizer_radio_finetune.zero_grad() 179 | running_loss += loss.item() 180 | 181 | 182 | if idx % 2 == 1: 183 | print('Mini-modu: whole: %d, epoch: %d, idx: %d, loss: %.4f'%(epoch, i, idx, running_loss/10)) 184 | print(radio_net.params_model.state_dict()) 185 | state = {'net': radio_net.state_dict()} 186 | torch.save(state, net_path) 187 | 188 | running_loss = 0. 189 | 190 | # if idx == 99: 191 | # break 192 | 193 | state = {'net': radio_net.state_dict()} 194 | torch.save(state, net_path) 195 | -------------------------------------------------------------------------------- /utils/Config.py: -------------------------------------------------------------------------------- 1 | from PIL.Image import NONE 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | class fundamentalShape: 7 | def __init__(self) -> None: 8 | square = torch.tensor([ 9 | [0,0,0,0,0,0,0,0,0,0,0], 10 | [0,1,1,1,1,1,1,1,1,1,0], 11 | [0,1,1,1,1,1,1,1,1,1,0], 12 | [0,1,1,1,1,1,1,1,1,1,0], 13 | [0,1,1,1,1,1,1,1,1,1,0], 14 | [0,1,1,1,1,1,1,1,1,1,0], 15 | [0,1,1,1,1,1,1,1,1,1,0], 16 | [0,1,1,1,1,1,1,1,1,1,0], 17 | [0,1,1,1,1,1,1,1,1,1,0], 18 | [0,1,1,1,1,1,1,1,1,1,0], 19 | [0,0,0,0,0,0,0,0,0,0,0] 20 | ]) 21 | rectangular = torch.tensor([ 22 | [0,0,0,0,0,0,0,0,0,0,0], 23 | [0,0,0,0,0,0,0,0,0,0,0], 24 | [0,1,1,1,1,1,1,1,1,1,0], 25 | [0,1,1,1,1,1,1,1,1,1,0], 26 | [0,1,1,1,1,1,1,1,1,1,0], 27 | [0,1,1,1,1,1,1,1,1,1,0], 28 | [0,1,1,1,1,1,1,1,1,1,0], 29 | [0,1,1,1,1,1,1,1,1,1,0], 30 | [0,1,1,1,1,1,1,1,1,1,0], 31 | [0,0,0,0,0,0,0,0,0,0,0], 32 | [0,0,0,0,0,0,0,0,0,0,0] 33 | ]) 34 | triangle_ortho = torch.tensor([ 35 | [0,0,0,0,0,0,0,0,0,0,0], 36 | [0,1,0,0,0,0,0,0,0,0,0], 37 | [0,1,1,0,0,0,0,0,0,0,0], 38 | [0,1,1,1,0,0,0,0,0,0,0], 39 | [0,1,1,1,1,0,0,0,0,0,0], 40 | [0,1,1,1,1,1,0,0,0,0,0], 41 | [0,1,1,1,1,1,1,0,0,0,0], 42 | [0,1,1,1,1,1,1,1,0,0,0], 43 | [0,1,1,1,1,1,1,1,1,0,0], 44 | [0,1,1,1,1,1,1,1,1,1,0], 45 | [0,0,0,0,0,0,0,0,0,0,0] 46 | ]) 47 | circle = torch.tensor([ 48 | [0,0,0,0,0,0,0,0,0,0,0], 49 | [0,0,0,0,1,1,1,0,0,0,0], 50 | [0,0,1,1,1,1,1,1,1,0,0], 51 | [0,0,1,1,1,1,1,1,1,0,0], 52 | [0,1,1,1,1,1,1,1,1,1,0], 53 | [0,1,1,1,1,1,1,1,1,1,0], 54 | [0,1,1,1,1,1,1,1,1,1,0], 55 | [0,0,1,1,1,1,1,1,1,0,0], 56 | [0,0,1,1,1,1,1,1,1,0,0], 57 | [0,0,0,0,1,1,1,0,0,0,0], 58 | [0,0,0,0,0,0,0,0,0,0,0] 59 | ]) 60 | ellipse = torch.tensor([ 61 | [0,0,0,0,0,0,0,0,0,0,0], 62 | [0,0,0,0,0,0,0,0,0,0,0], 63 | [0,0,0,1,1,1,1,1,0,0,0], 64 | [0,0,1,1,1,1,1,1,1,0,0], 65 | [0,1,1,1,1,1,1,1,1,1,0], 66 | [0,1,1,1,1,1,1,1,1,1,0], 67 | [0,1,1,1,1,1,1,1,1,1,0], 68 | [0,0,1,1,1,1,1,1,1,0,0], 69 | [0,0,0,1,1,1,1,1,0,0,0], 70 | [0,0,0,0,0,0,0,0,0,0,0], 71 | [0,0,0,0,0,0,0,0,0,0,0] 72 | ]) 73 | square_holo = torch.tensor([ 74 | [0,0,0,0,0,0,0,0,0,0,0], 75 | [0,1,1,1,1,1,1,1,1,1,0], 76 | [0,1,0,0,0,0,0,0,0,1,0], 77 | [0,1,0,0,0,0,0,0,0,1,0], 78 | [0,1,0,0,0,0,0,0,0,1,0], 79 | [0,1,0,0,0,0,0,0,0,1,0], 80 | [0,1,0,0,0,0,0,0,0,1,0], 81 | [0,1,0,0,0,0,0,0,0,1,0], 82 | [0,1,0,0,0,0,0,0,0,1,0], 83 | [0,1,1,1,1,1,1,1,1,1,0], 84 | [0,0,0,0,0,0,0,0,0,0,0] 85 | ]) 86 | 87 | self.size = square.shape 88 | self.shapes = [square, square_holo, rectangular, triangle_ortho, circle, ellipse] 89 | self.num = len(self.shapes) 90 | self.units = self._unitsCombine(self.shapes) 91 | 92 | def _unitsCombine(self, shapes): 93 | out = torch.tensor([]) 94 | for shape in shapes: 95 | shape_i = shape[None,] 96 | out = torch.cat([out, shape_i], dim=0) 97 | 98 | return out[None] 99 | 100 | 101 | class Config: 102 | map_order = 2 103 | 104 | # PAY ATTENTION: We assign the resolution to 3 105 | resolution = 3. 106 | map_size = [94, 94] # 3class for [94, 94] 2class for [101, 101]. Simulation for [114, 114]. 107 | 108 | RE_Net = 'Linear' # 'Conv' for Convolutional RE-Net and 'Linear' for Linear RE-Net 109 | data_fraction = {"3class":0.012, "2class":0.0098} # To assure the amount used is close to 10000 110 | 111 | lr = 1e-5 112 | lr_param = 1e-2 113 | lr_map = 1e-4 if RE_Net == 'Conv' else 1e-2 114 | 115 | batchsize = 128 116 | epoch = 5 117 | epoch_param = 1 118 | epoch_map = 2 119 | epoch_finetune = 5 120 | device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") 121 | shapeUnits = fundamentalShape() 122 | metainfo = shapeUnits.units if RE_Net == 'Conv' else torch.ones(1,1, requires_grad=True) 123 | 124 | 125 | 126 | 127 | if __name__ == '__main__': 128 | shape = fundamentalShape() 129 | print(shape.units.shape) 130 | plt.imshow(shape.shapes[0]) 131 | plt.show() -------------------------------------------------------------------------------- /utils/Map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | 5 | 6 | def meter_to_index(loc_meter, resolution): 7 | """ 8 | This function realizes the funcion of transforming the location in reality to 9 | the location in an index form, and the resolution indicates how accurate the 10 | location is. 11 | 12 | PAY ATTENTION: we don't need to transform the height of the location into 13 | index forms, cuz we need the information more precisely. 14 | 15 | Args: 16 | loc_meter: location in the meter form (other units can also work). 17 | resolution: how accurate the index is, the smaller the resolution, the more 18 | accurate the indexes. 19 | 20 | Returns: 21 | loc_index: loaction in the index form. 22 | """ 23 | loc_index = torch.round(loc_meter/resolution).int() 24 | return loc_index 25 | 26 | def loc_pair_to_map(loc_u, loc_d, resolution, map_size, mode='height'): 27 | """ 28 | This function realizes the function of transforming the location pair of a user 29 | and a UAV into a map with the shape of map_size. Each value indicates the height 30 | of line between them at the specific index. 31 | 32 | Args: 33 | loc_u: the location of the user in a reality form. 34 | loc_d: the location of the UAV in a reality form. 35 | resolution: how accurate the index is, the smaller the resolution, the more 36 | accurate the indexes. 37 | map_size: the size of the map we need to generate. 38 | mode: indicate the output mode, 'height' for height map and 'indication' for 39 | only indication of the connection of them. 40 | 41 | Returns: 42 | Position_map: a 2D map with specific size of map_size. 43 | """ 44 | Position_map = torch.zeros(map_size) # Initlialize the map 45 | if torch.abs(loc_u[:2]-loc_d[:2]).sum()<0.001: 46 | loc_line = meter_to_index(loc_u[:2], resolution) 47 | Position_map[loc_line[0], loc_line[1]] = torch.abs(loc_d[-1]-loc_u[-1]) 48 | 49 | return Position_map 50 | 51 | 52 | 53 | Sum_map = torch.zeros(map_size) # Calculate how many times the height at the same location has been calculated. 54 | delta = loc_u - loc_d # Set the location of the Drone as the standard location. 55 | row_delta, col_delta, height_delta = delta 56 | 57 | row_num, col_num = \ 58 | torch.abs(torch.round(row_delta/resolution)).int(), torch.abs(torch.round(col_delta/resolution)).int() 59 | 60 | for i in range(row_num): 61 | tangent = col_delta/row_delta 62 | row = row_delta/row_num*i + loc_d[0] 63 | col = tangent*row_delta/row_num*i + loc_d[1] 64 | height = loc_d[2] + height_delta/row_num*i 65 | 66 | loc_row_index = meter_to_index(torch.tensor([row, col], dtype=float), resolution) 67 | # print('row, loc_row_index',loc_row_index) 68 | Position_map[loc_row_index[0], loc_row_index[1]] += height if mode=='height' else 255 69 | Sum_map[loc_row_index[0], loc_row_index[1]] += 1 70 | 71 | # Calculate for the col direction 72 | for i in range(col_num): 73 | tangent = row_delta/col_delta 74 | col = col_delta/col_num*i + loc_d[1] 75 | row = tangent*col_delta/col_num*i + loc_d[0] 76 | height = loc_d[2] + height_delta/col_num*i 77 | 78 | loc_row_index = meter_to_index(torch.tensor([row, col], dtype=float), resolution) 79 | # print('col, loc_row_index',loc_row_index) 80 | Position_map[loc_row_index[0], loc_row_index[1]] += height if mode=='height' else 255 81 | Sum_map[loc_row_index[0], loc_row_index[1]] += 1 82 | 83 | 84 | # print(Position_map) 85 | Sum_map[Sum_map == 0] += 1 86 | Fraction = 1/Sum_map 87 | Position_map = Position_map*Fraction 88 | 89 | return Position_map 90 | 91 | 92 | def locs_to_map(locs, opt): 93 | """ 94 | To generate the Position Map tensor. 95 | """ 96 | batch_size = locs.size(0) 97 | Position_maps = torch.zeros([batch_size, 1, opt.map_size[-2], opt.map_size[-1]]) 98 | 99 | for i in range(batch_size): 100 | Position_maps[i, 0, :] = loc_pair_to_map(locs[i, :3].cpu(), locs[i, 3:6].cpu(), opt.resolution, opt.map_size[-2:]) 101 | 102 | return Position_maps 103 | 104 | 105 | 106 | 107 | 108 | if __name__ == '__main__': 109 | size = [101,101] 110 | loc_u = torch.tensor([80, 70, 0]) 111 | loc_d = torch.tensor([20, 10, 10]) 112 | # loc_u = np.zeros([100, 3]) 113 | # loc_d = np.zeros([100, 3]) 114 | print(loc_d.shape) 115 | resolution = 1 116 | 117 | Position_map = loc_pair_to_map(loc_u, loc_d, resolution, size).cpu() 118 | print(Position_map[None, None,].shape) 119 | print(Position_map) 120 | plt.imshow(Position_map) 121 | plt.show() 122 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zppppppx/deep-learning-based-radioMap/aff528a1281f0432e86064d5ddb2537449e89d06/utils/__init__.py -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataset import locPair 3 | import matplotlib.pyplot as plt 4 | 5 | def distance(locs): 6 | """ 7 | This function realizes the function of computing the log distances between 8 | the user and the drone. And a number 1 is adjuncted to the tail only for calculation convenience. 9 | """ 10 | locs1 = locs[:, :3] 11 | locs2 = locs[:, 3:6] 12 | dist_absolute = (locs1 - locs2).pow(2).sum(-1).pow(0.5) 13 | log_dist = torch.log10(dist_absolute).unsqueeze(1) 14 | ones = torch.ones([locs.size(0), 1]) 15 | return torch.cat((log_dist, ones), dim=1) 16 | 17 | def dist(loc): 18 | """ 19 | This function realizes the function of computing the log distances between 20 | the user and the drone. 21 | """ 22 | loc1 = loc[:3] 23 | loc2 = loc[3:6] 24 | dist_absolute = (loc1 - loc2).pow(2).sum(-1).pow(0.5) 25 | log_dist = torch.log10(dist_absolute) 26 | 27 | return log_dist 28 | 29 | def dists(locs): 30 | """ 31 | This function realizes the function of computing the log distances between 32 | the users and the drones. 33 | """ 34 | locs1 = locs[:, :3] 35 | locs2 = locs[:, 3:6] 36 | dist_absolute = (locs1 - locs2).pow(2).sum(-1).pow(0.5) 37 | log_dist = torch.log10(dist_absolute) 38 | 39 | return log_dist 40 | 41 | 42 | if __name__ == '__main__': 43 | file_path = '../data/radiomap_simulated100tx_3class.mat' 44 | radio_data_tr = locPair.RadioMap(file_path, fraction=0.0005) 45 | 46 | length = len(radio_data_tr) 47 | for i in range(length): 48 | locs, _, rss = radio_data_tr[i] 49 | log_dist = dist(locs) 50 | -------------------------------------------------------------------------------- /utils/line_calc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def line_calc(dot1, dot2): 4 | diff = dot2-dot1 5 | k = diff[1]/diff[0] 6 | b = dot1[1]-k*dot1[0] 7 | 8 | return k,b 9 | 10 | if __name__ == '__main__': 11 | dot1 = np.array([2.316,-136.3]) 12 | dot2 = np.array([1.82,-79.6]) 13 | 14 | k, b = line_calc(dot1, dot2) 15 | print('The line is y = %.3f x + %.3f' % (k,b)) -------------------------------------------------------------------------------- /utils/load_dict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def load_dict(mainnet, subnetPath, preffix): 4 | """ 5 | This function realizes loading the state dict of subnet in the main net. 6 | 7 | Args: 8 | mainnet: the net in need of loading the subnet. 9 | subnetPath: the path of the subnet's pth file. 10 | pre_name: the subnet's pre_name in the mainnet. 11 | """ 12 | preffix = preffix + '.' 13 | mainnet_stateDict = mainnet.state_dict() 14 | subnet_stateDict = torch.load(subnetPath) 15 | stateDict_forupdate = {preffix+k:v for k, v in subnet_stateDict.items() if preffix+k in mainnet_stateDict.keys()} 16 | # print(stateDict_forupdate) 17 | # UNet_dict = {k[start:]:v for k, v in model_dict.items() if k[start:] in map_dict.keys()} 18 | mainnet_stateDict.update(stateDict_forupdate) 19 | mainnet.load_state_dict(mainnet_stateDict) 20 | 21 | return mainnet 22 | 23 | -------------------------------------------------------------------------------- /weight_ob_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zppppppx/deep-learning-based-radioMap/aff528a1281f0432e86064d5ddb2537449e89d06/weight_ob_2.pth -------------------------------------------------------------------------------- /weight_pretrain.py: -------------------------------------------------------------------------------- 1 | 2 | from modules.weightGen import * 3 | from modules.algorithmSim import barrier_check_pos 4 | import utils.Config as Config 5 | import torch.optim as optim 6 | import torch 7 | import torch.nn as nn 8 | import matplotlib.pyplot as plt 9 | import os 10 | import torch.nn.functional as F 11 | import numpy as np 12 | 13 | 14 | def std_out_ob(barrier_vec, opt): 15 | """ 16 | This is standard obstacle map output, designed to train the nonlinear model in the Obstruction Network. 17 | 18 | Args: 19 | barrier_vec: the obstruction indications turned from a tensor to a vector (see the Obstruction Net part) 20 | opt: Config.Config(). Configures. 21 | 22 | Return: 23 | out_std: artificial labels. 24 | """ 25 | batch_size = barrier_vec.size(0) 26 | 27 | out_std = torch.zeros([batch_size, opt.map_order+1]) 28 | for i in range(batch_size): 29 | for j in range(opt.map_order): 30 | check_line = F.relu(barrier_vec[i, opt.map_order-j-1]-0) 31 | if check_line > 0: 32 | out_std[i, opt.map_order-j] += 1 33 | break 34 | 35 | if out_std[i].sum(-1) == 0: 36 | out_std[i][0] = 1 37 | 38 | return out_std 39 | 40 | def std_out_rss(rss_calc, rss_std, opt): 41 | """ 42 | This is standard rss output, designed to provide labels to train the nonlinear model in the RM-Net. 43 | 44 | Args: 45 | rss_calc: calculated rss. 46 | rss_std: standard rss. 47 | 48 | Return: 49 | out_std: standard category. 50 | """ 51 | diff = torch.abs(rss_calc - rss_std) 52 | args = torch.argmin(diff, dim=1) 53 | out_std = torch.zeros(rss_std.size(0), opt.map_order+1) 54 | for i in range(rss_std.size(0)): 55 | out_std[i, args[i]] += 1 56 | 57 | return out_std 58 | 59 | 60 | 61 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 62 | opt = Config.Config() 63 | scale = 1 64 | 65 | feature_net = weightGen_ob(opt).to(device) 66 | net_path_ob = './weight_ob_' + str(opt.map_order) + '.pth' 67 | 68 | 69 | if os.path.exists(net_path_ob): 70 | state = torch.load(net_path_ob) 71 | feature_net.load_state_dict(state) 72 | 73 | 74 | optimizer = optim.Adam([p for p in feature_net.parameters()], lr=1e-2) 75 | criterion = nn.MSELoss().to(device) 76 | 77 | # Training for obstacle 78 | running_loss = 0. 79 | for i in range(30): 80 | barriers_vec = torch.tensor([]) 81 | 82 | for j in range(opt.map_order): 83 | bar = torch.rand([opt.batchsize, opt.map_order]) 84 | bar[:, j:] = 0 85 | barriers_vec = torch.cat((barriers_vec, bar), dim=0) 86 | # print(i, '\t', bar) 87 | bar = torch.rand([opt.batchsize, opt.map_order]) 88 | barriers_vec = torch.cat((barriers_vec, bar), dim=0) 89 | # print(barriers_vec) 90 | rand_index = np.random.randint(0,opt.batchsize*(opt.map_order+1), opt.batchsize*(opt.map_order+1)) 91 | # print(rand_index) 92 | barriers_vec = barriers_vec[rand_index] 93 | barriers_vec = barriers_vec.to(device) 94 | 95 | 96 | # print(barriers.shape) 97 | 98 | # plt.subplot(2,2,1) 99 | # plt.imshow(Position_map[0,0]) 100 | # plt.subplot(2,2,2) 101 | # plt.imshow(Obstacle_map[0,0]) 102 | # plt.subplot(2,2,3) 103 | # plt.imshow(barriers[0,0].cpu().detach()) 104 | # plt.show() 105 | 106 | out_std = std_out_ob(barriers_vec, opt).to(device) 107 | # print(out_std) 108 | 109 | optimizer.zero_grad() 110 | out = feature_net(barriers_vec) 111 | loss = criterion(out, out_std) 112 | loss.backward() 113 | 114 | optimizer.step() 115 | running_loss += loss.item() 116 | 117 | print('epoch: %d, loss: %.4f' % (i, running_loss)) 118 | print(out[0]) 119 | 120 | running_loss = 0 121 | 122 | state = feature_net.state_dict() 123 | torch.save(state, net_path_ob) 124 | 125 | state = feature_net.state_dict() 126 | torch.save(state, net_path_ob) 127 | 128 | # Training for rss 129 | feature_net = weightGen_rss(opt).to(device) 130 | net_path_rss = './weight_rss_' + str(opt.map_order) + '.pth' 131 | 132 | if os.path.exists(net_path_rss): 133 | state = torch.load(net_path_rss) 134 | feature_net.load_state_dict(state) 135 | 136 | 137 | optimizer = optim.Adam([p for p in feature_net.parameters()], lr=1e-2) 138 | criterion = nn.MSELoss().to(device) 139 | 140 | 141 | running_loss = 0. 142 | for i in range(200): 143 | rss_calc = -100*torch.rand(opt.batchsize, opt.map_order+1).to(device) 144 | rss_std = -100*torch.rand(opt.batchsize, 1).to(device) 145 | out_std = std_out_rss(rss_std, rss_calc, opt).to(device) 146 | # print(out_std) 147 | 148 | optimizer.zero_grad() 149 | out = feature_net(rss_calc, rss_std) 150 | loss = criterion(out, out_std) 151 | loss.backward() 152 | 153 | optimizer.step() 154 | running_loss += loss.item() 155 | 156 | # print('epoch: %d, loss: %.4f' % (i, running_loss)) 157 | # print(out[0]) 158 | 159 | running_loss = 0 160 | 161 | state = feature_net.state_dict() 162 | torch.save(state, net_path_rss) 163 | 164 | state = feature_net.state_dict() 165 | torch.save(state, net_path_rss) -------------------------------------------------------------------------------- /weight_rss_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zppppppx/deep-learning-based-radioMap/aff528a1281f0432e86064d5ddb2537449e89d06/weight_rss_2.pth --------------------------------------------------------------------------------