├── .DS_Store ├── Geometry.pyc ├── LICENSE ├── README.md ├── accesspoint.pyc ├── bash.sh ├── center_server.pyc ├── global_parameters.py ├── myplotlib.py ├── rainbow_hac ├── .DS_Store ├── ap_agent.py ├── basic_block.py ├── basic_block_center.py ├── basic_block_center_mix.py ├── center_agent.py ├── game.py ├── memory.py ├── test.py └── train.py ├── requirements.txt ├── tsne.ipynb └── user_correlation.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperflight/Hierarchical-Multi-agent-DRL-with-Federated-Learning/a34213418ed4542a1e2a2641bb9ea26a91ce051b/.DS_Store -------------------------------------------------------------------------------- /Geometry.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperflight/Hierarchical-Multi-agent-DRL-with-Federated-Learning/a34213418ed4542a1e2a2641bb9ea26a91ce051b/Geometry.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Paper 2 | 3 | This is the code for paper "Correlation-aware Cooperative Multigroup Broadcast 360° Video Delivery Network: A Hierarchical Deep Reinforcement Learning Approach" 4 | For any usage, please cite this paper. 5 | 6 | ## Usage 7 | Directly run bash.sh 8 | 9 | ## t-SNE 10 | The Fig. 9 in this paper is plotted via the code in tsne.ipy file. 11 | 12 | ## Res_net 13 | https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 14 | 15 | ## A3C 16 | https://github.com/sweetice/Deep-reinforcement-learning-with-pytorch/blob/master/Char03%20Actor-Critic/AC_CartPole-v0.py \ 17 | https://github.com/rlcode/reinforcement-learning/blob/master/3-atari/1-breakout/breakout_a3c.py 18 | 19 | ## HAC 20 | https://github.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/blob/master/HAC.py \ 21 | https://github.com/skumar9876/Hierarchical-DQN/blob/master/hierarchical_dqn.py \ 22 | https://github.com/ifiaposto/Hierarchical-Deep-Reinforcement-Learning/blob/master/hdqn.py 23 | 24 | ## Rainbow 25 | https://github.com/Kaixhin/Rainbow 26 | -------------------------------------------------------------------------------- /accesspoint.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperflight/Hierarchical-Multi-agent-DRL-with-Federated-Learning/a34213418ed4542a1e2a2641bb9ea26a91ce051b/accesspoint.pyc -------------------------------------------------------------------------------- /bash.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | #SBATCH --output=/mnt/lustre/users/%u/%j.out 3 | #SBATCH --job-name=alphavr 4 | # #SBATCH --gres=gpu 5 | #SBATCH --ntasks=10 6 | #SBATCH --mem=20000 7 | #SBATCH --time=6-12:00 8 | # #SBATCH --constrain=v100 9 | #SBATCH --constrain=skylake 10 | 11 | ulimit -n 4096 12 | git status 13 | cat global_parameters.py 14 | 15 | # module load libs/cuda 16 | python ./rainbow_hac/train.py --id='hac_fed_large_15_variance' --active-scheduler --active-accesspoint --previous-action-observable --history-length-accesspoint=2 --history-length-scheduler=1 --architecture='canonical_4uav_61obv_3x3_mix' --action-selection='greedy' --data-reinforce --evaluation-interval=500 --evaluation-episodes=20000 --federated-round='20' -------------------------------------------------------------------------------- /center_server.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperflight/Hierarchical-Multi-agent-DRL-with-Federated-Learning/a34213418ed4542a1e2a2641bb9ea26a91ce051b/center_server.pyc -------------------------------------------------------------------------------- /global_parameters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | LENGTH_OF_FIELD = 80 4 | WIDTH_OF_FIELD = 80 # must be diviable by square step 5 | ACCESS_POINTS_FIELD = 61 # must be odd 6 | REWARD_CAL_RANGE = 1 # reward calculation range for each accesspoint (range = RCR*ACCESS_FIELD) 7 | DENSE_OF_ACCESSPOINT = 30 8 | ACCESS_POINT_PER_EDGE = int(1 + LENGTH_OF_FIELD // DENSE_OF_ACCESSPOINT) 9 | 10 | DENSE_OF_USERS = 120 11 | VARIANCE_OF_USERS = 0 12 | 13 | MAX_USERS_MOBILITY = 1 14 | 15 | NUM_OF_UAV = 4 16 | MAX_UAV_MOBILITY = 2 17 | 18 | # PCP parameters 19 | NUM_OF_CLUSTER = 1 20 | DENSE_OF_USERS_PCP = int(DENSE_OF_USERS / NUM_OF_UAV / NUM_OF_CLUSTER) 21 | UE_SCALE = 20 22 | VARIANCE_OF_SCALE = 0 23 | CLUSTER_SCALE = 1 24 | 25 | 26 | def REFRESH_SCALE(ue_scale): 27 | global DENSE_OF_USERS, VARIANCE_OF_SCALE, VARIANCE_OF_USERS, DENSE_OF_USERS_PCP, UE_SCALE 28 | DENSE_OF_USERS = int((1 + ((np.random.rand() - 0.5) * VARIANCE_OF_USERS)) * DENSE_OF_USERS) 29 | DENSE_OF_USERS_PCP = int(DENSE_OF_USERS / NUM_OF_UAV / NUM_OF_CLUSTER) 30 | UE_SCALE = np.ceil((1 + ((np.random.rand() - 0.5) * VARIANCE_OF_SCALE)) * ue_scale) 31 | 32 | 33 | FRAME_RATE = 90 34 | TILE_SIZE: float = 30 * 30 * 8 * 60 * 3 / 150 # 8640 35 | GOP: int = 5 36 | GOP_TILE_SIZE: list = [TILE_SIZE, 0.7 * TILE_SIZE, 0.7 * TILE_SIZE, 37 | 0.7 * TILE_SIZE, 0.7 * TILE_SIZE, 0.7 * TILE_SIZE, 0.7 * TILE_SIZE] 38 | GOP_INDEX: int = 0 39 | GOP_SIZE_CONSTANT = False 40 | 41 | UAV_FIELD_OF_VIEW = [6, 12] 42 | TOTAL_NUM_TILES = UAV_FIELD_OF_VIEW[0] * UAV_FIELD_OF_VIEW[1] 43 | USER_FIELD_OF_VIEW = [5, 7] # 150 verti * 210 hori 44 | 45 | LINK_THRESHOLD = 1e-7 46 | CORRELATION_THRESHOLD = 0.00125 47 | 48 | UAV_TRANSMISSION_CENTER_FREUENCY = 5e9 49 | UAV_INTERFERENCE = False 50 | AP_TRANSMISSION_CENTER_FREUENCY = 5e9 51 | 52 | SPEED_OF_LIGHT = 3e8 53 | DRONE_HEIGHT = 40 54 | EXCESSIVE_NLOS_ATTENUATION = pow(10, 20 / 10) 55 | 56 | ACCESS_POINT_TRANSMISSION_EIRP = pow(10, 78 / 10) # 78 dBm 57 | ACCESS_POINT_TRANSMISSION_BANDWIDTH = 50e6 # Hz 58 | UAV_TRANSMISSION_EIRP = pow(10, 78 / 10) # 78 dBm 59 | UAV_TRANSMISSION_BANDWIDTH = 50e6 # Hz 60 | NOISE_THETA = pow(10, -91 / 10) # -91 hertz 61 | UAV_AP_ALPHA = -2 62 | AP_UE_ALPHA = -4 63 | NAKAGAMI_M = 2 64 | RAYLEIGH = 2 65 | # https://arxiv.org/pdf/1704.02540.pdf 66 | 67 | DEFAULT_RESOURCE_BLOCKNUM = 28 # NUM of blocks for association change 68 | DEFAULT_RESOURCE_ALLOCATION = [1 / (FRAME_RATE * DEFAULT_RESOURCE_BLOCKNUM) 69 | for ind in range(0, DEFAULT_RESOURCE_BLOCKNUM * GOP)] 70 | DEFAULT_NUM_OF_RB_PER_RES = 10 # num of transmission tile in each decision slot 71 | DEFAULT_NUM_OF_RB = 1 # num of transmission tile in each transmit slot 72 | # transmit 5 slot, each with 10 tiles 73 | # 60Hz --- 0.016666666667s 74 | # into 20 slots 75 | 76 | CLUSTERING_METHOD = "PrivotingBK_greedy" # PrivotingBK_greedy/PrivotingBK 77 | 78 | LOG_LEVEL = 0 # 0: nothing, 1: text_only, 2: rich text, 3: even detail+figure, 4: save figure 79 | AP_COLOR = 'red' 80 | UE_COLOR = 'green' 81 | CS_COLOR = 'blue' 82 | 83 | PLOT_FADING_RANGE_LOG = [-150., 0.] 84 | IMAGE_PER_ROW = 4 85 | IMAGE_SIZE = (5, 4) 86 | 87 | 88 | 89 | # neural network parameters 90 | LR = 1e-3 91 | NUM_RES_BLOCKS = 8 92 | NUM_CHANNELS = 128 93 | DROP_OUT = 0.3 94 | EPOCHS = 10 95 | BATCH_SIZE = 64 96 | 97 | # mcts parameters 98 | MCTS_NUM = 25 99 | C_PUCT = 1 100 | 101 | # training parameters 102 | # observation square step 103 | SQUARE_STEP = 2 104 | # larger than this number. perform pure greedy, 105 | # should smaller than DEFAULT_RESOURCE_BLOCKNUM 106 | # is the v_resign in the paper 107 | ITERATION_NUM = 1000 108 | HISTORY_MAX_LEN = 200000 # learning iteration numbers 109 | EPS_NUM = 60 # episode in each iter 110 | EPS_GREEDY_NUM = 55 # default policy starts from step xx, this should be smaller than maximum steps 111 | TOTAL_HISTORY_LEN_IN_ITERATION = 50 # maximum learning history holded by deque 112 | REPLAY_FILE_PATH = "./replay/" # maximum length of total history in the number of iteration 113 | UPDATE_THRESHOLD = 0.55 # Arena update threshold 114 | ARENA_MATCH_NUM = 40 # Arena match num 115 | LOAD_HISTORY_EXAMPLES_PATH = ('./temp', 'best.pth.tar') # model and history load path 116 | USER_CLUSTER_INDICATOR_LENGTH = 8 117 | # indicate the number of clusters inside the figure 0 - 1 - step - step^2... 118 | # if step=2 : 0, 1, 2, 4, 8 119 | USER_CLUSTER_INDICATOR_STEP = 2 # scale the length indicator to reduce the states num 120 | # TODO: Observation version 1-3: 3, 4: 4, 5: 5, 6: 5, 7: 5, 8: 4 121 | OBSERVATION_DIMS = 3 # each cluster has three observations: uav position, user position, num of cluster 122 | REWARD_STAGE = [10, 15, 20] # reward stage, correspoinding to -1, 0, 1, 1.5 123 | BIAS_SIGMA = 0.25 # bias sigma for ensure all moves may be tried 124 | DIRICHLET = 0.03 # Bias parameter for dirichlet 125 | # TODO: when selecting observation 4 and 5, change the observation dims too 126 | OBSERVATION_VERSION = 7 # 1: observation v1, 2 observation v2 127 | # for details look into game.get_observation_vx() function 128 | NULL_UNAVALIABLE_UAV = False 129 | 130 | 131 | # training and loading parameters 132 | ENABLE_MODEL_RELOAD = False 133 | ENABLE_MEMORY_RELOAD = False 134 | ENABLE_EARLY_STOP = False 135 | ENABLE_EARLY_STOP_THRESHOLD = 0.5 136 | LOAD_MODE = False 137 | PARALLEL_EXICUSION = True 138 | ALLOCATED_CORES = 4 139 | 140 | # What should notice when running a new job: 141 | # OBSERVATION_VERSION : decide your observation type 142 | # ALLOCATED_CORES : num of parallel cores based on your computer build 143 | # LOAD_MODE: load previous model and playback in /temp/ or not 144 | # EPS_GREEDY_NUM : MCTS search steps 145 | # EPS_NUM : episodes for getting playback with new model 146 | # ARENA_MATCH_NUM : the number of arena match for comparing two models 147 | # TOTAL_HISTORY_LEN_IN_ITERATION : maximum stored history 148 | -------------------------------------------------------------------------------- /myplotlib.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import matplotlib.colors as mcolors 4 | import matplotlib.gridspec as gridspec 5 | import numpy as np 6 | from termcolor import colored 7 | from tabulate import tabulate 8 | import global_parameters as gp 9 | 10 | 11 | class MyFig: 12 | def __init__(self, num_of_subfig: list, figsize): 13 | self.fig = plt.figure(figsize=figsize, tight_layout=True) 14 | self.ax_list = [] 15 | for index_i in range(1, int(num_of_subfig[0] * num_of_subfig[1] + 1)): 16 | self.ax_list.append(self.fig.add_subplot(num_of_subfig[0], num_of_subfig[1], index_i)) 17 | self.num_of_subfig = num_of_subfig 18 | self.im_list = [None for _ in range(0, int(num_of_subfig[0] * num_of_subfig[1]))] 19 | self.data_list = [np.array([]) for _ in range(0, int(num_of_subfig[0] * num_of_subfig[1]))] 20 | self.index_h = 0 21 | self.index_v = 0 22 | self.index = 0 23 | 24 | def close(self): 25 | plt.close(self.fig) 26 | 27 | def next_figure(self): 28 | if self.index_v >= self.num_of_subfig[1] and self.index_h >= self.num_of_subfig[0]: 29 | raise IndexError("No next figure") 30 | if self.index_h == self.num_of_subfig[0] - 1: 31 | self.index_h = 0 32 | self.index_v += 1 33 | else: 34 | self.index_h += 1 35 | self.index += 1 36 | 37 | def reset_index(self): 38 | self.index_h = 0 39 | self.index_v = 0 40 | self.index = 0 41 | 42 | def plot_grid(self, data: np.ndarray, max_min: (int, int, int, int), step_size: int, 43 | range_value: list, title: str): 44 | # max_min: vmax, vmin, hmax, hmin 45 | vmax, vmin, hmax, hmin = max_min 46 | self.data_list[self.index] = data 47 | data[data == 0] = gp.PLOT_FADING_RANGE_LOG[1] 48 | norm = mcolors.Normalize(vmin=range_value[0], vmax=range_value[1]) 49 | # see note above: this makes all pcolormesh calls consistent: 50 | pc_kwargs = {'cmap': 'pink', 'norm': norm, 'edgecolors': 'k', 'linewidths': 2} 51 | self.ax_list[self.index].set_title(title) 52 | self.ax_list[self.index].grid(color='k', linestyle='-', linewidth=1) 53 | self.im_list[self.index] = self.ax_list[self.index].pcolor(self.data_list[self.index], 54 | vmin=gp.PLOT_FADING_RANGE_LOG[0], 55 | vmax=gp.PLOT_FADING_RANGE_LOG[1], 56 | **pc_kwargs) 57 | plt.setp(self.ax_list[self.index], xticks=np.arange(0, hmax - hmin + 2), 58 | xticklabels=np.around(np.arange(hmin * step_size, (hmax + 2) * step_size, 59 | step=step_size), decimals=1), 60 | yticks=np.arange(0, vmax - vmin + 2), 61 | yticklabels=np.around(np.arange(vmin * step_size, (vmax + 2) * step_size + step_size, 62 | step=step_size), decimals=1)) 63 | # +2 because 2-6 has 6 - 2 + 1 = 5 elements, range(2, 7, 1) count 6 (+1) 64 | # some element has position larger than 6 * step, so we need 7 also (+1) 65 | self.fig.colorbar(self.im_list[self.index], ax=self.ax_list[self.index], shrink=1, extend='min') 66 | 67 | def get_color(self, val: float, cmap): 68 | # Return the data color of an index. 69 | if int(val) > 0: 70 | raise ValueError("Fading can't larger than 0") 71 | return cmap(1 - abs(val / np.min(self.data_list[self.index]))) 72 | 73 | def draw_text_label(self, data, position: (int, int), idex: int): 74 | facecolor = self.get_color(data, self.im_list[self.index].get_cmap()) 75 | self.ax_list[self.index].text(position[0], position[1], "<{id}>".format(id=idex), 76 | color='white', ha='center', va='center', 77 | bbox={'boxstyle': 'square', 'facecolor': facecolor}) 78 | 79 | def draw_text_block(self, rotation: int, position: (float, float), hori_axis: str, verti_axis: str, content: str): 80 | # hori_axis/verti_axis = 'left' 'right' 'center' 81 | bbox_kwargs = {'fc': 'w', 'alpha': .75, 'boxstyle': "round4"} 82 | ann_kwargs = {'xycoords': 'axes fraction', 'textcoords': 'offset points', 'bbox': bbox_kwargs} 83 | self.ax_list[self.index].annotate(content, xy=position, xytext=(0, 0), 84 | ha=hori_axis, va=verti_axis, rotation=rotation, **ann_kwargs) 85 | 86 | def save_figure(self, time, apid): 87 | title_str = "clustering_result at time {ti} for access point {apid}".format(ti=time, apid=apid) 88 | self.fig.suptitle(title_str, fontsize=16) 89 | plt.savefig("./fig/cluster_result/" + title_str + ".eps") 90 | 91 | 92 | def plot_observation(observation: np.ndarray, group_size: int, num_of_ap: int, title: str, save=True): 93 | fig = plt.figure(figsize=(gp.IMAGE_SIZE[0] * observation.shape[2], 94 | gp.IMAGE_SIZE[1] * 1), tight_layout=True) 95 | ax_list = [] 96 | for index_i in range(observation.shape[2]): 97 | ax_list.append(fig.add_subplot(1, observation.shape[2], index_i + 1)) 98 | im_list = [None for _ in range(observation.shape[2])] 99 | data_list = [observation[:, :, _] for _ in range(observation.shape[2])] 100 | 101 | for index, data in enumerate(data_list): 102 | pc_kwargs = {'cmap': 'binary', 'edgecolors': 'k', 'linewidths': 2} 103 | ax_list[index].set_title( 104 | "Observation {a} for group {b}".format(a=index % group_size, b=index // group_size + 1)) 105 | ax_list[index].grid(color='k', linestyle='-', linewidth=1) 106 | im_list[index] = ax_list[index].pcolor(data_list[index], vmin=0, vmax=1, **pc_kwargs) 107 | plt.setp(ax_list[index]) 108 | 109 | bbox_kwargs = {'fc': 'w', 'alpha': .75, 'boxstyle': "round4"} 110 | ann_kwargs = {'xycoords': 'axes fraction', 'textcoords': 'offset points', 'bbox': bbox_kwargs} 111 | ap_len = np.floor(np.sqrt(num_of_ap)).astype(int) 112 | min_space = (1 / ap_len) / 2 113 | space = 1 / ap_len 114 | for ind in range(num_of_ap): 115 | ax_list[index].annotate(str(ind), xy=(min_space + ind % 2 * space, min_space + ind // 2 * space), xytext=(0, 0), 116 | ha='center', va='center', rotation=0, **ann_kwargs) 117 | 118 | if save: 119 | fig.suptitle(title, fontsize=16) 120 | plt.savefig("./fig/decision/" + title + ".eps") 121 | 122 | 123 | def table_print_color(table: np.ndarray, title: str, color): 124 | indi_r = np.indices([table.shape[0]]) 125 | print(colored(title, color)) 126 | if table.shape.__len__() == 1: 127 | if table.dtype == np.complex: 128 | temp_table = np.zeros((table.shape[0], 1, 2)) 129 | temp_table[:, :, 0] = [np.real(table)] 130 | temp_table[:, :, 1] = [np.imag(table)] 131 | print(colored(tabulate([temp_table], headers=[str(k) for k in indi_r[0]], tablefmt="grid"), color)) 132 | else: 133 | print(colored(tabulate([table], headers=[str(k) for k in indi_r[0]], tablefmt="grid"), color)) 134 | else: 135 | indi_c = np.indices([table.shape[1]]) 136 | if table.dtype == np.complex: 137 | temp_table = np.zeros((table.shape[0], table.shape[1], 2)) 138 | temp_table[:, :, 0] = np.real(table) 139 | temp_table[:, :, 1] = np.imag(table) 140 | print(colored(tabulate(np.insert(temp_table, 0, np.expand_dims(indi_r, axis=2), axis=1), 141 | headers=['ID'] + [str(k) for k in indi_c[0]], tablefmt="grid"), color)) 142 | else: 143 | print(colored(tabulate(np.insert(table, 0, indi_r, axis=1), 144 | headers=['ID'] + [str(k) for k in indi_c[0]], tablefmt="grid"), color)) 145 | 146 | 147 | def print_all_nodes_information(target): 148 | if gp.LOG_LEVEL >= 2: 149 | result_dic = [] 150 | for ues in target.users_list: 151 | result_dic.append([ues.id, ues.position, ues.sphere_id, ues.resource, ues.transmission_mask]) 152 | print(colored("UE LIST", gp.UE_COLOR)) 153 | print(colored(tabulate(result_dic, headers=['ID', 'Position', 'Sphere ID', 'Resource', 'Transmission Mask'], 154 | tablefmt="grid"), gp.UE_COLOR)) 155 | 156 | result_dic = [] 157 | for aps in target.accesspoint_list: 158 | result_dic.append([aps.id, aps.position]) 159 | print(colored("AP LIST", gp.AP_COLOR)) 160 | print(colored(tabulate(result_dic, headers=['ID', 'Position'], tablefmt="grid"), gp.AP_COLOR)) 161 | 162 | result_dic = [] 163 | for uavs in target.uav_list: 164 | result_dic.append([uavs.id, uavs.position, uavs.resource]) 165 | print(colored("UAV LIST", gp.CS_COLOR)) 166 | print(colored(tabulate(result_dic, headers=['ID', 'Position', 'Resources'], tablefmt="grid"), gp.CS_COLOR)) 167 | 168 | 169 | def print_users_information(target, transmitted_list): 170 | if gp.LOG_LEVEL >= 2: 171 | result_dic = [] 172 | for ues in target.users_list: 173 | if ues.sphere_id in target.stack_of_popular: 174 | # ue's cluster may not be selected 175 | if transmitted_list[ues.sphere_id].__len__() != 0: 176 | result_dic.append([ues.id, ues.position, ues.sphere_id, transmitted_list[ues.sphere_id], 177 | [res for key, res in enumerate(ues.resource) if ues.transmission_mask[key]]]) 178 | else: 179 | result_dic.append([ues.id, ues.position, ues.sphere_id, -1, 180 | [res for key, res in enumerate(ues.resource) if ues.transmission_mask[key]]]) 181 | print(colored("UE LIST", gp.AP_COLOR)) 182 | print(colored(tabulate(result_dic, headers=['ID', 'Position', 'Sphere ID', 'Current Transmitted', 183 | 'Remained Resources'], 184 | tablefmt="grid"), gp.AP_COLOR)) 185 | -------------------------------------------------------------------------------- /rainbow_hac/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperflight/Hierarchical-Multi-agent-DRL-with-Federated-Learning/a34213418ed4542a1e2a2641bb9ea26a91ce051b/rainbow_hac/.DS_Store -------------------------------------------------------------------------------- /rainbow_hac/ap_agent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import numpy as np 5 | import torch 6 | from torch import optim 7 | from torch.nn.utils import clip_grad_norm_ 8 | 9 | from rainbow_hac.basic_block import DQN 10 | 11 | 12 | class Agent: 13 | def __init__(self, args, env, index): 14 | self.active = args.active_accesspoint 15 | if not self.active: 16 | return 17 | self.action_space = env.get_action_size() 18 | self.action_type = args.action_selection 19 | self.atoms = args.atoms 20 | self.Vmin = args.V_min 21 | self.Vmax = args.V_max 22 | self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device) # Support (range) of z 23 | self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1) 24 | self.batch_size = args.batch_size 25 | self.n = args.multi_step_accesspoint 26 | self.discount = args.discount 27 | self.device = args.device 28 | 29 | self.online_net = DQN(args, self.action_space).to(device=args.device) 30 | if args.model: # Load pretrained model if provided 31 | self.model_path = os.path.join(args.model, "model" + str(index) + ".pth") 32 | if os.path.isfile(self.model_path): 33 | state_dict = torch.load(self.model_path, 34 | map_location='cpu') # Always load tensors onto CPU by default, will shift to GPU if necessary 35 | if 'conv1.weight' in state_dict.keys(): 36 | for old_key, new_key in (('conv1.weight', 'convs.0.weight'), ('conv1.bias', 'convs.0.bias'), 37 | ('conv2.weight', 'convs.2.weight'), ('conv2.bias', 'convs.2.bias'), 38 | ('conv3.weight', 'convs.4.weight'), ('conv3.bias', 'convs.4.bias')): 39 | state_dict[new_key] = state_dict[old_key] # Re-map state dict for old pretrained models 40 | del state_dict[old_key] # Delete old keys for strict load_state_dict 41 | self.online_net.load_state_dict(state_dict) 42 | print("Loading pretrained model: " + self.model_path) 43 | else: # Raise error if incorrect model path provided 44 | raise FileNotFoundError(self.model_path) 45 | 46 | self.online_net.train() 47 | 48 | self.target_net = DQN(args, self.action_space).to(device=args.device) 49 | 50 | self.online_dict = self.online_net.state_dict() 51 | self.target_dict = self.target_net.state_dict() 52 | 53 | self.update_target_net() 54 | self.target_net.train() 55 | for param in self.target_net.parameters(): 56 | param.requires_grad = False 57 | 58 | self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps) 59 | 60 | def reload_step_state_dict(self, better=True): 61 | if not self.active: 62 | return 63 | if better: 64 | self.online_dict = self.online_net.state_dict() 65 | self.target_dict = self.target_net.state_dict() 66 | else: 67 | self.online_net.load_state_dict(self.online_dict) 68 | self.target_net.load_state_dict(self.target_dict) 69 | 70 | def get_state_dict(self): 71 | if not self.active: 72 | return 73 | return self.online_net.state_dict() 74 | 75 | def set_state_dict(self, new_state_dict): 76 | if not self.active: 77 | return 78 | if 'conv1.weight' in new_state_dict.keys(): 79 | for old_key, new_key in (('conv1.weight', 'convs.0.weight'), ('conv1.bias', 'convs.0.bias'), 80 | ('conv2.weight', 'convs.2.weight'), ('conv2.bias', 'convs.2.bias'), 81 | ('conv3.weight', 'convs.4.weight'), ('conv3.bias', 'convs.4.bias')): 82 | new_state_dict[new_key] = new_state_dict[old_key] # Re-map state dict for old pretrained models 83 | del new_state_dict[old_key] # Delete old keys for strict load_state_dict 84 | self.online_net.load_state_dict(new_state_dict) 85 | return 86 | 87 | # Resets noisy weights in all linear layers (of online net only) 88 | def reset_noise(self): 89 | if not self.active: 90 | return 91 | self.online_net.reset_noise() 92 | 93 | # Acts based on single state (no batch) 94 | def act(self, state): 95 | if not self.active: 96 | return 97 | with torch.no_grad(): 98 | return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item() 99 | 100 | # Acts with an ε-greedy policy (used for evaluation only) 101 | def act_e_greedy(self, state, epsilon=0.3, action_type='greedy'): # High ε can reduce evaluation scores drastically 102 | if not self.active: 103 | return 104 | if action_type == 'greedy': 105 | return np.random.randint(0, self.action_space) if np.random.random() < epsilon else self.act(state) 106 | elif action_type == 'boltzmann': 107 | return self.act_boltzmann(state) 108 | 109 | # Acts with an ε-greedy policy (used for evaluation only) 110 | def act_boltzmann(self, state): # High ε can reduce evaluation scores drastically 111 | if not self.active: 112 | return 113 | res_policy = (self.online_net(state.unsqueeze(0)) * self.support).sum(2).detach().numpy() 114 | return self.boltzmann(res_policy) 115 | 116 | def boltzmann(self, res_policy): 117 | sizeofres = res_policy.shape 118 | res = [] 119 | for i in range(sizeofres[0]): 120 | count = 0 121 | action_probs = [] 122 | for _ in range(self.action_space): 123 | try: 124 | val = np.exp((res_policy[i][_].item()) * 3) 125 | except OverflowError: 126 | res.append(_) 127 | break 128 | action_probs.append(val) 129 | count += val 130 | if len(action_probs) == self.action_space: 131 | action_probs = [x / count for x in action_probs] 132 | res.append(np.random.choice(self.action_space, p=action_probs)) 133 | if sizeofres[0] == 1: 134 | return int(res[0]) 135 | return np.array(res) 136 | 137 | def lookup_server(self, list_of_pipe): 138 | if not self.active: 139 | return 140 | num_pro = len(list_of_pipe) 141 | list_pro = np.ones(num_pro, dtype=bool) 142 | with torch.no_grad(): 143 | while list_pro.any(): 144 | for key, pipes in enumerate(list_of_pipe): 145 | if not pipes.closed and pipes.readable: 146 | obs = pipes.recv() 147 | if len(obs) == 1: 148 | if not obs[0]: 149 | pipes.close() 150 | list_pro[key] = False 151 | continue 152 | if not self.active: 153 | pipes.send(False) 154 | else: 155 | pipes.send(self.act(obs).numpy()) 156 | # convert back to numpy or cpu-tensor, or it will cause error since cuda try to run in 157 | # another thread. Keep the gpu resource inside main thread 158 | 159 | def lookup_server_loop(self, list_of_pipe): 160 | if not self.active: 161 | return False 162 | num_pro = len(list_of_pipe) 163 | list_pro = np.ones(num_pro, dtype=bool) 164 | for key, pipes in enumerate(list_of_pipe): 165 | if not pipes.closed and pipes.readable: 166 | if pipes.poll(): 167 | obs = pipes.recv() 168 | if type(obs) is np.ndarray: 169 | pipes.close() 170 | list_pro[key] = False 171 | continue 172 | if not self.active: 173 | pipes.send(False) 174 | else: 175 | pipes.send(self.act(obs)) 176 | else: 177 | list_pro[key] = False 178 | # convert back to numpy or cpu-tensor, or it will cause error since cuda try to run in 179 | # another thread. Keep the gpu resource inside main thread 180 | return list_pro.any() 181 | 182 | def learn(self, mem): 183 | if not self.active: 184 | return 185 | # Sample transitions 186 | idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size) 187 | 188 | # Calculate current state probabilities (online network noise already sampled) 189 | log_ps = self.online_net(states, log=True) # Log probabilities log p(s_t, ·; θonline) 190 | log_ps_a = log_ps[range(self.batch_size), actions] # log p(s_t, a_t; θonline) 191 | 192 | with torch.no_grad(): 193 | # Calculate nth next state probabilities 194 | pns = self.online_net(next_states) # Probabilities p(s_t+n, ·; θonline) 195 | dns = self.support.expand_as(pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) 196 | if self.action_type == 'greedy': 197 | argmax_indices_ns = dns.sum(2).argmax(1) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] 198 | elif self.action_type == 'boltzmann': 199 | argmax_indices_ns = self.boltzmann(dns.sum(2)) 200 | self.target_net.reset_noise() # Sample new target net noise 201 | pns = self.target_net(next_states) # Probabilities p(s_t+n, ·; θtarget) 202 | pns_a = pns[range(self.batch_size), argmax_indices_ns] 203 | # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) 204 | 205 | # Compute Tz (Bellman operator T applied to z) 206 | Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze( 207 | 0) # Tz = R^n + (γ^n)z (accounting for terminal states) 208 | Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values 209 | # Compute L2 projection of Tz onto fixed support z 210 | b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz 211 | l, u = b.floor().long(), b.ceil().long() 212 | # Fix disappearing probability mass when l = b = u (b is int) 213 | l[(u > 0) * (l == u)] -= 1 214 | u[(l < (self.atoms - 1)) * (l == u)] += 1 215 | 216 | # Distribute probability of Tz 217 | m = states.new_zeros(self.batch_size, self.atoms) 218 | offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( 219 | self.batch_size, self.atoms).to(actions) 220 | m.view(-1).index_add_(0, (l + offset).view(-1), 221 | (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) 222 | m.view(-1).index_add_(0, (u + offset).view(-1), 223 | (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) 224 | 225 | loss = -torch.sum(m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) 226 | self.online_net.zero_grad() 227 | (weights * loss).mean().backward() # Backpropagate importance-weighted minibatch loss 228 | # clip_grad_norm_(self.online_net.parameters(), 1.0, norm_type=1) 229 | self.optimiser.step() 230 | 231 | mem.update_priorities(idxs, loss.detach().cpu().numpy()) # Update priorities of sampled transitions 232 | 233 | def update_target_net(self): 234 | if not self.active: 235 | return 236 | self.target_net.load_state_dict(self.online_net.state_dict()) 237 | 238 | # Save model parameters on current device (don't move model between devices) 239 | def save(self, path, index=-1, name='model.pth'): 240 | if not self.active: 241 | return 242 | if index == -1: 243 | torch.save(self.online_net.state_dict(), os.path.join(path, name)) 244 | else: 245 | torch.save(self.online_net.state_dict(), os.path.join(path, name[0:-4] + str(index) + name[-4:])) 246 | 247 | # Evaluates Q-value based on single state (no batch) 248 | def evaluate_q(self, state): 249 | if not self.active: 250 | return 0 251 | with torch.no_grad(): 252 | return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item() 253 | 254 | def train(self): 255 | if not self.active: 256 | return 257 | self.online_net.train() 258 | 259 | def eval(self): 260 | if not self.active: 261 | return 262 | self.online_net.eval() 263 | -------------------------------------------------------------------------------- /rainbow_hac/basic_block.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | import global_parameters as gp 8 | 9 | 10 | # Factorised NoisyLinear layer with bias 11 | class NoisyLinear(nn.Module): 12 | def __init__(self, in_features, out_features, std_init=0.5): 13 | super(NoisyLinear, self).__init__() 14 | self.in_features = in_features 15 | self.out_features = out_features 16 | self.std_init = std_init 17 | self.weight_mu = nn.Parameter(torch.empty(out_features, in_features)) 18 | self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features)) 19 | self.register_buffer('weight_epsilon', torch.empty(out_features, in_features)) 20 | self.bias_mu = nn.Parameter(torch.empty(out_features)) 21 | self.bias_sigma = nn.Parameter(torch.empty(out_features)) 22 | self.register_buffer('bias_epsilon', torch.empty(out_features)) 23 | self.reset_parameters() 24 | self.reset_noise() 25 | 26 | def reset_parameters(self): 27 | mu_range = 1 / math.sqrt(self.in_features) 28 | self.weight_mu.data.uniform_(-mu_range, mu_range) 29 | self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_features)) 30 | self.bias_mu.data.uniform_(-mu_range, mu_range) 31 | self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_features)) 32 | 33 | def _scale_noise(self, size): 34 | x = torch.randn(size) 35 | return x.sign().mul_(x.abs().sqrt_()) 36 | 37 | def reset_noise(self): 38 | epsilon_in = self._scale_noise(self.in_features) 39 | epsilon_out = self._scale_noise(self.out_features) 40 | self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in)) 41 | self.bias_epsilon.copy_(epsilon_out) 42 | 43 | def forward(self, inputs): 44 | if self.training: 45 | return F.linear(inputs, self.weight_mu + self.weight_sigma * self.weight_epsilon, 46 | self.bias_mu + self.bias_sigma * self.bias_epsilon) 47 | else: 48 | return F.linear(inputs, self.weight_mu, self.bias_mu) 49 | 50 | 51 | class DQN(nn.Module): 52 | def __init__(self, args, action_space): 53 | super(DQN, self).__init__() 54 | self.atoms = args.atoms 55 | self.action_space = action_space 56 | self.archit = args.architecture 57 | 58 | if 'canonical' in args.architecture and '61obv' in args.architecture and '2uav' in args.architecture: 59 | self.convs = nn.Sequential(nn.Conv2d(args.history_length_accesspoint, 16, 8, stride=3, padding=2), nn.LeakyReLU(), 60 | nn.Conv2d(16, 32, 4, stride=2, padding=1), nn.BatchNorm2d(32), nn.LeakyReLU(), 61 | nn.Conv2d(32, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 62 | nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 63 | nn.Dropout2d(0.2)) 64 | self.conv_output_size = 2368 # 41: 2: 1600 # 61: 2: 2368 3: 3200 4: 4288 # 4 uav: 4992 65 | elif 'canonical' in args.architecture and '41obv' in args.architecture and '2uav' in args.architecture: 66 | self.convs = nn.Sequential(nn.Conv2d(args.history_length_accesspoint, 16, 8, stride=3, padding=2), nn.LeakyReLU(), 67 | nn.Conv2d(16, 32, 4, stride=2, padding=1), nn.BatchNorm2d(32), nn.LeakyReLU(), 68 | nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(), 69 | nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 70 | nn.Dropout2d(0.2)) 71 | self.conv_output_size = 1600 # 41: 2: 1600 72 | elif 'canonical' in args.architecture and '61obv' in args.architecture and '4uav' in args.architecture: 73 | self.convs = nn.Sequential(nn.Conv2d(args.history_length_accesspoint, 16, 8, stride=3, padding=2), nn.LeakyReLU(), 74 | nn.Conv2d(16, 32, 4, stride=2, padding=1), nn.BatchNorm2d(32), nn.LeakyReLU(), 75 | nn.Conv2d(32, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 76 | nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 77 | # nn.MaxPool2d((args.dense_of_uav, 1)), 78 | nn.Dropout2d(0.2)) 79 | self.conv_output_size = 3648 80 | # 41: 2: 1600 # 61: 2: 2368 3: 3200 4: 4288 # 4 uav: 4992 /pooling 1216/ 3dim obs 3648 81 | elif args.architecture == 'canonical_3d': 82 | self.convs = nn.Sequential(nn.Conv3d(1, 32, (gp.OBSERVATION_DIMS, 8, 8), stride=(gp.OBSERVATION_DIMS, 3, 3), 83 | padding=(0, 2, 2)), nn.LeakyReLU(), 84 | nn.Conv3d(32, 64, (gp.NUM_OF_UAV, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)), nn.LeakyReLU(), 85 | nn.Conv3d(64, 64, (1, 3, 3), stride=1, padding=0), nn.LeakyReLU()) 86 | self.conv_output_size = 12160 # 2: 12160 3: 3200 4: 4288 87 | elif args.architecture == 'resnet8': 88 | net_args = { 89 | "indim": gp.OBSERVATION_DIMS * gp.NUM_OF_UAV, 90 | "block": ResidualBlock, 91 | "layers": [2, 2, 2, 2] 92 | } 93 | self.convs = ResNet(**net_args) 94 | self.conv_output_size = 64 * 4 * 4 95 | elif 'data-efficient' in args.architecture and '61obv' in args.architecture and '4uav' in args.architecture: 96 | self.convs = nn.Sequential(nn.Conv2d(args.history_length_accesspoint, 16, 5, stride=3, padding=2, dilation= 2), nn.LeakyReLU(), 97 | nn.Conv2d(16, 32, 3, stride=1, padding=0), nn.BatchNorm2d(32), nn.LeakyReLU(), nn.MaxPool2d(2), 98 | nn.Conv2d(32, 32, 3, stride=1, padding=0), nn.BatchNorm2d(32), nn.LeakyReLU(), 99 | nn.MaxPool2d((2, 1)), 100 | nn.Dropout2d(0.2)) 101 | self.conv_output_size = 1248 # 2: 12160 3: 3200 4: 4288 102 | else: 103 | raise TypeError('No such strucure') 104 | # TODO: Calculate the output_size carefully!!! 105 | # if args.architecture == 'canonical': 106 | # self.convs = nn.Sequential(nn.Conv2d(args.state_dims, 32, 3, stride=1, padding=1), nn.ReLU(), 107 | # nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(), 108 | # nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU()) 109 | # self.conv_output_size = 576 110 | # elif args.architecture == 'data-efficient': 111 | # self.convs = nn.Sequential(nn.Conv2d(args.history_length, 32, 3, stride=1, padding=0), nn.ReLU(), 112 | # nn.Conv2d(32, 64, 3, stride=1, padding=0), nn.ReLU()) 113 | # self.conv_output_size = 576 114 | self.fc = nn.Sequential(nn.Linear(self.conv_output_size, args.hidden_size), nn.Dropout(0.2), nn.LeakyReLU()) 115 | self.fc_h_v = NoisyLinear(args.hidden_size, args.hidden_size, std_init=args.noisy_std) 116 | self.fc_h_a = NoisyLinear(args.hidden_size, args.hidden_size, std_init=args.noisy_std) 117 | self.fc_z_v = NoisyLinear(args.hidden_size, self.atoms, std_init=args.noisy_std) 118 | self.fc_z_a = NoisyLinear(args.hidden_size, action_space * self.atoms, std_init=args.noisy_std) 119 | 120 | def forward(self, x, log=False): 121 | # if x.shape[1] != 1: 122 | # list_x = torch.split(x, 1, dim=1) 123 | # x = torch.cat(list_x, dim=3) 124 | if '3d' in self.archit: 125 | x = x.unsqueeze(1) 126 | if 'resnet' in self.archit: 127 | x = x.squeeze(1) 128 | x = self.convs(x.float()) 129 | x = self.fc(x.view(x.size(0), -1)) 130 | v = self.fc_z_v(F.relu(F.dropout(self.fc_h_v(x), p=0.2))) # Value stream 131 | a = self.fc_z_a(F.relu(F.dropout(self.fc_h_a(x), p=0.2))) # Advantage stream 132 | # v = self.fc_z_v(F.relu(self.fc_h_v(x))) # Value stream 133 | # a = self.fc_z_a(F.relu(self.fc_h_a(x))) # Advantage stream 134 | v, a = v.view(-1, 1, self.atoms), a.view(-1, self.action_space, self.atoms) 135 | q = v + a - a.mean(1, keepdim=True) # Combine streams 136 | if log: # Use log softmax for numerical stability 137 | q = F.log_softmax(q, dim=-1) # Log probabilities with action over second dimension 138 | else: 139 | q = F.softmax(q, dim=-1) # Probabilities with action over second dimension 140 | return q 141 | 142 | def reset_noise(self): 143 | for name, module in self.named_children(): 144 | if 'fc' in name and name != 'fc': 145 | module.reset_noise() 146 | -------------------------------------------------------------------------------- /rainbow_hac/basic_block_center.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.autograd import Variable 8 | import math 9 | import global_parameters as gp 10 | 11 | 12 | # Factorised NoisyLinear layer with bias 13 | class NoisyLinear(nn.Module): 14 | def __init__(self, in_features, out_features, std_init=0.5): 15 | super(NoisyLinear, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.std_init = std_init 19 | self.weight_mu = nn.Parameter(torch.empty(out_features, in_features)) 20 | self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features)) 21 | self.register_buffer('weight_epsilon', torch.empty(out_features, in_features)) 22 | self.bias_mu = nn.Parameter(torch.empty(out_features)) 23 | self.bias_sigma = nn.Parameter(torch.empty(out_features)) 24 | self.register_buffer('bias_epsilon', torch.empty(out_features)) 25 | self.reset_parameters() 26 | self.reset_noise() 27 | 28 | def reset_parameters(self): 29 | mu_range = 1 / math.sqrt(self.in_features) 30 | self.weight_mu.data.uniform_(-mu_range, mu_range) 31 | self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_features)) 32 | self.bias_mu.data.uniform_(-mu_range, mu_range) 33 | self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_features)) 34 | 35 | def _scale_noise(self, size): 36 | x = torch.randn(size) 37 | return x.sign().mul_(x.abs().sqrt_()) 38 | 39 | def reset_noise(self): 40 | epsilon_in = self._scale_noise(self.in_features) 41 | epsilon_out = self._scale_noise(self.out_features) 42 | self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in)) 43 | self.bias_epsilon.copy_(epsilon_out) 44 | 45 | def forward(self, inputs): 46 | if self.training: 47 | return F.linear(inputs, self.weight_mu + self.weight_sigma * self.weight_epsilon, 48 | self.bias_mu + self.bias_sigma * self.bias_epsilon) 49 | else: 50 | return F.linear(inputs, self.weight_mu, self.bias_mu) 51 | 52 | 53 | class DQN(nn.Module): 54 | def __init__(self, args, action_space): 55 | super(DQN, self).__init__() 56 | self.atoms = args.atoms_sche 57 | self.action_space = action_space 58 | self.archit = args.architecture 59 | 60 | if gp.GOP >= 2: 61 | self.split = gp.UAV_FIELD_OF_VIEW[0] * 2 * math.ceil(gp.LENGTH_OF_FIELD / gp.DENSE_OF_ACCESSPOINT) 62 | else: 63 | self.split = gp.UAV_FIELD_OF_VIEW[0] * math.ceil(gp.LENGTH_OF_FIELD / gp.DENSE_OF_ACCESSPOINT) 64 | 65 | if 'canonical' in args.architecture and '2x2' in args.architecture: 66 | # self.convs = nn.Sequential(nn.Conv2d(args.history_length_scheduler, 32, 5, stride=2, padding=2, dilation=2), nn.LeakyReLU(), 67 | # nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.LeakyReLU(), 68 | # nn.Conv2d(64, 128, 3, stride=1, padding=0), nn.LeakyReLU()) 69 | # # nn.Conv2d(128, 256, 3, stride=1, padding=0), nn.LeakyReLU()) 70 | # self.conv_output_size = 5120 71 | self.convs = nn.Sequential(nn.Conv2d(args.history_length_scheduler, 16, 5, stride=2, padding=4, dilation=2), 72 | nn.LeakyReLU(), 73 | nn.Conv2d(16, 32, 3, stride=1, padding=0), nn.BatchNorm2d(32), nn.LeakyReLU(), 74 | nn.Conv2d(32, 32, 3, stride=1, padding=0), nn.BatchNorm2d(32), nn.LeakyReLU(), 75 | nn.Conv2d(32, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 76 | nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 77 | nn.Dropout(0.2)) 78 | self.conv_output_size = 8192 79 | elif 'canonical' in args.architecture and '3x3' in args.architecture: 80 | self.convs = nn.Sequential(nn.Conv2d(args.history_length_scheduler, 16, 5, stride=2, padding=2, dilation=2), 81 | nn.LeakyReLU(), 82 | nn.Conv2d(16, 32, 4, stride=1, padding=0), nn.BatchNorm2d(32), nn.LeakyReLU(), 83 | nn.Conv2d(32, 32, 3, stride=1, padding=0), nn.BatchNorm2d(32), nn.LeakyReLU(), 84 | nn.Conv2d(32, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 85 | nn.Conv2d(64, 64, 3, stride=2, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 86 | nn.Dropout(0.2)) 87 | self.conv_output_size = 1024 # 3x3 2: 3328 4: 7936 88 | # TODO: adding UAV requires pooling to reduce the number of parameters 89 | elif args.architecture == 'canonical_3d': 90 | self.convs = nn.Sequential(nn.Conv3d(1, 32, (2, 5, 5), stride=1, padding=4, dilation=2), nn.LeakyReLU(), 91 | nn.Conv3d(32, 64, (2, 3, 3), stride=1, padding=0), nn.LeakyReLU(), 92 | nn.Conv3d(64, 64, (1, 3, 3), stride=1, padding=0), nn.LeakyReLU(), 93 | nn.Conv3d(64, 128, (1, 3, 3), stride=1, padding=0), nn.LeakyReLU(), 94 | nn.Conv3d(128, 128, (1, 3, 3), stride=1, padding=0), nn.LeakyReLU(), 95 | nn.Dropout(0.2)) 96 | self.conv_output_size = 8192 97 | elif args.architecture == 'data-efficient': 98 | self.convs = nn.Sequential(nn.Conv2d(args.history_length, 32, 5, stride=5, padding=0), nn.ReLU(), 99 | nn.Conv2d(32, 64, 5, stride=5, padding=0), nn.ReLU()) 100 | self.conv_output_size = 576 101 | else: 102 | raise TypeError('No such strucure') 103 | # TODO: Calculate the output_size carefully!!! 104 | # if args.architecture == 'canonical': 105 | # self.convs = nn.Sequential(nn.Conv2d(args.state_dims, 32, 3, stride=1, padding=1), nn.ReLU(), 106 | # nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(), 107 | # nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU()) 108 | # self.conv_output_size = 576 109 | # elif args.architecture == 'data-efficient': 110 | # self.convs = nn.Sequential(nn.Conv2d(args.history_length, 32, 3, stride=1, padding=0), nn.ReLU(), 111 | # nn.Conv2d(32, 64, 3, stride=1, padding=0), nn.ReLU()) 112 | # self.conv_output_size = 576 113 | 114 | self.fc = nn.Sequential(nn.Linear(self.conv_output_size * gp.NUM_OF_UAV, args.dense_of_uav * args.hidden_size), 115 | nn.Dropout(0.2), nn.LeakyReLU(), 116 | nn.Linear(self.conv_output_size * gp.NUM_OF_UAV, args.dense_of_uav * args.hidden_size), 117 | nn.Dropout(0.2), nn.LeakyReLU()) 118 | 119 | self.fc_h_v = NoisyLinear(args.dense_of_uav * args.hidden_size, args.dense_of_uav * args.hidden_size, std_init=args.noisy_std) 120 | self.fc_h_a = NoisyLinear(args.dense_of_uav * args.hidden_size, args.dense_of_uav * args.hidden_size, std_init=args.noisy_std) 121 | self.fc_z_v = NoisyLinear(args.dense_of_uav * args.hidden_size, self.atoms, std_init=args.noisy_std) 122 | self.fc_z_a = NoisyLinear(args.dense_of_uav * args.hidden_size, action_space * self.atoms, std_init=args.noisy_std) 123 | 124 | def forward(self, x, log=False): 125 | # if x.shape[1] != 1: 126 | # list_x = torch.split(x, 1, dim=1) 127 | # x = torch.cat(list_x, dim=2) 128 | x = torch.split(x, self.split, dim=3) 129 | res = [] 130 | for index, each in enumerate(x): 131 | each = self.convs(each.float()) 132 | res.append(each.view(each.size(0), -1)) 133 | x = self.fc(torch.cat(res, dim=1)) 134 | 135 | v = self.fc_z_v(F.relu(F.dropout(self.fc_h_v(x), p=0.2))) # Value stream 136 | a = self.fc_z_a(F.relu(F.dropout(self.fc_h_a(x), p=0.2))) # Advantage stream 137 | # v = self.fc_z_v(F.relu(self.fc_h_v(x))) # Value stream 138 | # a = self.fc_z_a(F.relu(self.fc_h_a(x))) # Advantage stream 139 | v, a = v.view(-1, 1, self.atoms), a.view(-1, self.action_space, self.atoms) 140 | q = v + a - a.mean(1, keepdim=True) # Combine streams 141 | if log: # Use log softmax for numerical stability 142 | q = F.log_softmax(q, dim=-1) # Log probabilities with action over second dimension 143 | else: 144 | q = F.softmax(q, dim=-1) # Probabilities with action over second dimension 145 | return q 146 | 147 | def reset_noise(self): 148 | for name, module in self.named_children(): 149 | if 'fc' in name and name != 'fc': 150 | module.reset_noise() 151 | -------------------------------------------------------------------------------- /rainbow_hac/basic_block_center_mix.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | # Factorised NoisyLinear layer with bias 11 | class NoisyLinear(nn.Module): 12 | def __init__(self, in_features, out_features, std_init=0.5): 13 | super(NoisyLinear, self).__init__() 14 | self.in_features = in_features 15 | self.out_features = out_features 16 | self.std_init = std_init 17 | self.weight_mu = nn.Parameter(torch.empty(out_features, in_features)) 18 | self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features)) 19 | self.register_buffer('weight_epsilon', torch.empty(out_features, in_features)) 20 | self.bias_mu = nn.Parameter(torch.empty(out_features)) 21 | self.bias_sigma = nn.Parameter(torch.empty(out_features)) 22 | self.register_buffer('bias_epsilon', torch.empty(out_features)) 23 | self.reset_parameters() 24 | self.reset_noise() 25 | 26 | def reset_parameters(self): 27 | mu_range = 1 / math.sqrt(self.in_features) 28 | self.weight_mu.data.uniform_(-mu_range, mu_range) 29 | self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_features)) 30 | self.bias_mu.data.uniform_(-mu_range, mu_range) 31 | self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_features)) 32 | 33 | def _scale_noise(self, size): 34 | x = torch.randn(size) 35 | return x.sign().mul_(x.abs().sqrt_()) 36 | 37 | def reset_noise(self): 38 | epsilon_in = self._scale_noise(self.in_features) 39 | epsilon_out = self._scale_noise(self.out_features) 40 | self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in)) 41 | self.bias_epsilon.copy_(epsilon_out) 42 | 43 | def forward(self, inputs): 44 | if self.training: 45 | return F.linear(inputs, self.weight_mu + self.weight_sigma * self.weight_epsilon, 46 | self.bias_mu + self.bias_sigma * self.bias_epsilon) 47 | else: 48 | return F.linear(inputs, self.weight_mu, self.bias_mu) 49 | 50 | 51 | class DQN(nn.Module): 52 | def __init__(self, args, action_space): 53 | super(DQN, self).__init__() 54 | self.atoms = args.atoms_sche 55 | self.action_space = action_space 56 | self.archit = args.architecture 57 | 58 | if 'canonical' in args.architecture and '2x2' in args.architecture: 59 | # self.convs = nn.Sequential(nn.Conv2d(args.history_length_scheduler, 32, 5, stride=2, padding=2, dilation=2), nn.LeakyReLU(), 60 | # nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.LeakyReLU(), 61 | # nn.Conv2d(64, 128, 3, stride=1, padding=0), nn.LeakyReLU()) 62 | # # nn.Conv2d(128, 256, 3, stride=1, padding=0), nn.LeakyReLU()) 63 | # self.conv_output_size = 5120 64 | self.convs = nn.Sequential(nn.Conv2d(args.history_length_scheduler, 16, 5, stride=2, padding=4, dilation=2), 65 | nn.LeakyReLU(), 66 | nn.Conv2d(16, 32, 3, stride=1, padding=0), nn.BatchNorm2d(32), nn.LeakyReLU(), 67 | nn.Conv2d(32, 32, 3, stride=1, padding=0), nn.BatchNorm2d(32), nn.LeakyReLU(), 68 | nn.Conv2d(32, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 69 | nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 70 | nn.Dropout2d(0.2)) 71 | if '2uav' in args.architecture: 72 | self.conv_output_size = 4096 # 3x3 2: 3328 4: 7936 73 | elif '4uav' in args.architecture: 74 | self.conv_output_size = 8192 75 | elif 'canonical' in args.architecture and '3x3' in args.architecture: 76 | self.convs = nn.Sequential(nn.Conv2d(args.history_length_scheduler, 16, 5, stride=2, padding=2, dilation=2), 77 | nn.LeakyReLU(), 78 | nn.Conv2d(16, 32, 4, stride=1, padding=0), nn.BatchNorm2d(32), nn.LeakyReLU(), 79 | nn.Conv2d(32, 32, 3, stride=1, padding=0), nn.BatchNorm2d(32), nn.LeakyReLU(), 80 | nn.Conv2d(32, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 81 | nn.Conv2d(64, 64, 3, stride=2, padding=0), nn.BatchNorm2d(64), nn.LeakyReLU(), 82 | nn.Dropout2d(0.2)) 83 | if '2uav' in args.architecture: 84 | self.conv_output_size = 3328 # 3x3 2: 3328 4: 7936 85 | elif '4uav' in args.architecture: 86 | self.conv_output_size = 7936 87 | # TODO: adding UAV requires pooling to reduce the number of parameters 88 | elif args.architecture == 'canonical_3d': 89 | self.convs = nn.Sequential(nn.Conv3d(1, 32, (2, 5, 5), stride=1, padding=4, dilation=2), nn.LeakyReLU(), 90 | nn.Conv3d(32, 64, (2, 3, 3), stride=1, padding=0), nn.LeakyReLU(), 91 | nn.Conv3d(64, 64, (1, 3, 3), stride=1, padding=0), nn.LeakyReLU(), 92 | nn.Conv3d(64, 128, (1, 3, 3), stride=1, padding=0), nn.LeakyReLU(), 93 | nn.Conv3d(128, 128, (1, 3, 3), stride=1, padding=0), nn.LeakyReLU()) 94 | self.conv_output_size = 8192 95 | elif args.architecture == 'data-efficient': 96 | self.convs = nn.Sequential(nn.Conv2d(args.history_length, 32, 5, stride=5, padding=0), nn.ReLU(), 97 | nn.Conv2d(32, 64, 5, stride=5, padding=0), nn.ReLU()) 98 | self.conv_output_size = 576 99 | else: 100 | raise TypeError('No such strucure') 101 | # TODO: Calculate the output_size carefully!!! 102 | # if args.architecture == 'canonical': 103 | # self.convs = nn.Sequential(nn.Conv2d(args.state_dims, 32, 3, stride=1, padding=1), nn.ReLU(), 104 | # nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(), 105 | # nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU()) 106 | # self.conv_output_size = 576 107 | # elif args.architecture == 'data-efficient': 108 | # self.convs = nn.Sequential(nn.Conv2d(args.history_length, 32, 3, stride=1, padding=0), nn.ReLU(), 109 | # nn.Conv2d(32, 64, 3, stride=1, padding=0), nn.ReLU()) 110 | # self.conv_output_size = 576 111 | 112 | # self.fc_h_v = nn.Sequential(nn.Linear(self.conv_output_size, 2 * args.hidden_size), 113 | # NoisyLinear(2 * args.hidden_size, args.hidden_size, std_init=args.noisy_std)) 114 | # self.fc_h_a = nn.Sequential(nn.Linear(self.conv_output_size, 2 * args.hidden_size), 115 | # NoisyLinear(2 * args.hidden_size, args.hidden_size, std_init=args.noisy_std)) 116 | # self.fc_z_v = NoisyLinear(args.hidden_size, self.atoms, std_init=args.noisy_std) 117 | # self.fc_z_a = NoisyLinear(args.hidden_size, action_space * self.atoms, std_init=args.noisy_std) 118 | 119 | # self.pre_h_v = nn.Linear(self.conv_output_size, 2 * args.hidden_size) 120 | # self.pre_z_v = nn.Linear(self.conv_output_size, 2 * args.hidden_size) 121 | # 122 | # self.fc_h_v = NoisyLinear(2 * args.hidden_size, args.hidden_size, std_init=args.noisy_std) 123 | # self.fc_h_a = NoisyLinear(2 * args.hidden_size, args.hidden_size, std_init=args.noisy_std) 124 | # self.fc_z_v = NoisyLinear(args.hidden_size, self.atoms, std_init=args.noisy_std) 125 | # self.fc_z_a = NoisyLinear(args.hidden_size, action_space * self.atoms, std_init=args.noisy_std) 126 | self.fc = nn.Sequential(nn.Linear(self.conv_output_size, args.dense_of_uav * args.hidden_size), nn.Dropout(0.2), nn.LeakyReLU()) 127 | 128 | self.fc_h_v = NoisyLinear(args.dense_of_uav * args.hidden_size, args.dense_of_uav * args.hidden_size, std_init=args.noisy_std) 129 | self.fc_h_a = NoisyLinear(args.dense_of_uav * args.hidden_size, args.dense_of_uav * args.hidden_size, std_init=args.noisy_std) 130 | self.fc_z_v = NoisyLinear(args.dense_of_uav * args.hidden_size, self.atoms, std_init=args.noisy_std) 131 | self.fc_z_a = NoisyLinear(args.dense_of_uav * args.hidden_size, action_space * self.atoms, std_init=args.noisy_std) 132 | 133 | def forward(self, x, log=False): 134 | # if x.shape[1] != 1: 135 | # list_x = torch.split(x, 1, dim=1) 136 | # x = torch.cat(list_x, dim=2) 137 | if '3d' in self.archit: 138 | x = x.unsqueeze(1) 139 | x = self.convs(x.float()) 140 | x = self.fc(x.view(x.size(0), -1)) 141 | 142 | # v = self.fc_z_v(F.relu(self.fc_h_v(x))) # Value stream 143 | # a = self.fc_z_a(F.relu(self.fc_h_a(x))) # Advantage stream 144 | v = self.fc_z_v(F.relu(F.dropout(self.fc_h_v(x), p=0.2))) # Value stream 145 | a = self.fc_z_a(F.relu(F.dropout(self.fc_h_a(x), p=0.2))) # Advantage stream 146 | v, a = v.view(-1, 1, self.atoms), a.view(-1, self.action_space, self.atoms) 147 | q = v + a - a.mean(1, keepdim=True) # Combine streams 148 | if log: # Use log softmax for numerical stability 149 | q = F.log_softmax(q, dim=-1) # Log probabilities with action over second dimension 150 | else: 151 | q = F.softmax(q, dim=-1) # Probabilities with action over second dimension 152 | return q 153 | 154 | def reset_noise(self): 155 | for name, module in self.named_children(): 156 | if 'fc' in name and name != 'fc': 157 | module.reset_noise() 158 | -------------------------------------------------------------------------------- /rainbow_hac/center_agent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import numpy as np 5 | import torch 6 | from torch import optim 7 | from torch.nn.utils import clip_grad_norm_ 8 | 9 | from rainbow_hac.basic_block_center import DQN 10 | from rainbow_hac.basic_block_center_mix import DQN as DQN_M 11 | import global_parameters as gp 12 | 13 | 14 | class CT_Agent: 15 | def __init__(self, args, env): 16 | self.active = args.active_scheduler 17 | if not self.active: 18 | return 19 | self.env = env 20 | self.action_space, self.square_resource = env.get_resource_action_space() 21 | self.uav_num = len(env.uav_list) 22 | self.atoms = args.atoms_sche 23 | self.Vmin = args.V_min 24 | self.Vmax = args.V_max 25 | self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device) # Support (range) of z 26 | self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1) 27 | self.batch_size = args.batch_size 28 | self.n = args.multi_step_scheduler 29 | self.discount = args.discount 30 | self.device = args.device 31 | self.net_type = args.architecture 32 | 33 | if 'mix' in self.net_type: 34 | self.online_net = DQN_M(args, self.action_space).to(device=args.device) 35 | else: 36 | self.online_net = DQN(args, self.action_space).to(device=args.device) 37 | if args.model: # Load pretrained model if provided 38 | self.model_path = os.path.join(args.model, "scheduler_model.pth") 39 | if os.path.isfile(self.model_path): 40 | state_dict = torch.load(self.model_path, 41 | map_location='cpu') # Always load tensors onto CPU by default, will shift to GPU if necessary 42 | if 'conv1.weight' in state_dict.keys(): 43 | for old_key, new_key in (('conv1.weight', 'convs.0.weight'), ('conv1.bias', 'convs.0.bias'), 44 | ('conv2.weight', 'convs.2.weight'), ('conv2.bias', 'convs.2.bias'), 45 | ('conv3.weight', 'convs.4.weight'), ('conv3.bias', 'convs.4.bias')): 46 | state_dict[new_key] = state_dict[old_key] # Re-map state dict for old pretrained models 47 | del state_dict[old_key] # Delete old keys for strict load_state_dict 48 | self.online_net.load_state_dict(state_dict) 49 | print("Loading pretrained model: " + self.model_path) 50 | else: # Raise error if incorrect model path provided 51 | raise FileNotFoundError(self.model_path) 52 | 53 | self.online_net.train() 54 | 55 | if 'mix' in self.net_type: 56 | self.target_net = DQN_M(args, self.action_space).to(device=args.device) 57 | else: 58 | self.target_net = DQN(args, self.action_space).to(device=args.device) 59 | 60 | self.online_dict = self.online_net.state_dict() 61 | self.target_dict = self.target_net.state_dict() 62 | 63 | self.update_target_net() 64 | self.target_net.train() 65 | for param in self.target_net.parameters(): 66 | param.requires_grad = False 67 | 68 | self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps) 69 | 70 | def reload_step_state_dict(self, better=True): 71 | if not self.active: 72 | return 73 | 74 | if better: 75 | self.online_dict = self.online_net.state_dict() 76 | self.target_dict = self.target_net.state_dict() 77 | else: 78 | self.online_net.load_state_dict(self.online_dict) 79 | self.target_net.load_state_dict(self.target_dict) 80 | 81 | def get_state_dict(self): 82 | if not self.active: 83 | return 84 | 85 | return self.online_net.state_dict() 86 | 87 | def set_state_dict(self, new_state_dict): 88 | if not self.active: 89 | return 90 | 91 | self.online_net.load_state_dict(new_state_dict) 92 | return 93 | 94 | # Resets noisy weights in all linear layers (of online net only) 95 | def reset_noise(self): 96 | if not self.active: 97 | return 98 | 99 | self.online_net.reset_noise() 100 | 101 | # Acts based on single state (no batch) 102 | def act(self, state): 103 | if not self.active: 104 | return 105 | with torch.no_grad(): 106 | prob = (self.online_net(state.unsqueeze(0)) * self.support).sum(2) 107 | return prob.numpy()[0] 108 | 109 | def lookup_server(self, list_of_pipe): 110 | if not self.active: 111 | return 112 | num_pro = len(list_of_pipe) 113 | list_pro = np.ones(num_pro, dtype=bool) 114 | with torch.no_grad(): 115 | while list_pro.any(): 116 | for key, pipes in enumerate(list_of_pipe): 117 | if not pipes.closed and pipes.readable: 118 | obs = pipes.recv() 119 | if len(obs) == 1: 120 | if not obs[0]: 121 | pipes.close() 122 | list_pro[key] = False 123 | continue 124 | if not self.active: 125 | pipes.send(False) 126 | else: 127 | pipes.send(self.act(obs).numpy()) 128 | # convert back to numpy or cpu-tensor, or it will cause error since cuda try to run in 129 | # another thread. Keep the gpu resource inside main thread 130 | 131 | def lookup_server_loop(self, list_of_pipe): 132 | num_pro = len(list_of_pipe) 133 | list_pro = np.ones(num_pro, dtype=bool) 134 | for key, pipes in enumerate(list_of_pipe): 135 | if not pipes.closed and pipes.readable: 136 | if pipes.poll(): 137 | obs = pipes.recv() 138 | if type(obs) is np.ndarray: 139 | pipes.close() 140 | list_pro[key] = False 141 | continue 142 | if not self.active: 143 | pipes.send(False) 144 | else: 145 | pipes.send(self.act(obs)) 146 | else: 147 | list_pro[key] = False 148 | # convert back to numpy or cpu-tensor, or it will cause error since cuda try to run in 149 | # another thread. Keep the gpu resource inside main thread 150 | return list_pro.any() 151 | 152 | # Acts with an ε-greedy policy (used for evaluation only) 153 | def act_e_greedy(self, state, epsilon=0.3): # High ε can reduce evaluation scores drastically 154 | if not self.active: 155 | return 156 | return np.random.rand(self.action_space) if np.random.random() < epsilon else self.act(state) 157 | 158 | def convert_result_prob_to_popularity(self, result_prob, state): 159 | request_avaliable = state[-1].numpy() 160 | request_avaliable = np.split(request_avaliable, gp.NUM_OF_UAV, axis=1) 161 | result = [] 162 | for requests in request_avaliable: 163 | each_uav_request = np.zeros(int(self.action_space / self.uav_num)) 164 | temp = np.split(requests, (int(np.ceil(gp.LENGTH_OF_FIELD / gp.DENSE_OF_ACCESSPOINT))), axis=1) 165 | for each_column in temp: 166 | each_uav_request += \ 167 | np.sum(each_column.reshape((int(np.ceil(gp.LENGTH_OF_FIELD / gp.DENSE_OF_ACCESSPOINT))), -1), axis=0) 168 | result.append(each_uav_request) 169 | request_avaliable = np.array(result) 170 | request_avaliable[request_avaliable > 0] = 1 171 | each_prob = np.split(result_prob, self.uav_num) 172 | res_prob = np.multiply(request_avaliable, each_prob) 173 | result = np.argsort(res_prob, axis=1)[:, -gp.DEFAULT_NUM_OF_RB * gp.DEFAULT_NUM_OF_RB_PER_RES:][:, ::-1] 174 | if result.shape[0] != self.uav_num: 175 | raise ValueError("Dimension Not match") 176 | 177 | index_list = [] 178 | for index, array in enumerate(list(result)): 179 | array = array + index * int(self.action_space / self.uav_num) 180 | index_list.extend(list(array)) 181 | return index_list 182 | 183 | def learn(self, mem): 184 | if not self.active: 185 | return 186 | # Sample transitions 187 | idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size) 188 | # for index in range(returns.shape[1]): 189 | # index = np.random.randint(0, returns.shape[1]) 190 | # actions_temp = actions[:, index:returns.shape[1] * self.uav_num:returns.shape[1]] 191 | # Calculate current state probabilities (online network noise already sampled) 192 | log_ps = self.online_net(states, log=True) # Log probabilities log p(s_t, ·; θonline) 193 | log_ps_a = log_ps[[[x] for x in range(self.batch_size)], actions, :] # log p(s_t, a_t; θonline) 194 | log_ps_a = torch.reshape(log_ps_a, (self.batch_size, returns.shape[1], -1, self.atoms)).mean(dim=2) 195 | # log_ps_a = torch.mean(log_ps_a, 1) 196 | 197 | with torch.no_grad(): 198 | # Calculate nth next state probabilities 199 | pns = self.online_net(next_states) # Probabilities p(s_t+n, ·; θonline) 200 | dns = self.support.expand_as(pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) 201 | dns = dns.sum(2) 202 | # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] 203 | argmax_indices_ns = torch.tensor([self.convert_result_prob_to_popularity(temp.numpy(), state) 204 | for temp, state in zip(dns, next_states)], 205 | dtype=torch.int64, device=self.device) 206 | self.target_net.reset_noise() # Sample new target net noise 207 | # argmax_indices_ns = argmax_indices_ns[:, index:returns.shape[1] * self.uav_num:returns.shape[1]] 208 | pns = self.target_net(next_states) # Probabilities p(s_t+n, ·; θtarget) 209 | pns_a = pns[[[x] for x in range(self.batch_size)], actions, :] 210 | pns_a = torch.reshape(pns_a, (self.batch_size, returns.shape[1], -1, self.atoms)).mean(dim=2) 211 | # pns_a = torch.mean(pns_a, 1) 212 | # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) 213 | 214 | # Compute Tz (Bellman operator T applied to z) 215 | Tz = returns.unsqueeze(2) + nonterminals.unsqueeze(2) * (self.discount ** self.n) * (self.support.unsqueeze(0)).unsqueeze(0) 216 | # Tz = R^n + (γ^n)z (accounting for terminal states) 217 | Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values 218 | # Compute L2 projection of Tz onto fixed support z 219 | b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz 220 | l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) 221 | # Fix disappearing probability mass when l = b = u (b is int) 222 | l[(u > 0) * (l == u)] -= 1 223 | u[(l < (self.atoms - 1)) * (l == u)] += 1 224 | 225 | # Distribute probability of Tz 226 | m = states.new_zeros(self.batch_size, 1, self.atoms, dtype=torch.float32) 227 | offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( 228 | self.batch_size, self.atoms).to(actions).unsqueeze(1) 229 | m.view(-1).index_add_(0, (l + offset).view(-1), 230 | (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) 231 | m.view(-1).index_add_(0, (u + offset).view(-1), 232 | (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) 233 | # m = m.unsqueeze_(1).expand(self.batch_size, select_length, self.atoms) 234 | 235 | loss = -torch.sum(m * log_ps_a, 2).mean(dim=1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) 236 | self.online_net.zero_grad() 237 | (weights * loss).mean().backward() # Backpropagate importance-weighted minibatch loss 238 | self.optimiser.step() 239 | 240 | mem.update_priorities(idxs, loss.detach().cpu().numpy()) 241 | # Update priorities of sampled transitions 242 | 243 | def learn_single(self, mem): 244 | if not self.active: 245 | return 246 | # Sample transitions 247 | idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size) 248 | total_loss = 0 249 | # for index in range(returns.shape[1]): 250 | index = np.random.randint(0, returns.shape[1]) 251 | actions_temp = actions[:, index:returns.shape[1] * self.uav_num:returns.shape[1]] 252 | # Calculate current state probabilities (online network noise already sampled) 253 | log_ps = self.online_net(states, log=True) # Log probabilities log p(s_t, ·; θonline) 254 | log_ps_a = log_ps[[[x] for x in range(self.batch_size)], actions_temp, :] # log p(s_t, a_t; θonline) 255 | log_ps_a = torch.mean(log_ps_a, 1) 256 | 257 | with torch.no_grad(): 258 | # Calculate nth next state probabilities 259 | pns = self.online_net(next_states) # Probabilities p(s_t+n, ·; θonline) 260 | dns = self.support.expand_as(pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) 261 | dns = dns.sum(2) 262 | # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] 263 | argmax_indices_ns = torch.tensor([self.convert_result_prob_to_popularity(temp.numpy(), state) 264 | for temp, state in zip(dns, next_states)], 265 | dtype=torch.int64, device=self.device) 266 | self.target_net.reset_noise() # Sample new target net noise 267 | argmax_indices_ns = argmax_indices_ns[:, index:returns.shape[1] * self.uav_num:returns.shape[1]] 268 | pns = self.target_net(next_states) # Probabilities p(s_t+n, ·; θtarget) 269 | pns_a = pns[[[x] for x in range(self.batch_size)], argmax_indices_ns, :] 270 | pns_a = torch.mean(pns_a, 1) 271 | # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) 272 | 273 | # Compute Tz (Bellman operator T applied to z) 274 | Tz = returns[:, index].unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze( 275 | 0) # Tz = R^n + (γ^n)z (accounting for terminal states) 276 | Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values 277 | # Compute L2 projection of Tz onto fixed support z 278 | b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz 279 | l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) 280 | # Fix disappearing probability mass when l = b = u (b is int) 281 | l[(u > 0) * (l == u)] -= 1 282 | u[(l < (self.atoms - 1)) * (l == u)] += 1 283 | 284 | # Distribute probability of Tz 285 | m = states.new_zeros(self.batch_size, self.atoms, dtype=torch.float32) 286 | offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( 287 | self.batch_size, self.atoms).to(actions) 288 | m.view(-1).index_add_(0, (l + offset).view(-1), 289 | (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) 290 | m.view(-1).index_add_(0, (u + offset).view(-1), 291 | (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) 292 | # m = m.unsqueeze_(1).expand(self.batch_size, select_length, self.atoms) 293 | 294 | loss = -torch.sum(m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) 295 | self.online_net.zero_grad() 296 | clip_grad_norm_(self.online_net.parameters(), 1.0, norm_type=1) 297 | (weights * loss).mean().backward() # Backpropagate importance-weighted minibatch loss 298 | self.optimiser.step() 299 | 300 | mem.update_priorities(idxs, loss.detach().cpu().numpy()) 301 | # Update priorities of sampled transitions 302 | 303 | def update_target_net(self): 304 | if not self.active: 305 | return 306 | self.target_net.load_state_dict(self.online_net.state_dict()) 307 | 308 | # Save model parameters on current device (don't move model between devices) 309 | def save(self, path, name='scheduler_model.pth'): 310 | if not self.active: 311 | return 312 | torch.save(self.online_net.state_dict(), os.path.join(path, name)) 313 | 314 | # Evaluates Q-value based on single state (no batch) 315 | def evaluate_q(self, state): 316 | if not self.active: 317 | return 0 318 | with torch.no_grad(): 319 | dns = (self.online_net(state.unsqueeze(0)) * self.support).sum(2)[0] 320 | index = self.convert_result_prob_to_popularity(dns.numpy(), state) 321 | return torch.mean(dns[index]).item() 322 | 323 | def train(self): 324 | if not self.active: 325 | return 326 | self.online_net.train() 327 | 328 | def eval(self): 329 | if not self.active: 330 | return 331 | self.online_net.eval() 332 | -------------------------------------------------------------------------------- /rainbow_hac/memory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from collections import namedtuple 4 | import numpy as np 5 | import torch 6 | import global_parameters as gp 7 | 8 | Transition = namedtuple('Transition', ('timestep', 'state', 'action', 'reward', 'nonterminal')) 9 | if gp.OBSERVATION_VERSION <= 7: 10 | blank_trans_aps = Transition(0, torch.zeros([int(np.ceil(gp.ACCESS_POINTS_FIELD / gp.SQUARE_STEP) * gp.OBSERVATION_DIMS * gp.NUM_OF_UAV), 11 | int(np.ceil(gp.ACCESS_POINTS_FIELD / gp.SQUARE_STEP))], dtype=torch.float32), None, 0, False) 12 | if gp.OBSERVATION_VERSION == 8: 13 | blank_trans_aps = Transition(0, torch.zeros([gp.OBSERVATION_DIMS * gp.NUM_OF_UAV, int(np.ceil(gp.ACCESS_POINTS_FIELD / gp.SQUARE_STEP)), 14 | int(np.ceil(gp.ACCESS_POINTS_FIELD / gp.SQUARE_STEP))], dtype=torch.float32), None, 0, False) 15 | if gp.GOP >= 2: 16 | blank_trans_sche = Transition(0, torch.zeros([gp.UAV_FIELD_OF_VIEW[0] * 2 * gp.ACCESS_POINT_PER_EDGE, 17 | gp.UAV_FIELD_OF_VIEW[1] * gp.ACCESS_POINT_PER_EDGE * gp.NUM_OF_UAV], 18 | dtype=torch.float32), None, 0, False) 19 | else: 20 | blank_trans_sche = Transition(0, torch.zeros( 21 | [gp.UAV_FIELD_OF_VIEW[0] * gp.ACCESS_POINT_PER_EDGE, 22 | gp.UAV_FIELD_OF_VIEW[1] * gp.ACCESS_POINT_PER_EDGE * gp.NUM_OF_UAV], 23 | dtype=torch.float32), None, 0, False) 24 | # TODO: Change the size of black_trans which should match with observation 25 | 26 | 27 | # Segment tree data structure where parent node values are sum/max of children node values 28 | class SegmentTree: 29 | def __init__(self, size): 30 | self.index = 0 31 | self.size = size 32 | self.full = False # Used to track actual capacity 33 | self.sum_tree = np.zeros((2 * size - 1,), 34 | dtype=np.float32) # Initialise fixed size tree with all (priority) zeros 35 | self.data = np.array([None] * size) # Wrap-around cyclic buffer 36 | self.max = 1 # Initial max value to return (1 = 1^ω) 37 | 38 | # Propagates value up tree given a tree index 39 | def _propagate(self, index, value): 40 | parent = (index - 1) // 2 41 | left, right = 2 * parent + 1, 2 * parent + 2 42 | self.sum_tree[parent] = self.sum_tree[left] + self.sum_tree[right] 43 | if parent != 0: 44 | self._propagate(parent, value) 45 | 46 | # Updates value given a tree index 47 | def update(self, index, value): 48 | self.sum_tree[index] = value # Set new value 49 | self._propagate(index, value) # Propagate value 50 | self.max = max(value, self.max) 51 | 52 | def append(self, data, value): 53 | self.data[self.index] = data # Store data in underlying data structure 54 | self.update(self.index + self.size - 1, value) # Update tree 55 | self.index = (self.index + 1) % self.size # Update index 56 | self.full = self.full or self.index == 0 # Save when capacity reached 57 | self.max = max(value, self.max) 58 | 59 | # Searches for the location of a value in sum tree 60 | def _retrieve(self, index, value): 61 | left, right = 2 * index + 1, 2 * index + 2 62 | if left >= len(self.sum_tree): 63 | return index 64 | elif value <= self.sum_tree[left]: 65 | return self._retrieve(left, value) 66 | else: 67 | return self._retrieve(right, value - self.sum_tree[left]) 68 | 69 | # Searches for a value in sum tree and returns value, data index and tree index 70 | def find(self, value): 71 | index = self._retrieve(0, value) # Search for index of item from root 72 | data_index = index - self.size + 1 73 | return (self.sum_tree[index], data_index, index) # Return value, data index, tree index 74 | 75 | # Returns data given a data index 76 | def get(self, data_index): 77 | return self.data[data_index % self.size] 78 | 79 | def total(self): 80 | return self.sum_tree[0] 81 | 82 | 83 | class ReplayMemory: 84 | def __init__(self, args, capacity, typeof_black: bool, remove_function=None): 85 | """ 86 | :parameter typeof_black: True: use ap_black, False: use scheduler_black 87 | """ 88 | self.device = args.device 89 | self.typeof_black = typeof_black 90 | self.previous_action_obs_ap = args.previous_action_observable 91 | self.remove_function = remove_function 92 | self.capacity = capacity 93 | self.memory_capacity_accesspoint = args.memory_capacity_accesspoint 94 | if typeof_black: 95 | self.history = args.history_length_accesspoint 96 | self.n = args.multi_step_accesspoint 97 | else: 98 | self.history = args.history_length_scheduler 99 | self.n = args.multi_step_scheduler 100 | self.discount = args.discount 101 | self.priority_weight = args.priority_weight 102 | # Initial importance sampling weight β, annealed to 1 over course of training 103 | self.priority_exponent = args.priority_exponent 104 | self.t = 0 # Internal episode timestep counter 105 | self.transitions = SegmentTree(capacity) 106 | # Store transitions in a wrap-around cyclic buffer within a sum tree for querying priorities 107 | 108 | # Adds state and action at time t, reward and terminal at time t + 1 109 | def append(self, state, action, reward, terminal): 110 | state = state[-1].to(dtype=torch.float32, device=torch.device('cpu')) 111 | # Only store last frame and discretise to save memory 112 | self.transitions.append(Transition(self.t, state, action, reward, not terminal), 113 | self.transitions.max) # Store new transition with maximum priority 114 | self.t = 0 if terminal else self.t + 1 # Start new episodes with t = 0 115 | 116 | # Returns a transition with blank states where appropriate 117 | def _get_transition(self, idx): 118 | transition = np.array([None] * (self.history + self.n)) 119 | transition[self.history - 1] = self.transitions.get(idx) 120 | 121 | # --------------- make black_trans when the last frame is terminal frame --------- 122 | for t in range(self.history - 2, -1, -1): # e.g. 2 1 0 123 | if transition[t + 1].timestep == 0: 124 | transition[t] = blank_trans_aps if self.typeof_black else blank_trans_sche 125 | # If future frame has timestep 0 126 | else: 127 | transition[t] = self.transitions.get(idx - self.history + 1 + t) 128 | for t in range(self.history, self.history + self.n): # e.g. 4 5 6 129 | if transition[t - 1].nonterminal: 130 | transition[t] = self.transitions.get(idx - self.history + 1 + t) 131 | else: 132 | transition[t] = blank_trans_aps if self.typeof_black else blank_trans_sche 133 | # If prev (next) frame is terminal 134 | return transition 135 | 136 | # Returns a valid sample from a segment 137 | def _get_sample_from_segment(self, segment, i): 138 | prob, idx, tree_idx, valid = None, None, None, False 139 | while not valid: 140 | sample = np.random.uniform(i * segment, 141 | (i + 1) * segment) # Uniformly sample an element from within a segment 142 | prob, idx, tree_idx = self.transitions.find( 143 | sample) # Retrieve sample from tree with un-normalised probability 144 | # Resample if transition straddled current index or probablity 0 145 | if (self.transitions.index - idx) % self.capacity > self.n and ( 146 | idx - self.transitions.index) % self.capacity >= self.history and prob != 0: 147 | valid = True # Note that conditions are valid but extra conservative around buffer index 0 148 | 149 | # Retrieve all required transition data (from t - h to t + n) 150 | transition = self._get_transition(idx) 151 | # Create un-discretised state and nth next state, if number-step is 1, don't need to add another dims 152 | state = torch.stack([trans.state for trans in transition[:self.history]]).to(device=self.device).to( 153 | dtype=torch.float32) 154 | next_state = torch.stack([trans.state for trans in transition[self.n:self.n + self.history]]).to( 155 | device=self.device).to(dtype=torch.float32) 156 | if self.typeof_black and self.previous_action_obs_ap: 157 | state[-1] = self.remove_function(state[-1]) 158 | next_state[-1] = self.remove_function(next_state[-1]) 159 | # Discrete action to be used as index 160 | action = torch.tensor([transition[self.history - 1].action], dtype=torch.int64, device=self.device) 161 | # Calculate truncated n-step discounted return R^n = Σ_k=0->n-1 (γ^k)R_t+k+1 (note that invalid nth next states have reward 0) 162 | R = torch.stack([sum(self.discount ** n * transition[self.history + n - 1].reward 163 | for n in range(self.n))]).to(device=self.device).to(dtype=torch.float32) 164 | # Mask for non-terminal nth next states 165 | nonterminal = torch.tensor([transition[self.history + self.n - 1].nonterminal], dtype=torch.float32, 166 | device=self.device) 167 | 168 | return prob, idx, tree_idx, state, action, R, next_state, nonterminal 169 | 170 | def sample(self, batch_size): 171 | p_total = self.transitions.total() 172 | # Retrieve sum of all priorities (used to create a normalised probability distribution) 173 | segment = p_total / batch_size # Batch size number of segments, based on sum over all probabilities 174 | batch = [self._get_sample_from_segment(segment, i) for i in range(batch_size)] # Get batch of valid samples 175 | probs, idxs, tree_idxs, states, actions, returns, next_states, nonterminals = zip(*batch) 176 | states, next_states = torch.stack(states), torch.stack(next_states) 177 | actions, returns, nonterminals = torch.cat(actions), torch.cat(returns), torch.stack(nonterminals) 178 | probs = np.array(probs, dtype=np.float32) / p_total # Calculate normalised probabilities 179 | capacity = self.capacity if self.transitions.full else self.transitions.index 180 | weights = (capacity * probs) ** -self.priority_weight # Compute importance-sampling weights w 181 | weights = torch.tensor(weights / weights.max(), dtype=torch.float32, 182 | device=self.device) # Normalise by max importance-sampling weight from batch 183 | return tree_idxs, states, actions, returns, next_states, nonterminals, weights 184 | 185 | def update_priorities(self, idxs, priorities): 186 | priorities = np.power(priorities, self.priority_exponent) 187 | [self.transitions.update(idx, priority) for idx, priority in zip(idxs, priorities)] 188 | 189 | # Set up internal state for iterator 190 | def __iter__(self): 191 | self.current_idx = 0 192 | return self 193 | 194 | # Return valid states for validation 195 | def __next__(self): 196 | if self.current_idx == self.capacity: 197 | raise StopIteration 198 | if self.history <= 1: 199 | state = self.transitions.data[self.current_idx].state.unsqueeze(0) 200 | self.current_idx += 1 201 | return state 202 | # Create stack of states 203 | state_stack = [None] * self.history 204 | state_stack[-1] = self.transitions.data[self.current_idx].state 205 | if self.typeof_black and self.previous_action_obs_ap: 206 | state_stack[-1] = self.remove_function(state_stack[-1]) 207 | prev_timestep = self.transitions.data[self.current_idx].timestep 208 | for t in reversed(range(self.history - 1)): 209 | if prev_timestep == 0: 210 | state_stack[t] = blank_trans_aps.state if self.typeof_black else blank_trans_sche.state 211 | # If future frame has timestep 0 212 | else: 213 | state_stack[t] = self.transitions.data[self.current_idx + t - self.history + 1].state 214 | prev_timestep -= 1 215 | state = torch.stack(state_stack, 0).to(dtype=torch.float32, device=self.device) 216 | # Agent will turn into batch 217 | self.current_idx += 1 218 | return state 219 | 220 | next = __next__ # Alias __next__ for Python 2 compatibility 221 | -------------------------------------------------------------------------------- /rainbow_hac/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import plotly 5 | import multiprocessing 6 | import copy as cp 7 | import math 8 | import global_parameters as gp 9 | 10 | from plotly.graph_objs import Scatter 11 | from plotly.graph_objs.scatter import Line 12 | import torch 13 | import numpy as np 14 | 15 | from rainbow_hac.game import Decentralized_Game as Env 16 | 17 | 18 | def test_parallel(new_game, c_pipe, c_pipe_controller, train_history_sche, train_history_aps, eps): 19 | train_examples_sche = [] 20 | train_examples_aps = [] 21 | for index in range(len(new_game.accesspoint_list)): 22 | train_examples_aps.append([]) 23 | reward_sum_aps = [] 24 | 25 | sche_state, done = None, True 26 | for _ in range(eps): 27 | while True: 28 | if done: 29 | sche_state, reward_sum_aps, reward_sum_sche, done = new_game.reset(), [], [], False 30 | 31 | sche_pack, aps_pack, done = new_game.step_p(c_pipe_controller, c_pipe) # Step 32 | 33 | reward_sum_aps.append(aps_pack[2]) 34 | if done: 35 | train_examples_sche.append(new_game.get_finial_reward()) 36 | reward_sum_aps = np.mean(reward_sum_aps, axis=0) 37 | for index in range(len(new_game.accesspoint_list)): 38 | train_examples_aps[index].append(reward_sum_aps[index]) 39 | break 40 | train_history_sche.append(train_examples_sche) 41 | train_history_aps.append(train_examples_aps) 42 | 43 | for index in range(len(new_game.accesspoint_list)): 44 | c_pipe[index].send((np.array([False]), np.array([False]))) 45 | c_pipe[index].close() 46 | c_pipe_controller.send(np.array([False])) 47 | c_pipe_controller.close() 48 | 49 | del new_game 50 | 51 | 52 | # test whole system 53 | def test(args, T, controller, dqn, val_mem_aps, metrics_aps, val_mem_sche, metrics_sche, results_dir, evaluate=False): 54 | env = Env(args) 55 | env.reset() 56 | 57 | T_rewards_sche, T_Qs_sche = [], [] 58 | metrics_sche['steps'].append(T) 59 | T_rewards_aps, T_Qs_aps = [], [] 60 | for _ in range(len(env.accesspoint_list)): 61 | metrics_aps[_]['steps'].append(T) 62 | T_rewards_aps.append([]) 63 | T_Qs_aps.append([]) 64 | 65 | # Test performance over several episodes 66 | sche_state, reward_sum, done = None, [], True 67 | for _ in range(args.evaluation_episodes): 68 | while True: 69 | if done: 70 | sche_state, reward_sum, done = env.reset(), [], False 71 | 72 | sche_pack, aps_pack, done = env.step(controller, dqn) 73 | 74 | reward_sum.append(aps_pack[2]) 75 | if done: 76 | T_rewards_sche.append(env.get_finial_reward()) 77 | reward_sum = np.mean(reward_sum, axis=0) 78 | print(np.mean(T_rewards_sche)) 79 | for index in range(len(env.accesspoint_list)): 80 | T_rewards_aps[index].append(reward_sum[index]) 81 | break 82 | env.close() 83 | 84 | # Test Q-values over validation memory 85 | for state in val_mem_sche: # Iterate over valid states 86 | T_Qs_sche.append(controller.evaluate_q(state)) 87 | # Test Q-values over validation memory 88 | for index, val_mems in enumerate(val_mem_aps): 89 | for state in val_mems: # Iterate over valid states 90 | T_Qs_aps[index].append(dqn[index].evaluate_q(state)) 91 | 92 | avg_reward = sum(T_rewards_sche) / len(T_rewards_sche) 93 | avg_Q = sum(T_Qs_sche) / len(T_Qs_sche) 94 | 95 | avg_reward_aps, avg_Q_aps = [], [] 96 | for _ in range(len(env.accesspoint_list)): 97 | avg_reward_aps.append(sum(T_rewards_aps[_]) / len(T_rewards_aps[_])) 98 | avg_Q_aps.append(sum(T_Qs_aps[_]) / len(T_Qs_aps[_])) 99 | 100 | better = True 101 | if not evaluate: 102 | # Save model parameters if improved 103 | if avg_reward > metrics_sche['best_avg_reward'] * args.better_indicator: 104 | metrics_sche['best_avg_reward'] = avg_reward 105 | controller.save(results_dir) 106 | # reload the state dict if obtain a better model 107 | 108 | # Append to results and save metrics 109 | metrics_sche['rewards'].append(T_rewards_sche) 110 | metrics_sche['Qs'].append(T_Qs_sche) 111 | torch.save(metrics_sche, os.path.join(results_dir, 'eval_metrics.pth')) 112 | 113 | # Plot 114 | _plot_line(metrics_sche['steps'], metrics_sche['rewards'], 'eval_Reward', path=results_dir) 115 | _plot_line(metrics_sche['steps'], metrics_sche['Qs'], 'eval_Q', path=results_dir) 116 | 117 | better_aps = True 118 | if not evaluate: 119 | # Save model parameters if improved 120 | better_vote = np.array([False] * len(env.accesspoint_list), dtype=np.int32) 121 | worse_vote = np.array([False] * len(env.accesspoint_list), dtype=np.int32) 122 | for _ in range(len(env.accesspoint_list)): 123 | if avg_reward_aps[_] > metrics_aps[_]['best_avg_reward'] * args.better_indicator: 124 | metrics_aps[_]['best_avg_reward'] = avg_reward_aps[_] 125 | dqn[_].save(results_dir, _) 126 | better_vote[_] = True 127 | elif avg_reward_aps[_] * args.better_indicator > metrics_aps[_]['best_avg_reward']: 128 | worse_vote[_] = True 129 | if np.sum(better_vote) >= np.ceil(len(env.accesspoint_list) / 3 * 2): 130 | if not np.sum(worse_vote) >= np.ceil(len(env.accesspoint_list) / 3 * 2): 131 | for _ in range(len(env.accesspoint_list)): 132 | dqn[_].reload_step_state_dict() 133 | else: 134 | if gp.ENABLE_MODEL_RELOAD: 135 | for _ in range(len(env.accesspoint_list)): 136 | dqn[_].reload_step_state_dict(False) 137 | better_aps = False 138 | # reload the state dict if obtain a better model 139 | 140 | for _ in range(len(env.accesspoint_list)): 141 | # Append to results and save metrics 142 | metrics_aps[_]['rewards'].append(T_rewards_aps[_]) 143 | metrics_aps[_]['Qs'].append(T_Qs_aps[_]) 144 | torch.save(metrics_aps[_], os.path.join(results_dir, 'metrics' + str(_) + '.pth')) 145 | 146 | for _ in range(len(env.accesspoint_list)): 147 | # Plot 148 | _plot_line(metrics_aps[_]['steps'], metrics_aps[_]['rewards'], 'Reward' + str(_), path=results_dir) 149 | _plot_line(metrics_aps[_]['steps'], metrics_aps[_]['Qs'], 'Q' + str(_), path=results_dir) 150 | 151 | # Return average reward and Q-value 152 | return (avg_reward, avg_Q, better), (avg_reward_aps, avg_Q_aps, better_aps) 153 | 154 | 155 | # Test DQN 156 | def test_p(args, T, controller, dqn, val_mem_aps, metrics_aps, val_mem_sche, metrics_sche, results_dir, evaluate=False): 157 | env = Env(args) 158 | env.reset() 159 | T_rewards_sche, T_Qs_sche = [], [] 160 | metrics_sche['steps'].append(T) 161 | T_rewards_aps, T_Qs_aps = [], [] 162 | for _ in range(len(env.accesspoint_list)): 163 | metrics_aps[_]['steps'].append(T) 164 | T_rewards_aps.append([]) 165 | T_Qs_aps.append([]) 166 | 167 | num_cores = math.floor(min(multiprocessing.cpu_count(), gp.ALLOCATED_CORES) - 1) 168 | num_eps = math.ceil(math.ceil(args.evaluation_episodes / num_cores) / 169 | (gp.GOP * gp.DEFAULT_RESOURCE_BLOCKNUM)) 170 | # make sure each subprocess can finish all the game (end with done) 171 | with multiprocessing.Manager() as manager: 172 | train_history_sche = manager.list() 173 | train_history_aps = manager.list() 174 | 175 | p_pipe_list1 = [] 176 | c_pipe_list1 = [] 177 | for _ in range(num_cores): 178 | p_pipe, c_pipe = multiprocessing.Pipe() 179 | p_pipe_list1.append(p_pipe) 180 | c_pipe_list1.append(c_pipe) 181 | 182 | p_pipe_list2 = [] 183 | c_pipe_list2 = [] 184 | for _ in range(num_cores): 185 | temp1, temp2 = [], [] 186 | for temp in range(len(env.accesspoint_list)): 187 | p_pipe, c_pipe = multiprocessing.Pipe() 188 | temp1.append(p_pipe) 189 | temp2.append(c_pipe) 190 | p_pipe_list2.append(temp1) 191 | c_pipe_list2.append(temp2) 192 | p_pipe_list2 = np.array(p_pipe_list2) 193 | 194 | process_list = [] 195 | for _ in range(num_cores): 196 | process = multiprocessing.Process(target=test_parallel, 197 | args=(cp.deepcopy(env), c_pipe_list2[_], c_pipe_list1[_], 198 | train_history_sche, train_history_aps, num_eps)) 199 | process_list.append(process) 200 | 201 | for pro in process_list: 202 | pro.start() 203 | 204 | on_off1 = True 205 | on_off2 = True 206 | while on_off1 or on_off2: 207 | on_off1 = controller.lookup_server_loop(p_pipe_list1) 208 | temp = np.ones(len(env.accesspoint_list), dtype=bool) 209 | for index in range(len(env.accesspoint_list)): 210 | temp[index] = dqn[index].lookup_server_loop(p_pipe_list2[:, index]) 211 | on_off2 = temp.any() 212 | 213 | for pro in process_list: 214 | pro.join() 215 | pro.terminate() 216 | 217 | for res in train_history_sche: 218 | for reward in res: 219 | T_rewards_sche.append(reward) 220 | for res in train_history_aps: 221 | for index, memerys in enumerate(res): 222 | for reward in memerys: 223 | T_rewards_aps[index].append(reward) 224 | 225 | # Test Q-values over validation memory 226 | for state in val_mem_sche: # Iterate over valid states 227 | T_Qs_sche.append(controller.evaluate_q(state)) 228 | # Test Q-values over validation memory 229 | for index, val_mems in enumerate(val_mem_aps): 230 | for state in val_mems: # Iterate over valid states 231 | T_Qs_aps[index].append(dqn[index].evaluate_q(state)) 232 | 233 | avg_reward_sche = sum(T_rewards_sche) / len(T_rewards_sche) 234 | avg_Q_sche = sum(T_Qs_sche) / len(T_Qs_sche) 235 | 236 | avg_reward_aps, avg_Q_aps = [], [] 237 | for _ in range(len(env.accesspoint_list)): 238 | avg_reward_aps.append(sum(T_rewards_aps[_]) / len(T_rewards_aps[_])) 239 | avg_Q_aps.append(sum(T_Qs_aps[_]) / len(T_Qs_aps[_])) 240 | 241 | better_sche = True 242 | if not evaluate: 243 | # Save model parameters if improved 244 | if avg_reward_sche >= metrics_sche['best_avg_reward'] * args.better_indicator: 245 | if avg_reward_sche > metrics_sche['best_avg_reward']: 246 | metrics_sche['best_avg_reward'] = avg_reward_sche 247 | controller.save(results_dir) 248 | controller.reload_step_state_dict() 249 | else: 250 | if gp.ENABLE_MODEL_RELOAD: 251 | controller.reload_step_state_dict(False) 252 | better_sche = False 253 | # reload the state dict if obtain a better model 254 | 255 | # Append to results and save metrics 256 | metrics_sche['rewards'].append(T_rewards_sche) 257 | metrics_sche['Qs'].append(T_Qs_sche) 258 | torch.save(metrics_sche, os.path.join(results_dir, 'scheduler_metrics.pth')) 259 | 260 | # Plot 261 | _plot_line(metrics_sche['steps'], metrics_sche['rewards'], 'scheduler_Reward', path=results_dir) 262 | _plot_line(metrics_sche['steps'], metrics_sche['Qs'], 'scheduler_Q', path=results_dir) 263 | 264 | better_aps = True 265 | if not evaluate: 266 | # Save model parameters if improved 267 | better_vote = np.array([False] * len(env.accesspoint_list), dtype=np.int32) 268 | worse_vote = np.array([False] * len(env.accesspoint_list), dtype=np.int32) 269 | for _ in range(len(env.accesspoint_list)): 270 | if avg_reward_aps[_] >= metrics_aps[_]['best_avg_reward'] * args.better_indicator or better_sche: 271 | if avg_reward_aps[_] > metrics_aps[_]['best_avg_reward']: 272 | metrics_aps[_]['best_avg_reward'] = avg_reward_aps[_] 273 | dqn[_].save(results_dir, _) 274 | better_vote[_] = True 275 | elif avg_reward_aps[_] * args.better_indicator > metrics_aps[_]['best_avg_reward'] or better_sche: 276 | worse_vote[_] = True 277 | if np.sum(better_vote) >= np.ceil(len(env.accesspoint_list) / 3 * 2): 278 | if not np.sum(worse_vote) >= np.ceil(len(env.accesspoint_list) / 3 * 2): 279 | for _ in range(len(env.accesspoint_list)): 280 | dqn[_].reload_step_state_dict() 281 | else: 282 | if gp.ENABLE_MODEL_RELOAD: 283 | for _ in range(len(env.accesspoint_list)): 284 | dqn[_].reload_step_state_dict(False) 285 | better_aps = False 286 | # reload the state dict if obtain a better model 287 | 288 | for _ in range(len(env.accesspoint_list)): 289 | # Append to results and save metrics 290 | metrics_aps[_]['rewards'].append(T_rewards_aps[_]) 291 | metrics_aps[_]['Qs'].append(T_Qs_aps[_]) 292 | torch.save(metrics_aps[_], os.path.join(results_dir, 'metrics' + str(_) + '.pth')) 293 | 294 | for _ in range(len(env.accesspoint_list)): 295 | # Plot 296 | _plot_line(metrics_aps[_]['steps'], metrics_aps[_]['rewards'], 'Reward' + str(_), path=results_dir) 297 | _plot_line(metrics_aps[_]['steps'], metrics_aps[_]['Qs'], 'Q' + str(_), path=results_dir) 298 | 299 | # Return average reward and Q-value 300 | return (avg_reward_sche, avg_Q_sche, better_sche), (avg_reward_aps, avg_Q_aps, better_aps) 301 | 302 | 303 | # Plots min, max and mean + standard deviation bars of a population over time 304 | def _plot_line(xs, ys_population, title, path=''): 305 | max_colour, mean_colour, std_colour, transparent = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)', 'rgba(0, 0, 0, 0)' 306 | 307 | ys = torch.tensor(ys_population, dtype=torch.float32) 308 | ys_min, ys_max, ys_mean, ys_std = ys.min(1)[0].squeeze(), ys.max(1)[0].squeeze(), ys.mean(1).squeeze(), ys.std( 309 | 1).squeeze() 310 | ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std 311 | 312 | trace_max = Scatter(x=xs, y=ys_max.numpy(), line=Line(color=max_colour, dash='dash'), name='Max') 313 | trace_upper = Scatter(x=xs, y=ys_upper.numpy(), line=Line(color=transparent), name='+1 Std. Dev.', showlegend=False) 314 | trace_mean = Scatter(x=xs, y=ys_mean.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour), 315 | name='Mean') 316 | trace_lower = Scatter(x=xs, y=ys_lower.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color=transparent), 317 | name='-1 Std. Dev.', showlegend=False) 318 | trace_min = Scatter(x=xs, y=ys_min.numpy(), line=Line(color=max_colour, dash='dash'), name='Min') 319 | 320 | plotly.offline.plot({ 321 | 'data': [trace_upper, trace_mean, trace_lower, trace_min, trace_max], 322 | 'layout': dict(title=title, xaxis={'title': 'Step'}, yaxis={'title': title}) 323 | }, filename=os.path.join(path, title + '.html'), auto_open=False) 324 | -------------------------------------------------------------------------------- /rainbow_hac/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #from __future__ import division 3 | import argparse 4 | import bz2 5 | from datetime import datetime 6 | import os 7 | import sys 8 | 9 | sys.path.append('../..') 10 | sys.path.append('./') 11 | 12 | import pickle 13 | import global_parameters as gp 14 | 15 | import numpy as np 16 | import math 17 | import copy 18 | import torch 19 | from tqdm import trange 20 | 21 | import multiprocessing 22 | import torch.multiprocessing 23 | # torch.multiprocessing.set_sharing_strategy('file_system') 24 | # TODO: When running in server, uncomment this line if needed 25 | import copy as cp 26 | 27 | from rainbow_hac.ap_agent import Agent 28 | from rainbow_hac.center_agent import CT_Agent as Controller 29 | from rainbow_hac.game import Decentralized_Game as Env 30 | from rainbow_hac.memory import ReplayMemory 31 | from rainbow_hac.test import test, test_p 32 | 33 | # Note that hyperparameters may originally be reported in ATARI game frames instead of agent steps 34 | parser = argparse.ArgumentParser(description='Rainbow') 35 | parser.add_argument('--id', type=str, default='default', help='Experiment ID') 36 | parser.add_argument('--seed', type=int, default=123, help='Random seed') 37 | parser.add_argument('--active-scheduler', action='store_false', help='Active scheduler') 38 | parser.add_argument('--active-accesspoint', action='store_false', help='Active AP') 39 | parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA') 40 | # parser.add_argument('--game', type=str, default='transmit-vr', choices=atari_py.list_games(), help='Environment game') 41 | parser.add_argument('--T-max', type=int, default=int(50e6), metavar='STEPS', 42 | help='Number of training steps (4x number of frames)') 43 | parser.add_argument('--max-episode-length', type=int, default=int(108e3), metavar='LENGTH', 44 | help='Max episode length in game frames (0 to disable)') 45 | # TODO: Note that the change of UAV numbers should also change the history-length variable 46 | parser.add_argument('--previous-action-observable', action='store_true', help='Observe previous action? (AP)') 47 | parser.add_argument('--history-length-accesspoint', type=int, default=2, metavar='T', 48 | help='Total number of history state') 49 | parser.add_argument('--history-length-scheduler', type=int, default=1, metavar='T', 50 | help='Total number of history state') 51 | parser.add_argument('--state-dims', type=int, default=gp.NUM_OF_UAV * gp.OBSERVATION_DIMS, metavar='S', 52 | help='Total number of dims in consecutive states processed, UAV * 3 for current version') 53 | parser.add_argument('--dense-of-uav', type=int, default=gp.NUM_OF_UAV, metavar='UAV', 54 | help='Total number of UAVs') 55 | parser.add_argument('--user-cluster-scale', type=int, default=gp.UE_SCALE, metavar='UAV', 56 | help='Total number of UAVs') 57 | parser.add_argument('--architecture', type=str, default='canonical_4uav_61obv_3x3_mix', 58 | choices=['canonical_2uav_61obv_3x3_mix', 'canonical_4uav_61obv_3x3_mix', 59 | 'canonical_2uav_61obv_2x2_mix', 'canonical_4uav_61obv_2x2_mix', 60 | 'canonical_2uav_61obv_3x3', 'canonical_4uav_61obv_3x3', 61 | 'canonical_2uav_61obv_2x2', 'canonical_4uav_61obv_2x2', 62 | 'canonical_2uav_41obv_3x3_mix', 'canonical_4uav_41obv_3x3_mix', 63 | 'canonical_2uav_41obv_2x2_mix', 'canonical_4uav_41obv_2x2_mix', 64 | 'canonical_2uav_41obv_3x3', 'canonical_4uav_41obv_3x3', 65 | 'canonical_2uav_41obv_2x2', 'canonical_4uav_41obv_2x2', 66 | 'data-efficient', 'data-efficient_4uav_61obv_3x3_mix'], 67 | metavar='ARCH', help='Network architecture') 68 | # TODO: if select resnet8, obs v8 and dims 4 should be set in gp 69 | parser.add_argument('--hidden-size', type=int, default=512, metavar='SIZE', help='Network hidden size') 70 | parser.add_argument('--noisy-std', type=float, default=0.5, metavar='σ', 71 | help='Initial standard deviation of noisy linear layers') 72 | parser.add_argument('--noisy-std-controller-exploration', type=float, default=0.5, metavar='σ', 73 | help='Initial standard deviation of noisy linear layers') 74 | parser.add_argument('--atoms', type=int, default=51, metavar='C', help='Discretised size of value distribution') 75 | parser.add_argument('--atoms-sche', type=int, default=21, metavar='C', help='Discretised size of value distribution') 76 | parser.add_argument('--V-min', type=float, default=-2, metavar='V', help='Minimum of value distribution support') 77 | parser.add_argument('--V-max', type=float, default=2, metavar='V', help='Maximum of value distribution support') 78 | # TODO: Make sure the value located inside V_min and V_max 79 | parser.add_argument('--epsilon-min', type=float, default=0.0, metavar='ep_d', help='Minimum of epsilon') 80 | parser.add_argument('--epsilon-max', type=float, default=0.0, metavar='ep_u', help='Maximum of epsilon') 81 | parser.add_argument('--epsilon-delta', type=float, default=0.0001, metavar='ep_d', help='Decreasing step of epsilon') 82 | # TODO: Set the ep carefully 83 | parser.add_argument('--action-selection', type=str, default='greedy', metavar='action_type', 84 | choices=['greedy', 'boltzmann', 'no_limit'], 85 | help='Type of action selection algorithm, 1: greedy, 2: boltzmann') 86 | parser.add_argument('--model', type=str, default=None, metavar='PARAM', help='Pretrained model (state dict)') 87 | parser.add_argument('--memory-capacity-accesspoint', type=int, default=int(12e3), metavar='CAPACITY', 88 | help='Experience replay memory capacity') 89 | parser.add_argument('--memory-capacity-scheduler', type=int, default=int(12e3), metavar='CAPACITY', 90 | help='Experience replay memory capacity') 91 | parser.add_argument('--replay-frequency', type=int, default=4, metavar='k', help='Frequency of sampling from memory') 92 | parser.add_argument('--replay-frequency-scheduler', type=int, default=4, metavar='k', 93 | help='Frequency of sampling from memory') 94 | parser.add_argument('--priority-exponent', type=float, default=0.5, metavar='ω', 95 | help='Prioritised experience replay exponent (originally denoted α)') 96 | parser.add_argument('--priority-weight', type=float, default=0.4, metavar='β', 97 | help='Initial prioritised experience replay importance sampling weight') 98 | parser.add_argument('--multi-step-accesspoint', type=int, default=3, metavar='n', 99 | help='Number of steps for multi-step return') 100 | parser.add_argument('--multi-step-scheduler', type=int, default=3, metavar='n', 101 | help='Number of steps for multi-step return') 102 | parser.add_argument('--discount', type=float, default=0.9, metavar='γ', help='Discount factor') 103 | parser.add_argument('--target-update', type=int, default=int(8000), metavar='τ', 104 | help='Number of steps after which to update target network') 105 | parser.add_argument('--reward-clip', type=int, default=1, metavar='VALUE', help='Reward clipping (0 to disable)') 106 | parser.add_argument('--learning-rate', type=float, default=0.0000625, metavar='η', help='Learning rate') 107 | parser.add_argument('--adam-eps', type=float, default=1.5e-4, metavar='ε', help='Adam epsilon') 108 | parser.add_argument('--batch-size', type=int, default=32, metavar='SIZE', help='Batch size') 109 | parser.add_argument('--better-indicator', type=float, default=1.0, metavar='b', 110 | help='The new model should be b times of old performance to be recorded') 111 | # TODO: Switch interval should not be large 112 | parser.add_argument('--learn-start', type=int, default=int(1000), metavar='STEPS', 113 | help='Number of steps before starting training') 114 | parser.add_argument('--learn-start-scheduler', type=int, default=int(1000), metavar='STEPS', 115 | help='Number of steps before starting training') 116 | parser.add_argument('--evaluate', action='store_true', help='Evaluate only') 117 | parser.add_argument('--data-reinforce', action='store_true', help='DataReinforcement') 118 | # TODO: Change this after debug 119 | parser.add_argument('--evaluation-interval', type=int, default=500, metavar='STEPS', 120 | help='Number of training steps between evaluations') 121 | parser.add_argument('--evaluation-episodes', type=int, default=20000, metavar='N', 122 | help='Number of evaluation episodes to average over') 123 | # TODO: Note that DeepMind's evaluation method is running the latest agent for 500K frames ever every 1M steps 124 | # TODO: Change this after debug 125 | parser.add_argument('--evaluation-size', type=int, default=20, metavar='N', 126 | help='Number of transitions to use for validating Q') 127 | # TODO: This evaluation-size is used for Q value evaluation, can be small if Q is not important 128 | parser.add_argument('--render', action='store_false', help='Display screen (testing only)') 129 | parser.add_argument('--enable-cudnn', action='store_true', help='Enable cuDNN (faster but nondeterministic)') 130 | parser.add_argument('--checkpoint-interval', default=0, 131 | help='How often to checkpoint the model, defaults to 0 (never checkpoint)') 132 | parser.add_argument('--memory', help='Path to save/load the memory from') 133 | parser.add_argument('--disable-bzip-memory', action='store_false', 134 | help='Don\'t zip the memory file. Not recommended (zipping is a bit slower and much, much smaller)') 135 | # TODO: Change federated round each time 136 | parser.add_argument('--federated-round', type=int, default=20, metavar='F', 137 | help='Rounds to perform global combination, set a negative number to disable federated aggregation') 138 | 139 | # Setup 140 | args = parser.parse_args() 141 | 142 | gp.UE_SCALE = args.user_cluster_scale 143 | 144 | print(' ' * 26 + 'Options') 145 | for k, v in vars(args).items(): 146 | print(' ' * 26 + k + ': ' + str(v)) 147 | results_dir = os.path.join('./results', args.id) 148 | if not os.path.exists(results_dir): 149 | os.makedirs(results_dir) 150 | 151 | metrics = {'steps': [], 'rewards': [], 'Qs': [], 'best_avg_reward': -float('inf')} 152 | np.random.seed(args.seed) 153 | torch.manual_seed(np.random.randint(1, 10000)) 154 | # if torch.cuda.is_available() and not args.disable_cuda: 155 | # args.device = torch.device('cuda') 156 | # torch.cuda.manual_seed(np.random.randint(1, 10000)) 157 | # torch.backends.cudnn.enabled = args.enable_cudnn 158 | # else: 159 | # args.device = torch.device('cpu') 160 | args.device = torch.device('cpu') 161 | 162 | 163 | # Simple ISO 8601 timestamped logger 164 | def log(s): 165 | print('[' + str(datetime.now().strftime('%Y-%m-%dT%H:%M:%S')) + '] ' + s) 166 | 167 | 168 | def average_weights(list_of_weight): 169 | """aggregate all weights""" 170 | averga_w = copy.deepcopy(list_of_weight[0]) 171 | for key in averga_w.keys(): 172 | for ind in range(1, len(list_of_weight)): 173 | averga_w[key] += list_of_weight[ind][key] 174 | averga_w[key] = torch.div(averga_w[key], len(list_of_weight)) 175 | return averga_w 176 | 177 | 178 | def load_memory(memory_path, disable_bzip): 179 | if disable_bzip: 180 | with open(memory_path, 'rb') as pickle_file: 181 | return pickle.load(pickle_file) 182 | else: 183 | with bz2.open(memory_path, 'rb') as zipped_pickle_file: 184 | return pickle.load(zipped_pickle_file) 185 | 186 | 187 | def save_memory(memory, memory_path, disable_bzip, scheduller_or_ap, index=-1): 188 | if not scheduller_or_ap: 189 | # save ap mem 190 | memory_path = memory_path[0:-4] + '_aps_' + str(index) + memory_path[-4:] 191 | else: 192 | memory_path = memory_path[0:-4] + '_sche' + memory_path[-4:] 193 | if disable_bzip: 194 | with open(memory_path, 'wb') as pickle_file: 195 | pickle.dump(memory, pickle_file) 196 | else: 197 | with bz2.open(memory_path, 'wb') as zipped_pickle_file: 198 | pickle.dump(memory, zipped_pickle_file) 199 | 200 | 201 | def run_game_once_parallel_random(new_game, train_history_sche_parallel, train_history_aps_parallel, episode): 202 | train_examples_aps = [] 203 | train_examples_sche = [] 204 | for _ in range(number_of_aps): 205 | train_examples_aps.append([]) 206 | eps, _, done_pp, _ = 0, None, True, None 207 | while eps < episode: 208 | if done_pp: 209 | _, done_pp = new_game.reset(), False 210 | 211 | sche_pack_p, aps_pack_p, done_pp = new_game.step(np.random.rand(scheduling_size[0]), [np.random.randint(0, action_space) 212 | for _ in range(number_of_aps)], True) # Step 213 | 214 | for index_p, ele_p in enumerate(aps_pack_p): 215 | train_examples_aps[index_p].append((ele_p, None, None, done_pp)) 216 | 217 | train_examples_sche.append((sche_pack_p, None, None, done_pp)) 218 | eps += 1 219 | train_history_aps_parallel.append(train_examples_aps) 220 | train_history_sche_parallel.append(train_examples_sche) 221 | 222 | 223 | # Environment 224 | env = Env(args) 225 | env.reset() 226 | action_space = env.get_action_size() 227 | scheduling_size = env.get_resource_action_space() 228 | number_of_aps = len(env.accesspoint_list) 229 | 230 | # Controller 231 | controller = Controller(args, env) 232 | 233 | # Agent 234 | dqn = [] 235 | matric = [] 236 | for _ in range(number_of_aps): 237 | # dqn.append(temp) 238 | dqn.append(Agent(args, env, _)) 239 | matric.append(copy.deepcopy(metrics)) 240 | 241 | if args.federated_round > 0: 242 | global_model = Agent(args, env, "Global_") 243 | 244 | # If a model is provided, and evaluate is fale, presumably we want to resume, so try to load memory 245 | if args.model is not None and not args.evaluate: 246 | if not args.memory: 247 | raise ValueError('Cannot resume training without memory save path. Aborting...') 248 | elif not os.path.exists(args.memory): 249 | raise ValueError('Could not find memory file at {path}. Aborting...'.format(path=args.memory)) 250 | 251 | mem_aps = [] 252 | for index in range(number_of_aps): 253 | path = os.path.join(args.memory, ('metrics_aps' + str(index) + '.pth')) 254 | mem_aps.append(load_memory(path, args.disable_bzip_memory)) 255 | path = os.path.join(args.memory, ('metrics_sche' + '.pth')) 256 | mem_sche = load_memory(path, args.disable_bzip_memory) 257 | 258 | else: 259 | mem_aps = [] 260 | for _ in range(number_of_aps): 261 | mem_aps.append(ReplayMemory(args, args.memory_capacity_accesspoint, True, env.remove_previous_action)) 262 | mem_sche = ReplayMemory(args, args.memory_capacity_scheduler, False) 263 | 264 | priority_weight_increase = (1 - args.priority_weight) / (args.T_max - args.learn_start) 265 | 266 | # Construct validation memory 267 | val_mem_aps = [] 268 | val_mem_sche = ReplayMemory(args, args.evaluation_size, False) 269 | for _ in range(number_of_aps): 270 | val_mem_aps.append(ReplayMemory(args, args.evaluation_size, True, env.remove_previous_action)) 271 | if not gp.PARALLEL_EXICUSION: 272 | T, done = 0, True 273 | while T < args.evaluation_size: 274 | if done: 275 | _, done = env.reset(), False 276 | 277 | sche_pack, aps_pack, done = env.step(np.random.rand(scheduling_size[0]), 278 | [np.random.randint(0, action_space) 279 | for _ in range(number_of_aps)], True) 280 | val_mem_sche.append(sche_pack, None, None, done) 281 | for index, ele in enumerate(aps_pack): 282 | val_mem_aps[index].append(ele, None, None, done) 283 | T += 1 284 | else: 285 | num_cores = min(multiprocessing.cpu_count(), gp.ALLOCATED_CORES) - 1 286 | num_eps = math.ceil(math.ceil(args.evaluation_size / num_cores) / 287 | (gp.GOP * gp.DEFAULT_RESOURCE_BLOCKNUM)) * (gp.GOP * gp.DEFAULT_RESOURCE_BLOCKNUM) 288 | # make sure each subprocess can finish all the game (end with done) 289 | with multiprocessing.Manager() as manager: 290 | train_history_sche = manager.list() 291 | train_history_aps = manager.list() 292 | 293 | process_list = [] 294 | for _ in range(num_cores): 295 | process = multiprocessing.Process(target=run_game_once_parallel_random, 296 | args=(cp.deepcopy(env), train_history_sche, 297 | train_history_aps, num_eps)) 298 | process_list.append(process) 299 | 300 | for pro in process_list: 301 | pro.start() 302 | for pro in process_list: 303 | pro.join() 304 | pro.terminate() 305 | 306 | for res in train_history_aps: 307 | for index, memerys in enumerate(res): 308 | for state, _, _, done in memerys: 309 | val_mem_aps[index].append(state, None, None, done) 310 | for memorys in train_history_sche: 311 | for state, _, _, done in memorys: 312 | val_mem_sche.append(state, None, None, done) 313 | 314 | if args.evaluate: 315 | controller.eval() 316 | for index in range(number_of_aps): 317 | dqn[index].eval() # Set DQN (online network) to evaluation mode 318 | avg_reward, avg_Q = test(args, 0, controller, dqn, val_mem_aps, matric, val_mem_sche, metrics, results_dir, evaluate=True) # Test 319 | for index in range(number_of_aps): 320 | print('Avg. reward for ap' + str(index) + ': ' + str(avg_reward[index]) + ' | Avg. Q: ' + str(avg_Q[index])) 321 | else: 322 | # Training loop 323 | T, aps_state, done, sche_state, epsilon = 0, None, True, None, args.epsilon_max 324 | reinforce_ap = [] 325 | for i in range(len(env.accesspoint_list)): 326 | temp = [] 327 | for j in range(3): 328 | temp.append([]) 329 | reinforce_ap.append(temp) 330 | reinforce_sche = [] 331 | for i in range(3): 332 | reinforce_sche.append([]) 333 | for T in trange(1, args.T_max + 1): 334 | # training loop 335 | if done: 336 | if T > 2: 337 | print(env.get_finial_reward()) 338 | sche_state, done = env.reset(), False 339 | if T > 1 and args.data_reinforce: 340 | for sche_pair in reinforce_sche: 341 | for sche_ele in sche_pair: 342 | mem_sche.append(sche_ele[0], sche_ele[1], sche_ele[2], sche_ele[3]) 343 | for index, ap_rein in enumerate(reinforce_ap): 344 | for ap_pair in ap_rein: 345 | for ap_ele in ap_pair: 346 | mem_aps[index].append(ap_ele[0], ap_ele[1], ap_ele[2], ap_ele[3]) 347 | reinforce_ap = [] 348 | for i in range(len(env.accesspoint_list)): 349 | temp = [] 350 | for j in range(3): 351 | temp.append([]) 352 | reinforce_ap.append(temp) 353 | reinforce_sche = [] 354 | for i in range(3): 355 | reinforce_sche.append([]) 356 | 357 | if T % args.replay_frequency == 0: 358 | controller.reset_noise() # Draw a new set of noisy weights 359 | for _ in range(number_of_aps): 360 | dqn[_].reset_noise() 361 | 362 | sche_pack, aps_pack, done = env.step(controller, dqn, False, epsilon) # Step 363 | epsilon = epsilon - args.epsilon_delta 364 | epsilon = np.clip(epsilon, a_min=args.epsilon_min, a_max=args.epsilon_max) 365 | 366 | if gp.ENABLE_EARLY_STOP: 367 | if env.center_server.clock % gp.DEFAULT_RESOURCE_BLOCKNUM == (gp.DEFAULT_RESOURCE_BLOCKNUM - 1): 368 | if env.center_server.obtain_centerlized_linear_reward() < gp.ENABLE_EARLY_STOP_THRESHOLD: 369 | done = True 370 | 371 | reward_sche = sche_pack[2] 372 | if args.reward_clip > 0: 373 | reward_sche = torch.clamp(reward_sche, max=args.reward_clip, min=-args.reward_clip) # Clip rewards 374 | mem_sche.append(sche_pack[0], sche_pack[1], reward_sche, done) # Append transition to memory 375 | 376 | reward_aps = aps_pack[2] 377 | for _ in range(number_of_aps): 378 | if args.reward_clip > 0: 379 | reward_aps[_] = torch.clamp(reward_aps[_], max=args.reward_clip, min=-args.reward_clip) # Clip rewards 380 | if not aps_pack[1][_] == -1: 381 | mem_aps[_].append(aps_pack[0][_], aps_pack[1][_], reward_aps[_], done) # Append transition to memory 382 | for direction in range(3): 383 | obs = aps_pack[0][_] 384 | if gp.OBSERVATION_VERSION <= 7: 385 | res = [] 386 | rot_obs = torch.split(obs, int(obs.shape[1] / (gp.OBSERVATION_DIMS * gp.NUM_OF_UAV)), dim=1) 387 | for index, ele in enumerate(rot_obs): 388 | res.append(torch.rot90(ele, direction+1, (1, 2))) 389 | obs = torch.cat(res, dim=1) 390 | if gp.OBSERVATION_VERSION == 8: 391 | obs = torch.rot90(obs, direction, (2, 3)) 392 | if not aps_pack[1][_] == -1: 393 | reinforce_ap[_][direction].append((obs, aps_pack[1][_], reward_aps[_], done)) 394 | # append rotated observation for data reinforcement 395 | obs = sche_pack[0][-1] 396 | rot_obs = list(torch.split(obs, gp.UAV_FIELD_OF_VIEW[1], dim=1)) 397 | res = [] 398 | for ele in rot_obs: 399 | if gp.GOP >= 2: 400 | res.append(torch.stack(torch.split(ele, gp.UAV_FIELD_OF_VIEW[0] * 2, dim=0))) 401 | else: 402 | res.append(torch.stack(torch.split(ele, gp.UAV_FIELD_OF_VIEW[0], dim=0))) 403 | res = torch.stack(res) 404 | res = torch.split(res, int(math.ceil(gp.LENGTH_OF_FIELD / gp.DENSE_OF_ACCESSPOINT)), dim=0) 405 | for direction in range(3): 406 | result = [] 407 | for ele in res: 408 | temp = torch.rot90(ele, direction+1, (0, 1)).reshape( 409 | int(math.ceil(gp.LENGTH_OF_FIELD / gp.DENSE_OF_ACCESSPOINT)), -1, gp.UAV_FIELD_OF_VIEW[1]) 410 | result.append(torch.cat([temp[each, :, :] for each in range(temp.shape[0])], dim=1)) 411 | result = torch.cat(result, dim=1).unsqueeze(0) 412 | reinforce_sche[direction].append((result, sche_pack[1], reward_sche, done)) 413 | # append rotated observation for data reinforcement 414 | 415 | # Train and test 416 | if T >= args.learn_start_scheduler: 417 | mem_sche.priority_weight = min(mem_sche.priority_weight + priority_weight_increase, 1) 418 | # Anneal importance sampling weight β to 1 419 | 420 | if T % args.replay_frequency_scheduler == 0: 421 | controller.learn(mem_sche) # Train with n-step distributional double-Q learning 422 | 423 | # If memory path provided, save it 424 | if args.memory is not None: 425 | save_memory(mem_sche, args.memory, args.disable_bzip_memory, True) 426 | 427 | # Update target network 428 | if T % args.target_update == 0: 429 | controller.update_target_net() 430 | 431 | # Checkpoint the network 432 | if (args.checkpoint_interval != 0) and (T % args.checkpoint_interval == 0): 433 | controller.save(results_dir, 'checkpoint_controller' + '.pth') 434 | 435 | if T >= args.learn_start: 436 | for index in range(number_of_aps): 437 | mem_aps[index].priority_weight = min(mem_aps[index].priority_weight + priority_weight_increase, 1) 438 | # Anneal importance sampling weight β to 1 439 | 440 | if T % args.replay_frequency == 0: 441 | for index in range(number_of_aps): 442 | dqn[index].learn(mem_aps[index]) # Train with n-step distributional double-Q learning 443 | 444 | if T % args.federated_round == 0 and 0 < args.federated_round: 445 | global_weight = average_weights([model.get_state_dict() for model in dqn]) 446 | global_model.set_state_dict(global_weight) 447 | log('T = ' + str(T) + ' / ' + str(args.T_max) + ' Global averaging starts') 448 | global_model.save(results_dir, 'Global_') 449 | for models in dqn: 450 | models.set_state_dict(global_weight) 451 | 452 | # If memory path provided, save it 453 | for index in range(number_of_aps): 454 | if args.memory is not None: 455 | save_memory(mem_aps[index], args.memory, args.disable_bzip_memory, False, index) 456 | 457 | # Update target network 458 | if T % args.target_update == 0: 459 | for index in range(number_of_aps): 460 | dqn[index].update_target_net() 461 | 462 | # Checkpoint the network 463 | if (args.checkpoint_interval != 0) and (T % args.checkpoint_interval == 0): 464 | for index in range(number_of_aps): 465 | dqn[index].save(results_dir, 'checkpoint' + str(index) + '.pth') 466 | 467 | if T % args.evaluation_interval == 0 and T > args.learn_start_scheduler and T > args.learn_start: 468 | controller.eval() # Set DQN (online network) to evaluation mode 469 | for index in range(number_of_aps): 470 | dqn[index].eval() # Set DQN (online network) to evaluation mode 471 | 472 | if gp.PARALLEL_EXICUSION: 473 | sche_pack, aps_pack = test_p(args, T, controller, dqn, val_mem_aps, matric, val_mem_sche, 474 | metrics, results_dir) # Test 475 | else: 476 | sche_pack, aps_pack = test(args, T, controller, dqn, val_mem_aps, matric, val_mem_sche, 477 | metrics, results_dir) # Test 478 | if sche_pack[2]: 479 | log('T = ' + str(T) + ' / ' + str(args.T_max) + ' Better model, accepted.') 480 | else: 481 | # mem_sche.expand_memory() 482 | log('T = ' + str(T) + ' / ' + str(args.T_max) + ' Worse model, reject.') 483 | log('T = ' + str(T) + ' / ' + str(args.T_max) + ' For controller' 484 | + ' | Avg. reward: ' + str(sche_pack[0]) + ' | Avg. Q: ' + str(sche_pack[1])) 485 | 486 | if aps_pack[2]: 487 | log('T = ' + str(T) + ' / ' + str(args.T_max) + ' Better model, accepted.') 488 | else: 489 | log('T = ' + str(T) + ' / ' + str(args.T_max) + ' Worse model, reject.') 490 | for index in range(number_of_aps): 491 | log('T = ' + str(T) + ' / ' + str(args.T_max) + ' For ap' + str(index) + 492 | ' | Avg. reward: ' + str(aps_pack[0][index]) + ' | Avg. Q: ' + str(aps_pack[1][index])) 493 | 494 | controller.train() # Set DQN (online network) back to training mode 495 | for index in range(number_of_aps): 496 | dqn[index].train() # Set DQN (online network) back to training mode 497 | 498 | env.close() 499 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch~=1.4.0 2 | numpy~=1.18.1 3 | tensorflow~=2.1.0 4 | tensorflow-gpu~=2.0.0 5 | tqdm~=4.41.1 6 | joblib~=0.14.1 7 | matplotlib~=3.1.3 8 | scipy~=1.4.1 9 | termcolor~=1.1.0 10 | tabulate~=0.8.3 11 | gym~=0.15.6 12 | plotly~=4.4.1 13 | pip~=20.0.2 14 | zlib~=1.2.11 15 | wheel~=0.34.2 16 | openssl~=1.1.1d 17 | cryptography~=2.5 18 | py~=1.8.1 -------------------------------------------------------------------------------- /user_correlation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import typing 3 | import global_parameters as gp 4 | import copy as cp 5 | 6 | 7 | # not progressive raw video 8 | 9 | class VR_Sphere: 10 | __slots__ = ['sphere_id', 'num_of_tile', 'tiles', 'transmission_mask', 'resource', 'its_resource'] 11 | 12 | def __init__(self, sphere_id: int, tiles: list, num_of_tile: int, resources: np.ndarray): 13 | self.sphere_id: int = int(sphere_id) 14 | self.num_of_tile = num_of_tile 15 | self.tiles = tiles # size of fov [x, y] 16 | self.transmission_mask = np.ones(self.num_of_tile, dtype=bool) 17 | if resources.size == 0: 18 | self.resource = np.arange(0, self.tiles[0] * self.tiles[1]) 19 | else: 20 | if type(resources).__module__ != np.__name__: 21 | raise TypeError("Resource input should be numpy array") 22 | if self.num_of_tile != resources.__len__(): 23 | raise ValueError("Size of resources not equal") 24 | if len(resources.shape) != 1: 25 | raise ValueError("Resource dimension should be 1") 26 | self.resource = resources 27 | self.its_resource = self.resource 28 | 29 | def __eq__(self, other): 30 | if self.sphere_id == other.sphere_id: 31 | if (np.equal(self.resource, other.resource)).all() and \ 32 | (np.equal(self.transmission_mask, other.transmission_mask)).all(): 33 | return True 34 | return False 35 | 36 | def __sub__(self, other): 37 | if self.sphere_id != other.sphere_id: 38 | return 0 39 | return np.intersect1d(self.resource[self.transmission_mask], 40 | other.resource[other.transmission_mask]).__len__() 41 | 42 | def __add__(self, other): 43 | if self.sphere_id != other.sphere_id: 44 | return 0 45 | return np.union1d(self.resource[self.transmission_mask], 46 | other.resource[other.transmission_mask]).__len__() 47 | 48 | def __truediv__(self, other): 49 | if self.sphere_id != other.sphere_id: 50 | return self 51 | self.resource = np.setdiff1d(self.resource[self.transmission_mask], 52 | other.resource[other.transmission_mask]) 53 | self.num_of_tile -= np.intersect1d(self.resource[self.transmission_mask], 54 | other.resource[other.transmission_mask]).__len__() 55 | 56 | def __mul__(self, other): 57 | if self.sphere_id != other.sphere_id: 58 | return self 59 | self.resource = np.union1d(self.resource[self.transmission_mask], 60 | other.resource[other.transmission_mask]) 61 | self.num_of_tile = self.resource.__len__() 62 | 63 | def __mod__(self, other): 64 | self.sphere_id = other.sphere_id 65 | self.resource = other.resource 66 | self.num_of_tile = other.resource.__len__() 67 | self.transmission_mask = other.transmission_mask 68 | 69 | def __str__(self): 70 | return "VR Resource at " + str(self.sphere_id) + "\n Resources: " + str(self.resource[self.transmission_mask]) 71 | 72 | 73 | class Field_of_View(VR_Sphere): 74 | __slots__ = ['sphere_id', 'num_of_tile', 'tiles', 'transmission_mask', 'resource', 75 | 'center_resource', 'its_resource'] 76 | 77 | def __init__(self, center_resource, resource_list, sphere_id, tiles): 78 | super(Field_of_View, self).__init__(sphere_id, tiles, resource_list.__len__(), resource_list) 79 | self.center_resource = center_resource 80 | 81 | def __mod__(self, other): 82 | super(Field_of_View, self).__mod__(other) 83 | self.center_resource = other.center_resource 84 | 85 | def __eq__(self, other): 86 | if super(Field_of_View, self).__eq__(other) and self.center_resource == other.center_resource: 87 | return True 88 | return False 89 | 90 | 91 | class Clustering: 92 | __slots__ = ['correlation_matrix', 'cluster_method', 'cluster_num', 'cluster_threshold', 'cluster_result'] 93 | 94 | def __init__(self, correlation_matrix, cluster_method=None, cluster_threshold=None, cluster_num=None): 95 | # correlation_matrix should be a two dim narray 96 | self.correlation_matrix = np.asarray(correlation_matrix) 97 | if correlation_matrix.shape[0] != correlation_matrix.shape[1]: 98 | raise ValueError("correlation matrix should be square") 99 | if cluster_method is None or cluster_threshold is None: 100 | self.cluster_method = "k-mean" 101 | self.cluster_num = 4 102 | self.cluster_threshold = 0.05 # 1 correlation with 20 meters distance 103 | else: 104 | self.cluster_method = cluster_method 105 | self.cluster_num = cluster_num 106 | self.cluster_threshold = cluster_threshold 107 | self.cluster_method = cluster_method 108 | self.correlation_matrix = correlation_matrix 109 | for index in range(0, self.correlation_matrix.shape[0]): 110 | self.correlation_matrix[index][index] = 0 111 | self.cluster_result = [] 112 | 113 | def cluster(self): 114 | self.correlation_matrix = np.where(self.correlation_matrix <= 115 | self.cluster_threshold, 0, self.correlation_matrix) 116 | max_clique = [] 117 | temp_node_list = np.array([index for index in range(0, self.correlation_matrix.shape[0]) 118 | if not (self.correlation_matrix[index] == 0).all()]) 119 | clusted_user_list = cp.copy(temp_node_list) 120 | if self.cluster_method == "PrivotingBK": 121 | while temp_node_list.shape[0] != 0: 122 | result = [] 123 | degency = self.degeneracy_ordering(temp_node_list) 124 | self.degeneracy_bk(degency, result) 125 | local_max_cli: np.ndarray = max(result, key=lambda p: p.shape[0]) 126 | max_clique.append(local_max_cli) 127 | temp = local_max_cli 128 | result = [ele for ele in result if (np.intersect1d(ele, temp)).shape[0] == 0] 129 | while len(result) != 0: 130 | local_max_cli = max(result, key=lambda p: p.shape[0]) 131 | temp = np.union1d(temp, local_max_cli) 132 | result = [ele for ele in result if (np.intersect1d(ele, temp)).shape[0] == 0] 133 | max_clique.append(local_max_cli) 134 | temp_node_list = np.setdiff1d(temp_node_list, temp) 135 | self.cluster_result = max_clique 136 | if self.cluster_method == "PrivotingBK_greedy": 137 | while temp_node_list.shape[0] != 0: 138 | result = self.greedy_bk(temp_node_list) 139 | temp_node_list = np.setdiff1d(temp_node_list, result) 140 | max_clique.append(result) 141 | self.cluster_result = max_clique 142 | if self.cluster_method == "k-mean": 143 | raise TypeError("Haven't implement k-mean!!") 144 | return clusted_user_list 145 | # Return node list which is being clustered, for additional clustering calculation. 146 | # in case of existing solo node which shows all 0 in correlation matrix 147 | # Clustering users 148 | 149 | def update_correlation_matrix(self, new_matrix): 150 | self.correlation_matrix = new_matrix 151 | 152 | def degeneracy_ordering(self, list_of_node): 153 | degeneracy = [] 154 | degree = np.negative(np.ones(self.correlation_matrix.shape[0])) 155 | for ues in range(0, self.correlation_matrix.shape[0]): 156 | if ues in list_of_node: 157 | degree[ues] = np.sum(np.where(self.correlation_matrix[ues, :] != 0, 1, 0)) 158 | else: 159 | degree[ues] = np.Inf 160 | for _ in list_of_node: 161 | minimum_degree = np.argmin(degree) 162 | if minimum_degree not in list_of_node: 163 | raise ValueError("Out of list range.") 164 | degree[np.nonzero(np.where(self.correlation_matrix[minimum_degree, :] != 0, 1, 0))] -= 1 165 | degeneracy.append(minimum_degree) 166 | degree[minimum_degree] = np.Inf 167 | return degeneracy 168 | 169 | def degeneracy_bk(self, degeneracy, result): 170 | p = degeneracy 171 | x = np.array([]) 172 | for index, vertex in enumerate(degeneracy): 173 | neighbor = np.nonzero(np.where(self.correlation_matrix[vertex, :] != 0, 1, 0)) 174 | r = np.array([vertex]) 175 | self.privoting_bk(r, np.intersect1d(p, neighbor), np.intersect1d(x, neighbor), result) 176 | p = np.setdiff1d(p, vertex) 177 | x = np.union1d(x, vertex) 178 | 179 | def privoting_bk(self, r, p, x, result): 180 | if len(p) == 0 and len(x) == 0: 181 | result.append(r) 182 | return 183 | maximum_neighbor_num = 0 184 | maximum_neighbor = np.array([]) 185 | for index in p: 186 | num_neighbor = np.sum(np.where(self.correlation_matrix[index, :] != 0, 1, 0)) 187 | if num_neighbor > maximum_neighbor_num: 188 | maximum_neighbor_num = num_neighbor 189 | maximum_neighbor = np.nonzero(np.where(self.correlation_matrix[index, :] != 0, 1, 0)) 190 | for vertex in np.setdiff1d(p, maximum_neighbor): 191 | self.privoting_bk(np.union1d(r, vertex), 192 | np.intersect1d(p, np.nonzero(np.where(self.correlation_matrix[vertex, :] != 0, 1, 0))), 193 | np.intersect1d(x, np.nonzero(np.where(self.correlation_matrix[vertex, :] != 0, 1, 0))), 194 | result) 195 | p = np.setdiff1d(p, vertex) 196 | x = np.union1d(x, vertex) 197 | 198 | def greedy_bk(self, list_of_nodes): 199 | clique = np.array([], dtype=np.int) 200 | vertices = list_of_nodes 201 | rand = np.random.randint(len(vertices), size=1) 202 | clique = np.append(clique, vertices[rand]) 203 | neighbor = [] 204 | for index in range(0, self.correlation_matrix.shape[0]): 205 | neighbor.append(np.nonzero(np.where(self.correlation_matrix[index, :] != 0, 1, 0))[0]) 206 | for v in vertices: 207 | if v in clique: 208 | continue 209 | is_next = True 210 | for u in clique: 211 | if u in neighbor[v]: 212 | continue 213 | else: 214 | is_next = False 215 | break 216 | if is_next: 217 | clique = np.append(clique, v) 218 | return np.sort(clique) 219 | 220 | def get_cluster_result(self): 221 | return self.cluster_result 222 | 223 | 224 | class User_VR(Field_of_View): 225 | __slots__ = ['sphere_id', 'num_of_tile', 'tiles', 'transmission_mask', 'resource', 'center_resource', 226 | 'id', 'position', 'mobility_range', 'original_source', 'sizeof_fov', 'resource_size', 'its_resource', 227 | 'in_range_ap'] 228 | 229 | def __init__(self, position: np.ndarray, mobility_range, user_index, center_resource: np.ndarray, 230 | original_source: VR_Sphere, ap_list, sizeof_fov=None, resource=None): 231 | # def __init__(self, center_resource, resource_list, sphere_id, tiles): 232 | # tiles: size of tiles [x,y], list 233 | # center_resource: postion of center tile: [x,y], list 234 | # resource_list: resource of list: numpy array (x,) 235 | self.id: int = int(user_index) 236 | self.position = position 237 | self.mobility_range = mobility_range 238 | self.original_source = original_source 239 | self.in_range_ap = [] 240 | self.calculate_range(ap_list) 241 | 242 | if center_resource[0] * center_resource[1] >= original_source.num_of_tile: 243 | raise ValueError("No such resource") 244 | self.center_resource = center_resource 245 | 246 | if sizeof_fov is None: 247 | self.sizeof_fov = gp.USER_FIELD_OF_VIEW # if no specific fov, use default fov 7x5 with 30x30 degree block 248 | elif sizeof_fov[0] % 2 == 0 or sizeof_fov[1] % 2 == 0: 249 | raise ValueError("FoV should be odd numbers") 250 | else: 251 | self.sizeof_fov = sizeof_fov 252 | 253 | if resource is None: 254 | resource_list = self.distribute_resource(original_source) 255 | else: 256 | resource_list = np.array(resource) 257 | 258 | super(User_VR, self).__init__(self.center_resource, resource_list, original_source.sphere_id, self.sizeof_fov) 259 | self.resource_size = np.ones(self.num_of_tile) * gp.TILE_SIZE 260 | 261 | # def __deepcopy__(self, memo): 262 | # copied = User_VR(self.position, self.mobility_range, self.id, self.center_resource, self.original_source, 263 | # self.sizeof_fov, self.resource) 264 | # copied.transmission_mask = self.transmission_mask.copy() 265 | # copied.resource_size = np.copy(self.resource_size) 266 | # return copied 267 | 268 | def clock_tiktok(self, gop): 269 | res_resource = self.resource[self.transmission_mask] 270 | res_resource_size = np.array([]) 271 | if self.transmission_mask.any(): 272 | res_resource_index = np.concatenate(np.argwhere(self.transmission_mask), axis=0) 273 | res_resource_size = self.resource_size[res_resource_index] 274 | res_resource += gp.TOTAL_NUM_TILES 275 | self.resource = np.concatenate((self.its_resource, res_resource)) 276 | self.num_of_tile = len(self.resource) 277 | self.transmission_mask = np.ones(self.num_of_tile, dtype=bool) 278 | 279 | self.resource_size = np.concatenate((np.ones(len(self.its_resource)) * gp.GOP_TILE_SIZE[gop], 280 | res_resource_size)) 281 | 282 | def get_resource_uav_id(self): 283 | return self.original_source.sphere_id 284 | 285 | def dist(self, other): 286 | return np.sqrt(np.power(self.position[0] - other.position[0], 2) + 287 | np.power(self.position[1] - other.position[1], 2)) 288 | 289 | def correlation(self, other): 290 | if self.sphere_id != other.sphere_id: 291 | return 0 292 | if not self.transmission_mask.any() and not other.transmission_mask.any(): 293 | return 0 294 | if self.dist(other) == 0: 295 | return abs(self - other) / abs(self + other) 296 | return 1 / (self.dist(other)) * abs(self - other) / abs(self + other) 297 | 298 | def distribute_resource(self, original_source): 299 | resource = [] 300 | 301 | width_list = [] 302 | width_list_temp = list(range(0, original_source.tiles[1])) 303 | if self.center_resource[1] < int((self.sizeof_fov[1] - 1) / 2): 304 | width_list = list(width_list_temp[int(self.center_resource[1] - (self.sizeof_fov[1] - 1) / 2):]) 305 | width_list.extend(width_list_temp[0:self.center_resource[1] + int((self.sizeof_fov[1] + 1) / 2)]) 306 | elif int((self.sizeof_fov[1] - 1) / 2) <= self.center_resource[1] <= \ 307 | int(original_source.tiles[1] - int((self.sizeof_fov[1] + 1) / 2)): 308 | width_list = [self.center_resource[1] + x for x in range(-int((self.sizeof_fov[1] - 1) / 2), 309 | int((self.sizeof_fov[1] + 1) / 2))] 310 | elif self.center_resource[1] > int(original_source.tiles[1] - int((self.sizeof_fov[1] + 1) / 2)): 311 | width_list = list(range(0, self.center_resource[1] + int((self.sizeof_fov[1] + 1) / 2 - 312 | original_source.tiles[1]))) 313 | width_list.extend(list(range(self.center_resource[1] - int((self.sizeof_fov[1] - 1) / 2), 314 | original_source.tiles[1]))) 315 | 316 | # central and edge region 317 | if int((self.sizeof_fov[0] - 1) / 2) <= self.center_resource[0] <= \ 318 | int(original_source.tiles[0] - int((self.sizeof_fov[0] + 1) / 2)): 319 | height_list = [self.center_resource[0] + j for j in range(-int((self.sizeof_fov[0] - 1) / 2), 320 | int((self.sizeof_fov[0] + 1) / 2))] 321 | for m in width_list: 322 | for n in height_list: 323 | resource.append(original_source.resource.reshape(original_source.tiles)[n][m]) 324 | 325 | elif self.center_resource[0] < int((self.sizeof_fov[0] - 1) / 2): # up polar region 326 | for m in range(0, int((self.sizeof_fov[0] - 1) / 2) - self.center_resource[0]): 327 | for n in range(0, original_source.tiles[1]): 328 | resource.append(original_source.resource.reshape(original_source.tiles)[m][n]) 329 | for m in range(int((self.sizeof_fov[0] - 1) / 2) - self.center_resource[0], 330 | self.center_resource[0] + int((self.sizeof_fov[0] + 1) / 2)): 331 | for n in width_list: 332 | resource.append(original_source.resource.reshape(original_source.tiles)[m][n]) 333 | 334 | elif self.center_resource[0] > int(original_source.tiles[0] - int((self.sizeof_fov[0] + 1) / 2)): 335 | # down polar region 336 | for m in range(int(original_source.tiles[0] -((self.sizeof_fov[0] + 1) / 2 - 337 | (original_source.tiles[0] - self.center_resource[0]))), original_source.tiles[0]): 338 | for n in range(0, original_source.tiles[1]): 339 | resource.append(original_source.resource.reshape(original_source.tiles)[m][n]) 340 | for m in range(int(self.center_resource[0] - (self.sizeof_fov[0] - 1) / 2), 341 | int(original_source.tiles[0] - ((self.sizeof_fov[0] + 1) / 2 - 342 | (original_source.tiles[0] - self.center_resource[0])))): 343 | for n in width_list: 344 | resource.append(original_source.resource.reshape(original_source.tiles)[m][n]) 345 | # temp = np.sort(np.array(resource)) 346 | # show = np.ones(original_source.tiles[0] * original_source.tiles[1]) 347 | # show[temp] = 0 348 | # show = np.reshape(show, original_source.tiles) 349 | return np.sort(np.array(resource)) 350 | 351 | @staticmethod 352 | def limit_center_range(input_shape, shape_range): 353 | if input_shape[0] < 0: 354 | input_shape[0] = 0 355 | elif input_shape[0] >= shape_range[0]: 356 | input_shape[0] = shape_range[0] - 1 357 | if input_shape[1] < 0: 358 | input_shape[1] = 0 359 | elif input_shape[1] >= shape_range[1]: 360 | input_shape[1] = shape_range[1] - 1 361 | return input_shape 362 | 363 | # resize the center resource inside current resource range 364 | 365 | def moving_fov(self, delta, size_of_fov=None, new_source=None, new_center=None): 366 | # if entering new_source, inter field moving, else, intra moving 367 | if size_of_fov is not None: 368 | self.sizeof_fov = size_of_fov 369 | if new_source is None: 370 | self.center_resource[0] += delta[0] 371 | self.center_resource[1] += delta[1] 372 | self.center_resource = self.limit_center_range(self.sizeof_fov, self.original_source.tiles) 373 | new_fov_list = self.distribute_resource(self.original_source) 374 | new_fov = Field_of_View(self.center_resource, new_fov_list, self.original_source.sphere_id, self.sizeof_fov) 375 | self % new_fov 376 | else: 377 | self.center_resource = new_center 378 | self.original_source = new_source 379 | self.center_resource = self.limit_center_range(self.sizeof_fov, self.original_source.tiles) 380 | new_fov_list = self.distribute_resource(self.original_source) 381 | new_fov = Field_of_View(self.center_resource, new_fov_list, self.original_source.sphere_id, self.sizeof_fov) 382 | self % new_fov 383 | 384 | def calculate_range(self, ap_list): 385 | position = np.array([aps.position for aps in ap_list]) 386 | compare_res = np.abs(position - np.array(self.position)) <= ((gp.ACCESS_POINTS_FIELD - 1) / 2 * gp.REWARD_CAL_RANGE) 387 | index = np.logical_and(compare_res[:, 0], compare_res[:, 1]).astype(np.bool) 388 | self.in_range_ap = [aps for aps in np.nonzero(index)[0].astype(int)] 389 | 390 | def mobility(self, delta, ap_list): 391 | self.position[0] += delta[0] 392 | self.position[1] += delta[1] 393 | self.calculate_range(ap_list) 394 | self.position = self.limit_center_range(self.position, self.mobility_range) 395 | 396 | def merge_into(self, count_dict: typing.Dict[int, int], incresing_step=1): 397 | # [resource id, count number] 398 | for key in enumerate(self.resource): 399 | if self.transmission_mask[key[0]] != 0: 400 | temp = key[1] 401 | if key[1] >= gp.TOTAL_NUM_TILES: 402 | increasing = 0 403 | while temp >= gp.TOTAL_NUM_TILES: 404 | increasing += incresing_step 405 | temp -= gp.TOTAL_NUM_TILES 406 | count_dict[temp] += increasing 407 | # this part of code is for PF scheduling with time +1 408 | count_dict[temp] += incresing_step 409 | return count_dict 410 | 411 | def transmitted(self, resource_id: np.ndarray, transmission_amount): 412 | amount = transmission_amount 413 | my_resource = np.in1d(self.resource % gp.TOTAL_NUM_TILES, resource_id) 414 | if my_resource.any(): 415 | ind = np.max(np.where(my_resource)[0]) 416 | if not self.transmission_mask[ind]: 417 | return True, False, False 418 | if amount >= np.sum(self.resource_size[ind]) / len(resource_id) * np.sum(my_resource): 419 | self.transmission_mask[ind] = 0 420 | return True, True, True 421 | else: 422 | return True, True, False 423 | return False, False, False # tile exist? transmitted? sucessfull transmitted? 424 | # if requires that resource, receive, if not, wait and waste this turn 425 | 426 | def obtain_psnr(self): 427 | # must call after the change of fov 428 | # reset parameters 429 | temp_transmission_mask = self.transmission_mask[0:len(self.its_resource)].copy() 430 | for key, ele in enumerate(self.resource[len(self.its_resource)::]): 431 | if self.transmission_mask[key]: 432 | temp_transmission_mask[np.where(self.its_resource == ele % gp.TOTAL_NUM_TILES)[0]] = True 433 | if np.sum(temp_transmission_mask) == 0: 434 | current_psnr = 10 * np.log10(self.its_resource.size * 2) 435 | else: 436 | current_psnr = 10 * np.log10(1/(1/self.its_resource.size * 437 | (np.sum(temp_transmission_mask[0:len(self.its_resource)])))) 438 | return current_psnr 439 | 440 | def obtain_psnr_linear(self): 441 | # must call after the change of fov 442 | # reset parameters 443 | temp_transmission_mask = self.transmission_mask[0:len(self.its_resource)].copy() 444 | for key, ele in enumerate(self.resource[len(self.its_resource)::]): 445 | if self.transmission_mask[key]: 446 | temp_transmission_mask[np.where(self.its_resource == ele % gp.TOTAL_NUM_TILES)[0]] = True 447 | if np.sum(temp_transmission_mask) == 0: 448 | current_psnr = 1 449 | else: 450 | current_psnr = 1 - (1 / self.its_resource.size * 451 | (np.sum(temp_transmission_mask[0:len(self.its_resource)]))) 452 | return current_psnr 453 | 454 | 455 | class UAV(VR_Sphere): 456 | __slots__ = ['sphere_id', 'num_of_tile', 'tiles', 'transmission_mask', 'resource', 'its_resource', 457 | 'id', 'mobility_range', 'position'] 458 | 459 | def __init__(self, uav_index, position: np.ndarray, mobility_range, tiles: list): 460 | self.tiles = tiles 461 | self.num_of_tile = tiles[0] * tiles[1] 462 | self.position = position 463 | self.id: int = int(uav_index) 464 | self.mobility_range = mobility_range 465 | super(UAV, self).__init__(uav_index, self.tiles, self.num_of_tile, np.array([])) 466 | 467 | def __str__(self): 468 | return "UAV " + str(self.id) + " in " + str(self.position) + " with " + str(self.tiles) 469 | 470 | # def __deepcopy__(self, memo): 471 | # copied = UAV(self.id, self.position, self.mobility_range, self.tiles) 472 | # return copied 473 | 474 | @staticmethod 475 | def limit_center_range(input_shape, shape_range): 476 | if input_shape[0] < 0: 477 | input_shape[0] = 0 478 | elif input_shape[0] >= shape_range[0]: 479 | input_shape[0] = shape_range[0] 480 | if input_shape[1] < 0: 481 | input_shape[1] = 0 482 | elif input_shape[1] >= shape_range[1]: 483 | input_shape[1] = shape_range[1] 484 | return input_shape 485 | 486 | # resize the center resource inside current resource range 487 | 488 | def mobility(self, delta): 489 | self.position[0] += delta[0] 490 | self.position[1] += delta[1] 491 | self.position = self.limit_center_range(self.position, self.mobility_range) 492 | 493 | 494 | if __name__ == "__correlation_test__": 495 | import time 496 | 497 | start_time = time.time() 498 | size = 10000 499 | N = np.random.rand(size, size) 500 | # for i in range (0, 100): 501 | # for j in range (0, 100): 502 | # if abs(i - j) > 20: 503 | # N[i][j] = 0 504 | for i in range(0, size): 505 | N[i][i] = 0 506 | # for i in range (0, size): 507 | # for j in range (0, size): 508 | # if abs(i - j) > 200: 509 | # N[i][j] = 0 510 | clst = Clustering(N, "PrivotingBK_greedy", 0.5) 511 | print(clst.get_cluster_result()) 512 | print("--- %s seconds ---" % (time.time() - start_time)) 513 | --------------------------------------------------------------------------------