├── .gitignore ├── LICENSE ├── README.md ├── adversarial_comms ├── __init__.py ├── config │ ├── coverage.yaml │ ├── coverage_split.yaml │ └── path_planning.yaml ├── environments │ ├── __init__.py │ ├── coverage.py │ └── path_planning.py ├── evaluate.py ├── generate_dataset.py ├── models │ ├── __init__.py │ ├── adversarial.py │ └── gnn │ │ ├── __init__.py │ │ ├── adversarialGraphML.py │ │ ├── graphML.py │ │ └── graphTools.py ├── train_interpreter.py ├── train_policy.py └── trainers │ ├── __init__.py │ ├── hom_multi_action_dist.py │ ├── multiagent_ppo.py │ └── random_heuristic.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Comms 2 | Code accompanying the paper 3 | > [The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning](https://arxiv.org/abs/2008.02616)\ 4 | > Jan Blumenkamp, Amanda Prorok\ 5 | > (University of Cambridge)\ 6 | > _arXiv: 2008.02616_. 7 | 8 | The five minute video presentation for CoRL 2020: 9 | 10 | [![Video preview](https://img.youtube.com/vi/SUDkRaj4FAI/0.jpg)](https://www.youtube.com/watch?v=SUDkRaj4FAI) 11 | 12 | Supplementary video material: 13 | 14 | [![Video preview](https://img.youtube.com/vi/o1Nq9XoSU6U/0.jpg)](https://www.youtube.com/watch?v=o1Nq9XoSU6U) 15 | 16 | ## Installation 17 | Clone the repository, change directory into its root and run: 18 | ``` 19 | pip install -e . 20 | ``` 21 | This will install the package and all requirements. It will also set up the entry points we are referring to later in these instructions. 22 | 23 | ## Training 24 | Generally, training is performed for the policies and for the interpreter. We first explain the three policy training steps (cooperative, self-interested, and re-adaptation) for all three experiments (coverage, split coverage and path planning) and then for the interpreters. 25 | 26 | The policy training follows this scheme: 27 | ``` 28 | train_policy [experiment] -t [total time steps in millions] 29 | continue_policy [cooperative checkpoint path] -t [total time steps] -e [experiment] -o self_interested 30 | continue_policy [self-interested checkpoint path] -t [total time steps] -e [experiment] -o re_adapt 31 | ``` 32 | where `experiment` is one of `{coverage, coverage_split, path_planning}`, `-t` is the total number of time steps at which the experiment is to be terminated (note that this is not per call, but total time steps, so if a policy is trained with `train_policy -t 20` and then continued with `continue_policy -t 20` it will terminate immediately) and `-o` is a config option (one of `{self_interested, re_adapt}` as can be found in the `alternative_config` key in each of the config files in `config`). 33 | 34 | When running each experiment, Ray will print the trial name to the terminal, which looks something like `MultiPPO_coverage_f4dc4_00000`. By default, Ray will create the directory `~/ray_results/MultiPPO` in which the trial with the given name can be found with its checkpoint. `continue_policy` expects the path to one of such checkpoints, for example `~/ray_results/MultiPPO/MultiPPO_coverage_f4dc4_00000/checkpoint_440`. The first `continue_policy` expects the checkpoint generated in the first `train_policy` call and the second `continue_policy` the checkpoint generated in the first `continue_policy` call. You should take note of each experiment's checkpoint path. 35 | 36 | ### Standard Coverage 37 | ``` 38 | train_policy coverage -t 20 39 | continue_policy [cooperative checkpoint path] -t 60 -e coverage -o self_interested 40 | continue_policy [adversarial checkpoint path] -t 80 -e coverage -o re_adapt 41 | ``` 42 | 43 | ### Split coverage 44 | ``` 45 | train_policy coverage_split -t 3 46 | continue_policy [cooperative checkpoint path] -t 20 -e coverage_split -o self_interested 47 | continue_policy [adversarial checkpoint path] -t 30 -e coverage_split -o re_adapt 48 | ``` 49 | 50 | ### Path Planning 51 | ``` 52 | train_policy path_planning -t 20 53 | continue_policy [cooperative checkpoint path] -t 60 -e path_planning -o self_interested 54 | continue_policy [adversarial checkpoint path] -t 80 -e path_planning -o re_adapt 55 | ``` 56 | 57 | ## Evaluation 58 | We provide three methods for evaluation: 59 | 60 | 1) `evaluate_coop`: Evaluate cooperative only performance while disabling self-interested agents with and without communication among cooperative agents. 61 | 2) `evaluate_adv`: Evaluate cooperative and self-interested agents with and without communication between cooperative and self-interested agents (cooperative agents can always communicate to each other). 62 | 3) `evaluate_random`: Run a random policy that visits random neighboring (preferably uncovered) cells. 63 | 64 | The evaluation is run as 65 | ``` 66 | evaluate_{coop, adv} [checkpoint path] [result path] --trials 100 67 | evaluate_random [result path] --trials 100 68 | ``` 69 | for 100 evaluation runs with different seeds. The resulting file is a Pandas dataframe containing the rewards for all agents at every time step. It can be processed and visualized by running `evaluate_plot [pickled data path]`. 70 | 71 | Additionally, a checkpoint can be rolled out and rendered for a randomly generated environment with `evaluate_serve [checkpoint_path] --seed 0`. 72 | 73 | ## Citation 74 | If you use any part of this code in your research, please cite our paper: 75 | ``` 76 | @article{blumenkamp2020adversarial, 77 | title={The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning}, 78 | author={Blumenkamp, Jan and Prorok, Amanda}, 79 | journal={Conference on Robot Learning (CoRL)}, 80 | year={2020} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /adversarial_comms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/adversarial_comms/889a41e239958bd519365c6b0c143469cd27e49f/adversarial_comms/__init__.py -------------------------------------------------------------------------------- /adversarial_comms/config/coverage.yaml: -------------------------------------------------------------------------------- 1 | framework: torch 2 | env: coverage 3 | lambda: 0.95 4 | kl_coeff: 0.5 5 | kl_target: 0.01 6 | clip_rewards: True 7 | clip_param: 0.2 8 | vf_clip_param: 250.0 9 | vf_share_layers: False 10 | vf_loss_coeff: 1.0e-4 11 | entropy_coeff: 0.01 12 | train_batch_size: 5000 13 | rollout_fragment_length: 100 14 | sgd_minibatch_size: 1000 15 | num_sgd_iter: 5 16 | num_workers: 7 17 | num_envs_per_worker: 16 18 | lr: 5.0e-4 19 | gamma: 0.9 20 | batch_mode: truncate_episodes 21 | observation_filter: NoFilter 22 | num_gpus: 0.5 23 | num_gpus_per_worker: 0.0625 24 | model: 25 | custom_model: adversarial 26 | custom_action_dist: hom_multi_action 27 | custom_model_config: 28 | graph_layers: 1 29 | graph_tabs: 2 30 | graph_edge_features: 1 31 | graph_features: 128 32 | cnn_filters: [[8, [4, 4], 2], [16, [4, 4], 2], [32, [3, 3], 2]] 33 | value_cnn_filters: [[8, [4, 4], 2], [16, [4, 4], 2], [32, [4, 4], 2]] 34 | value_cnn_compression: 128 35 | cnn_compression: 32 36 | pre_gnn_mlp: [64, 128, 32] 37 | gp_kernel_size: 16 38 | graph_aggregation: sum 39 | relative: true 40 | activation: relu 41 | freeze_coop: False 42 | freeze_greedy: False 43 | freeze_coop_value: False 44 | freeze_greedy_value: False 45 | cnn_residual: False 46 | agent_split: 1 47 | greedy_mse_fac: 0.0 48 | env_config: 49 | world_shape: [24, 24] 50 | state_size: 16 51 | collapse_state: False 52 | termination_no_new_coverage: 10 53 | max_episode_len: 345 # 24*24*0.6 54 | n_agents: [1, 5] 55 | disabled_teams_step: [True, False] 56 | disabled_teams_comms: [True, False] 57 | min_coverable_area_fraction: 0.6 58 | map_mode: random 59 | reward_annealing: 0.0 60 | communication_range: 16.0 61 | ensure_connectivity: True 62 | reward_type: semi_cooperative #semi_cooperative/cooperative 63 | episode_termination: early # early/fixed/default 64 | operation_mode: coop_only 65 | evaluation_num_workers: 1 66 | evaluation_interval: 1 67 | evaluation_num_episodes: 10 68 | evaluation_config: 69 | env_config: 70 | termination_no_new_coverage: -1 71 | max_episode_len: 345 # 24*24*0.6 72 | episode_termination: default 73 | operation_mode: all 74 | ensure_connectivity: False 75 | logger_config: 76 | wandb: 77 | project: adv_paper 78 | #project: vaegp_0920 79 | group: revised_gp 80 | api_key_file: "./wandb_api_key_file" 81 | alternative_config: 82 | self_interested: 83 | # adversarial case in co-training 84 | evaluation_num_workers: 1 85 | num_workers: 7 86 | num_envs_per_worker: 64 87 | rollout_fragment_length: 100 88 | num_gpus_per_worker: 0.0625 89 | num_gpus: 0.5 90 | env_config: 91 | operation_mode: greedy_only 92 | disabled_teams_step: [False, False] 93 | disabled_teams_comms: [False, False] 94 | n_agents: [1, 5] 95 | model: 96 | custom_model_config: 97 | freeze_coop: True 98 | freeze_greedy: False 99 | adversarial: 100 | evaluation_num_workers: 1 101 | num_workers: 7 102 | num_envs_per_worker: 64 103 | rollout_fragment_length: 100 104 | num_gpus_per_worker: 0.0625 105 | num_gpus: 0.5 106 | 107 | env_config: 108 | operation_mode: adversary_only 109 | disabled_teams_step: [False, False] 110 | disabled_teams_comms: [False, False] 111 | termination_no_new_coverage: -1 112 | max_episode_len: 173 # 24*24*0.6 113 | episode_termination: default 114 | model: 115 | custom_model_config: 116 | freeze_coop: True 117 | freeze_greedy: False 118 | re_adapt: 119 | env_config: 120 | operation_mode: coop_only 121 | disabled_teams_step: [False, False] 122 | disabled_teams_comms: [False, False] 123 | model: 124 | custom_model_config: 125 | freeze_coop: False 126 | freeze_greedy: True 127 | adversarial_abundance: 128 | # adversarial case in co-training 129 | env_config: 130 | #map_mode: random_teams_far 131 | map_mode: split_half_fixed_block 132 | #map_mode: split_half_fixed_block_same_side 133 | communication_range: 8.0 134 | model: 135 | custom_model_config: 136 | graph_tabs: 3 137 | logger_config: 138 | wandb: 139 | project: vaegp_0920 140 | 141 | -------------------------------------------------------------------------------- /adversarial_comms/config/coverage_split.yaml: -------------------------------------------------------------------------------- 1 | framework: torch 2 | env: coverage 3 | lambda: 0.95 4 | kl_coeff: 0.5 5 | kl_target: 0.01 6 | clip_rewards: True 7 | clip_param: 0.2 8 | vf_clip_param: 250.0 9 | vf_share_layers: False 10 | vf_loss_coeff: 1.0e-4 11 | entropy_coeff: 0.01 12 | train_batch_size: 5000 13 | rollout_fragment_length: 100 14 | sgd_minibatch_size: 1000 15 | num_sgd_iter: 5 16 | num_workers: 16 17 | num_envs_per_worker: 8 18 | lr: 5.0e-4 19 | gamma: 0.9 20 | batch_mode: truncate_episodes 21 | observation_filter: NoFilter 22 | num_gpus: 1 23 | model: 24 | custom_model: adversarial 25 | custom_action_dist: hom_multi_action 26 | custom_model_config: 27 | graph_layers: 1 28 | graph_tabs: 2 29 | graph_edge_features: 1 30 | 31 | # 16 32 | graph_features: 32 33 | cnn_filters: [[8, [4, 4], 2], [16, [4, 4], 2], [32, [3, 3], 2]] 34 | value_cnn_filters: [[8, [4, 4], 2], [16, [4, 4], 2], [32, [4, 4], 2]] 35 | value_cnn_compression: 128 36 | cnn_compression: 32 37 | 38 | relative: true 39 | activation: relu 40 | freeze_coop: False 41 | freeze_greedy: False 42 | freeze_coop_value: False 43 | freeze_greedy_value: False 44 | cnn_residual: False 45 | agent_split: 1 46 | env_config: 47 | world_shape: [24, 24] 48 | state_size: 16 49 | collapse_state: False 50 | termination_no_new_coverage: 10 51 | max_episode_len: 288 # (24*24)/2 52 | n_agents: [1, 5] 53 | disabled_teams_step: [True, False] 54 | disabled_teams_comms: [True, False] 55 | map_mode: split_half_fixed 56 | reward_annealing: 0.0 57 | communication_range: 16.0 58 | ensure_connectivity: True 59 | reward_type: split_right 60 | episode_termination: early_right # early/fixed/default 61 | operation_mode: coop_only 62 | agents: 63 | coverage_radius: 1 64 | visibility_distance: 0 65 | map_update_radius: 100 66 | relative_coord_frame: True 67 | evaluation_num_workers: 2 68 | evaluation_interval: 1 69 | evaluation_num_episodes: 10 70 | evaluation_config: 71 | env_config: 72 | termination_no_new_coverage: -1 73 | max_episode_len: 288 # (24*24)/2 74 | episode_termination: default 75 | operation_mode: all 76 | ensure_connectivity: False 77 | alternative_config: 78 | self_interested: 79 | env_config: 80 | operation_mode: greedy_only 81 | disabled_teams_step: [False, False] 82 | disabled_teams_comms: [False, False] 83 | model: 84 | custom_model_config: 85 | freeze_coop: True 86 | freeze_greedy: False 87 | re_adapt: 88 | env_config: 89 | operation_mode: coop_only 90 | disabled_teams_step: [False, False] 91 | disabled_teams_comms: [False, False] 92 | model: 93 | custom_model_config: 94 | freeze_coop: False 95 | freeze_greedy: True 96 | -------------------------------------------------------------------------------- /adversarial_comms/config/path_planning.yaml: -------------------------------------------------------------------------------- 1 | framework: torch 2 | env: path_planning 3 | lambda: 0.95 4 | kl_coeff: 0.5 5 | kl_target: 0.01 6 | clip_rewards: True 7 | clip_param: 0.2 8 | vf_clip_param: 250.0 9 | vf_share_layers: False 10 | vf_loss_coeff: 1.0e-4 11 | entropy_coeff: 0.01 12 | train_batch_size: 5000 13 | rollout_fragment_length: 100 14 | sgd_minibatch_size: 1000 15 | num_sgd_iter: 5 16 | num_workers: 7 17 | num_envs_per_worker: 16 18 | lr: 4.0e-4 19 | gamma: 0.99 20 | batch_mode: complete_episodes 21 | observation_filter: NoFilter 22 | num_gpus: 1.0 23 | model: 24 | custom_model: adversarial 25 | custom_action_dist: hom_multi_action 26 | custom_model_config: 27 | graph_layers: 1 28 | graph_tabs: 2 29 | graph_edge_features: 1 30 | graph_features: 128 31 | cnn_filters: [[8, [4, 4], 2], [16, [4, 4], 2], [32, [3, 3], 2]] 32 | value_cnn_filters: [[16, [4, 4], 2], [32, [4, 4], 2]] 33 | value_cnn_compression: 128 34 | cnn_compression: 32 35 | gp_kernel_size: 16 36 | graph_aggregation: sum 37 | activation: relu 38 | freeze_coop: False 39 | freeze_greedy: False 40 | freeze_coop_value: False 41 | freeze_greedy_value: False 42 | cnn_residual: False 43 | agent_split: 1 44 | env_config: 45 | world_shape: [12, 12] 46 | state_size: 16 47 | max_episode_len: 50 48 | n_agents: [1, 15] 49 | disabled_teams_step: [True, False] 50 | disabled_teams_comms: [True, False] 51 | communication_range: 5.0 52 | ensure_connectivity: True 53 | reward_type: coop_only 54 | world_mode: warehouse 55 | agents: 56 | visibility_distance: 0 57 | relative_coord_frame: True 58 | evaluation_num_workers: 1 59 | evaluation_interval: 1 60 | evaluation_num_episodes: 10 61 | evaluation_config: 62 | env_config: 63 | reward_type: local 64 | alternative_config: 65 | self_interested: 66 | env_config: 67 | reward_type: greedy_only 68 | disabled_teams_step: [False, False] 69 | disabled_teams_comms: [False, False] 70 | model: 71 | custom_model_config: 72 | freeze_coop: True 73 | freeze_greedy: False 74 | evaluation_config: 75 | env_config: 76 | reward_type: local 77 | re_adapt: 78 | env_config: 79 | reward_type: coop_only 80 | disabled_teams_step: [False, False] 81 | disabled_teams_comms: [False, False] 82 | model: 83 | custom_model_config: 84 | freeze_coop: False 85 | freeze_greedy: True 86 | evaluation_config: 87 | env_config: 88 | reward_type: local 89 | -------------------------------------------------------------------------------- /adversarial_comms/environments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/adversarial_comms/889a41e239958bd519365c6b0c143469cd27e49f/adversarial_comms/environments/__init__.py -------------------------------------------------------------------------------- /adversarial_comms/environments/coverage.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | import matplotlib.patches as patches 5 | import gym 6 | from gym import spaces 7 | from gym.utils import seeding, EzPickle 8 | from matplotlib import colors 9 | from functools import partial 10 | from enum import Enum 11 | import copy 12 | 13 | from ray.rllib.env.multi_agent_env import MultiAgentEnv 14 | 15 | # https://bair.berkeley.edu/blog/2018/12/12/rllib/ 16 | 17 | DEFAULT_OPTIONS = { 18 | 'world_shape': [24, 24], 19 | 'state_size': 48, 20 | 'collapse_state': False, 21 | 'termination_no_new_coverage': 10, 22 | 'max_episode_len': -1, 23 | "min_coverable_area_fraction": 0.6, 24 | "map_mode": "random", 25 | "n_agents": [5], 26 | "disabled_teams_step": [False], 27 | "disabled_teams_comms": [False], 28 | 'communication_range': 8.0, 29 | 'one_agent_per_cell': False, 30 | 'ensure_connectivity': True, 31 | 'reward_type': 'semi_cooperative', 32 | #"operation_mode": 'all', # greedy_only, coop_only, don't default for now 33 | 'episode_termination': 'early', 34 | 'agent_observability_radius': None, 35 | } 36 | 37 | X = 1 38 | Y = 0 39 | 40 | class Dir(Enum): 41 | RIGHT = 0 42 | LEFT = 1 43 | UP = 2 44 | DOWN = 3 45 | 46 | class WorldMap(): 47 | def __init__(self, random_state, shape, min_coverable_area_fraction): 48 | self.shape = tuple(shape) 49 | self.min_coverable_area_fraction = min_coverable_area_fraction 50 | self.reset(random_state) 51 | 52 | def reset(self, random_state, mode="random"): 53 | self.coverage = np.zeros(self.shape, dtype=np.int) 54 | if mode == "random": 55 | if self.min_coverable_area_fraction == 1.0: 56 | self.map = np.zeros(self.shape, dtype=np.uint8) 57 | else: 58 | self.map = np.ones(self.shape, dtype=np.uint8) 59 | p = np.array([random_state.randint(0, self.shape[c]) for c in [Y, X]]) 60 | while self.get_coverable_area_faction() < self.min_coverable_area_fraction: 61 | d_p = np.array([[0, 1], [0, -1], [-1, 0], [1, 0]][random_state.randint(0, 4)])#*random_state.randint(1, 5) 62 | p_new = np.clip(p + d_p, [0,0], np.array(self.shape)-1) 63 | self.map[min(p[Y],p_new[Y]):max(p[Y],p_new[Y])+1, min(p[X],p_new[X]):max(p[X],p_new[X])+1] = 0 64 | #print(min(p[Y],p_new[Y]),max(p[Y],p_new[Y])+1, min(p[X],p_new[X]),max(p[X],p_new[X])+1, np.sum(self.map)) 65 | p = p_new 66 | elif mode == "split_half_fixed" or mode == "split_half_fixed_block" or mode == "split_half_fixed_block_same_side": 67 | self.map = np.zeros(self.shape, dtype=np.uint8) 68 | self.map[:, int(self.shape[X]/2)] = 1 69 | if mode == "split_half_fixed": 70 | self.map[int(self.shape[Y]/2), int(self.shape[X]/2)] = 0 71 | 72 | def get_coverable_area_faction(self): 73 | coverable_area = ~(self.map > 0) 74 | return np.sum(coverable_area)/(self.map.shape[X]*self.map.shape[Y]) 75 | 76 | def get_coverable_area(self): 77 | coverable_area = ~(self.map>0) 78 | return np.sum(coverable_area) 79 | 80 | def get_covered_area(self): 81 | coverable_area = ~(self.map>0) 82 | return np.sum((self.coverage > 0) & coverable_area) 83 | 84 | def get_coverage_fraction(self): 85 | coverable_area = ~(self.map>0) 86 | covered_area = (self.coverage > 0) & coverable_area 87 | return np.sum(covered_area)/np.sum(coverable_area) 88 | 89 | class Action(Enum): 90 | NOP = 0 91 | MOVE_RIGHT = 1 92 | MOVE_LEFT = 2 93 | MOVE_UP = 3 94 | MOVE_DOWN = 4 95 | 96 | class Robot(): 97 | def __init__(self, 98 | index, 99 | random_state, 100 | world, 101 | state_size, 102 | collapse_state, 103 | termination_no_new_coverage, 104 | agent_observability_radius, 105 | one_agent_per_cell): 106 | self.index = index 107 | self.world = world 108 | self.termination_no_new_coverage = termination_no_new_coverage 109 | self.state_size = state_size 110 | self.collapse_state = collapse_state 111 | self.initialized_rendering = False 112 | self.agent_observability_radius = agent_observability_radius 113 | self.one_agent_per_cell = one_agent_per_cell 114 | self.pose = np.array([-1, -1]) # assign negative pose so that during reset agents are not placed at same initial position 115 | self.reset(random_state) 116 | 117 | def reset(self, random_state, pose_mean=np.array([0, 0]), pose_var=1): 118 | def random_pos(var): 119 | return np.array([ 120 | int(np.clip(random_state.normal(loc=pose_mean[c], scale=var), 0, self.world.map.shape[c]-1)) 121 | for c in [Y, X]]) 122 | 123 | current_pose_var = pose_var 124 | self.pose = random_pos(current_pose_var) 125 | self.prev_pose = self.pose.copy() 126 | while self.world.map.map[self.pose[Y], self.pose[X]] == 1 or (self.world.is_occupied(self.pose, self) and self.one_agent_per_cell): 127 | self.pose = random_pos(current_pose_var) 128 | current_pose_var += 0.1 129 | 130 | self.coverage = np.zeros(self.world.map.shape, dtype=np.bool) 131 | self.state = None 132 | self.no_new_coverage_steps = 0 133 | self.reward = 0 134 | 135 | def step(self, action): 136 | action = Action(action) 137 | 138 | delta_pose = { 139 | Action.MOVE_RIGHT: [ 0, 1], 140 | Action.MOVE_LEFT: [ 0, -1], 141 | Action.MOVE_UP: [-1, 0], 142 | Action.MOVE_DOWN: [ 1, 0], 143 | Action.NOP: [ 0, 0] 144 | }[action] 145 | 146 | is_valid_pose = lambda p: all([p[c] >= 0 and p[c] < self.world.map.shape[c] for c in [Y, X]]) 147 | is_obstacle = lambda p: self.world.map.map[p[Y]][p[X]] == 1 148 | 149 | self.prev_pose = self.pose.copy() 150 | desired_pos = self.pose + delta_pose 151 | if is_valid_pose(desired_pos) and (not self.world.is_occupied(desired_pos, self) or not self.one_agent_per_cell) and not is_obstacle(desired_pos): 152 | self.pose = desired_pos 153 | 154 | if self.world.map.coverage[self.pose[Y], self.pose[X]] == 0: 155 | self.world.map.coverage[self.pose[Y], self.pose[X]] = self.index 156 | self.reward = 1 157 | self.no_new_coverage_steps = 0 158 | else: 159 | self.reward = 0 160 | self.no_new_coverage_steps += 1 161 | 162 | self.coverage[self.pose[Y], self.pose[X]] = True 163 | #self.reward -= 1 # subtract each time step 164 | 165 | def update_state(self): 166 | coverage = self.coverage.copy().astype(np.int) 167 | if self.collapse_state: 168 | yy, xx = np.mgrid[:self.coverage.shape[Y], :self.coverage.shape[X]] 169 | for (cx, cy) in zip(xx[self.coverage]-self.pose[X], yy[self.coverage]-self.pose[Y]): 170 | if abs(cx) < self.state_size/2 and abs(cy) < self.state_size/2: 171 | continue 172 | u = max(abs(cx), abs(cy)) 173 | p_sq = np.round(self.pose + int(self.state_size/2)*np.array([cy/u, cx/u])).astype(np.int) 174 | coverage[p_sq[Y], p_sq[X]] += 1 175 | 176 | state_output_shape = np.array([self.state_size]*2, dtype=int) 177 | state_data = [ 178 | self.to_coordinate_frame(self.world.map.map, state_output_shape, fill=1), 179 | self.to_coordinate_frame(coverage, state_output_shape, fill=0) 180 | ] 181 | if self.agent_observability_radius is not None: 182 | pose_map = np.zeros(self.world.map.shape, dtype=np.uint8) 183 | 184 | for team in self.world.teams.values(): 185 | for r in team: 186 | if not r is self and np.sum((r.pose - self.pose)**2) < self.agent_observability_radius**2: 187 | pose_map[r.pose[Y], r.pose[X]] = 2 188 | pose_map[self.pose[Y], self.pose[X]] = 1 189 | state_data.append(self.to_coordinate_frame(pose_map, state_output_shape, fill=0)) 190 | self.state = np.stack(state_data, axis=-1).astype(np.uint8) 191 | 192 | done = self.no_new_coverage_steps == self.termination_no_new_coverage 193 | return self.state, self.reward, done, {} 194 | 195 | def to_abs_frame(self, data): 196 | half_state_size = int(self.state_size / 2) 197 | return np.roll(data, self.pose, axis=(0, 1))[half_state_size:, half_state_size:] 198 | 199 | def to_coordinate_frame(self, m, output_shape, fill=0): 200 | half_out_shape = np.array(output_shape/2, dtype=np.int) 201 | padded = np.pad(m,([half_out_shape[Y]]*2,[half_out_shape[X]]*2), mode='constant', constant_values=fill) 202 | return padded[self.pose[Y]:self.pose[Y] + output_shape[Y], self.pose[X]:self.pose[X] + output_shape[Y]] 203 | 204 | class CoverageEnv(gym.Env, EzPickle): 205 | def __init__(self, env_config): 206 | EzPickle.__init__(self) 207 | self.seed() 208 | 209 | self.cfg = copy.deepcopy(DEFAULT_OPTIONS) 210 | self.cfg.update(env_config) 211 | 212 | self.fig = None 213 | self.map_colormap = colors.ListedColormap(['white', 'black', 'gray']) # free, obstacle, unknown 214 | 215 | hsv = np.ones((self.cfg['n_agents'][1], 3)) 216 | hsv[..., 0] = np.linspace(160/360, 250/360, self.cfg['n_agents'][1] + 1)[:-1] 217 | self.teams_agents_color = { 218 | 0: [(1, 0, 0)], 219 | 1: colors.hsv_to_rgb(hsv) 220 | } 221 | 222 | ''' 223 | hsv = np.ones((sum(self.cfg['n_agents']), 3)) 224 | hsv[..., 0] = np.linspace(0, 1, sum(self.cfg['n_agents']) + 1)[:-1] 225 | self.teams_agents_color = {} 226 | current_index = 0 227 | for i, n_agents in enumerate(self.cfg['n_agents']): 228 | self.teams_agents_color[i] = colors.hsv_to_rgb(hsv[current_index:current_index+n_agents]) 229 | current_index += n_agents 230 | ''' 231 | 232 | hsv = np.ones((len(self.cfg['n_agents']), 3)) 233 | hsv[..., 0] = np.linspace(0, 1, len(self.cfg['n_agents']) + 1)[:-1] 234 | self.teams_colors = ['r', 'b'] #colors.hsv_to_rgb(hsv) 235 | 236 | n_all_agents = sum(self.cfg['n_agents']) 237 | self.observation_space = spaces.Dict({ 238 | 'agents': spaces.Tuple(( 239 | spaces.Dict({ 240 | 'map': spaces.Box(0, np.inf, shape=(self.cfg['state_size'], self.cfg['state_size'], 2 if self.cfg['agent_observability_radius'] is None else 3)), 241 | 'pos': spaces.Box(low=np.array([0,0]), high=np.array([self.cfg['world_shape'][Y], self.cfg['world_shape'][X]]), dtype=np.int), 242 | }), 243 | )*n_all_agents), # Do not add this as additional dimension of map and pos since this way it is easier to handle in the model 244 | 'gso': spaces.Box(-np.inf, np.inf, shape=(n_all_agents, n_all_agents)), 245 | 'state': spaces.Box(low=0, high=2, shape=self.cfg['world_shape']+[2+len(self.cfg['n_agents'])]), 246 | }) 247 | self.action_space = spaces.Tuple((spaces.Discrete(5),)*sum(self.cfg['n_agents'])) 248 | 249 | self.map = WorldMap(self.world_random_state, self.cfg['world_shape'], self.cfg['min_coverable_area_fraction']) 250 | self.teams = {} 251 | agent_index = 1 252 | for i, n_agents in enumerate(self.cfg['n_agents']): 253 | self.teams[i] = [] 254 | for j in range(n_agents): 255 | self.teams[i].append( 256 | Robot( 257 | agent_index, 258 | self.agent_random_state, 259 | self, 260 | self.cfg['state_size'], 261 | self.cfg['collapse_state'], 262 | self.cfg['termination_no_new_coverage'], 263 | self.cfg['agent_observability_radius'], 264 | self.cfg['one_agent_per_cell'] 265 | ) 266 | ) 267 | agent_index += 1 268 | 269 | self.reset() 270 | 271 | def is_occupied(self, p, agent_ignore=None): 272 | for team_key, team in self.teams.items(): 273 | if self.cfg['disabled_teams_step'][team_key]: 274 | continue 275 | for o in team: 276 | if o is agent_ignore: 277 | continue 278 | if p[X] == o.pose[X] and p[Y] == o.pose[Y]: 279 | return True 280 | return False 281 | 282 | def seed(self, seed=None): 283 | self.agent_random_state, seed_agents = seeding.np_random(seed) 284 | self.world_random_state, seed_world = seeding.np_random(seed) 285 | return [seed_agents, seed_world] 286 | 287 | def reset(self): 288 | self.dones = {key: [False for _ in team] for key, team in self.teams.items()} 289 | self.timestep = 0 290 | self.map.reset(self.world_random_state, self.cfg['map_mode']) 291 | 292 | def random_pos_seed(team_key): 293 | rnd = self.agent_random_state 294 | if self.cfg['map_mode'] == "random": 295 | return np.array([rnd.randint(0, self.map.shape[c]) for c in [Y, X]]) 296 | if self.cfg['map_mode'] == "split_half_fixed": 297 | return np.array([ 298 | rnd.randint(0, self.map.shape[Y]), 299 | rnd.randint(0, int(self.map.shape[X]/3)) 300 | ]) 301 | elif self.cfg['map_mode'] == "split_half_fixed_block": 302 | if team_key == 0: 303 | return np.array([ 304 | rnd.randint(0, self.map.shape[Y]), 305 | rnd.randint(0, int(self.map.shape[X] / 3)) 306 | ]) 307 | else: 308 | return np.array([ 309 | rnd.randint(0, self.map.shape[Y]), 310 | rnd.randint(2*int(self.map.shape[X] / 3), self.map.shape[X]) 311 | ]) 312 | elif self.cfg['map_mode'] == "split_half_fixed_block_same_side": 313 | return np.array([ 314 | rnd.randint(0, self.map.shape[Y]), 315 | rnd.randint(2*int(self.map.shape[X] / 3), self.map.shape[X]) 316 | ]) 317 | 318 | pose_seed = None 319 | for team_key, team in self.teams.items(): 320 | if not self.cfg['map_mode'] == "random" or pose_seed is None: 321 | # shared pose_seed if random map mode 322 | pose_seed = random_pos_seed(team_key) 323 | while self.map.map[pose_seed[Y], pose_seed[X]] == 1: 324 | pose_seed = random_pos_seed(team_key) 325 | for r in team: 326 | r.reset(self.agent_random_state, pose_mean=pose_seed, pose_var=1) 327 | return self.step([Action.NOP]*sum(self.cfg['n_agents']))[0] 328 | 329 | def compute_gso(self, team_id=0): 330 | own_team_agents = [(agent, self.cfg['disabled_teams_comms'][team_id]) for agent in self.teams[team_id]] 331 | other_agents = [(agent, self.cfg['disabled_teams_comms'][other_team_id]) for other_team_id, team in self.teams.items() for agent in team if not team_id == other_team_id] 332 | 333 | all_agents = own_team_agents + other_agents # order is important since in model the data is concatenated in this order as well 334 | dists = np.zeros((len(all_agents), len(all_agents))) 335 | done_matrix = np.zeros((len(all_agents), len(all_agents)), dtype=np.bool) 336 | for agent_y in range(len(all_agents)): 337 | for agent_x in range(agent_y): 338 | dst = np.sum(np.array(all_agents[agent_x][0].pose - all_agents[agent_y][0].pose)**2) 339 | dists[agent_y, agent_x] = dst 340 | dists[agent_x, agent_y] = dst 341 | 342 | d = all_agents[agent_x][1] or all_agents[agent_y][1] 343 | done_matrix[agent_y, agent_x] = d 344 | done_matrix[agent_x, agent_y] = d 345 | 346 | current_dist = self.cfg['communication_range'] 347 | A = dists < (current_dist**2) 348 | active_row = ~np.array([a[1] for a in all_agents]) 349 | if self.cfg['ensure_connectivity']: 350 | def is_connected(m): 351 | def walk_dfs(m, index): 352 | for i in range(len(m)): 353 | if m[index][i]: 354 | m[index][i] = False 355 | walk_dfs(m, i) 356 | 357 | m_c = m.copy() 358 | walk_dfs(m_c, 0) 359 | return not np.any(m_c.flatten()) 360 | 361 | # set done teams as generally connected since they should not be included by increasing connectivity 362 | while not is_connected(A[active_row][:, active_row]): 363 | current_dist *= 1.1 364 | A = (dists < current_dist**2) 365 | 366 | # Mask out done agents 367 | A = (A & ~done_matrix).astype(np.int) 368 | 369 | # normalization: refer https://github.com/QingbiaoLi/GraphNets/blob/master/Flocking/Utils/dataTools.py#L601 370 | np.fill_diagonal(A, 0) 371 | deg = np.sum(A, axis = 1) # nNodes (degree vector) 372 | D = np.diag(deg) 373 | Dp = np.diag(np.nan_to_num(np.power(deg, -1/2))) 374 | L = A # D-A 375 | gso = Dp @ L @ Dp 376 | return gso 377 | 378 | def step(self, actions): 379 | self.timestep += 1 380 | action_index = 0 381 | for i, team in enumerate(self.teams.values()): 382 | for agent in team: 383 | if not self.cfg['disabled_teams_step'][i]: 384 | agent.step(actions[action_index]) 385 | action_index += 1 386 | 387 | states, rewards = {}, {} 388 | for team_key, team in self.teams.items(): 389 | states[team_key] = [] 390 | rewards[team_key] = {} 391 | for i, agent in enumerate(team): 392 | state, reward, done, _ = agent.update_state() 393 | states[team_key].append(state) 394 | rewards[team_key][i] = reward 395 | if done: 396 | self.dones[team_key][i] = True 397 | dones = {} 398 | world_done = self.timestep == self.cfg['max_episode_len'] or self.map.get_coverage_fraction() == 1.0 399 | for key in self.teams.keys(): 400 | dones[key] = world_done 401 | if self.cfg['episode_termination'] == 'early' or self.cfg['episode_termination'] == 'early_any': 402 | dones[key] = world_done or any(self.dones[key]) 403 | elif self.cfg['episode_termination'] == 'early_all': 404 | dones[key] = world_done or all(self.dones[key]) 405 | elif self.cfg['episode_termination'] == 'early_right': 406 | # early term only if at least one agent has reached right side of env 407 | # before that fixed episode length (world_done) 408 | agent_is_in_right_half = False 409 | for agent in self.teams[key]: 410 | if agent.pose[X] < self.cfg['world_shape'][X]/2: 411 | agent.no_new_coverage_steps = 0 412 | else: 413 | agent_is_in_right_half = True 414 | 415 | if agent_is_in_right_half: 416 | dones[key] = any(self.dones[key]) 417 | else: 418 | dones[key] = world_done 419 | elif self.cfg['episode_termination'] == 'default': 420 | pass 421 | else: 422 | raise NotImplementedError("Unknown termination mode", self.cfg['episode_termination']) 423 | 424 | if self.cfg['operation_mode'] == "all": 425 | pass 426 | elif self.cfg['operation_mode'] == "greedy_only" or self.cfg['operation_mode'] == "adversary_only": 427 | dones[1] = dones[0] 428 | elif self.cfg['operation_mode'] == "coop_only": 429 | dones[0] = dones[1] 430 | else: 431 | raise NotImplementedError("Unknown operation_mode") 432 | done = any(dones.values()) # Currently we cannot run teams independently, all have to stop at the same time 433 | 434 | pose_map = np.zeros(self.map.shape + (len(self.teams),), dtype=np.uint8) 435 | for i, team in enumerate(self.teams.values()): 436 | for r in team: 437 | pose_map[r.pose[Y], r.pose[X], i] = 1 438 | global_state = np.concatenate([np.stack([self.map.map, self.map.coverage > 0], axis=-1), pose_map], axis=-1) 439 | state = { 440 | 'agents': tuple([{ 441 | 'map': states[key][agent_i], 442 | 'pos': self.teams[key][agent_i].pose 443 | } for key in self.teams.keys() for agent_i in range(self.cfg['n_agents'][key])]), 444 | 'gso': self.compute_gso(0), 445 | 'state': global_state 446 | } 447 | 448 | for key in self.teams.keys(): 449 | if self.cfg['reward_type'] == 'semi_cooperative': 450 | pass 451 | elif self.cfg['reward_type'] == 'split_right': 452 | for agent_key in rewards[key].keys(): 453 | if self.teams[key][agent_key].pose[X] < self.cfg['world_shape'][X]/2: 454 | rewards[key][agent_key] = 0 455 | else: 456 | raise NotImplementedError("Unknown reward type", self.cfg['reward_type']) 457 | if self.cfg['operation_mode'] == "all": 458 | pass 459 | elif self.cfg['operation_mode'] == "greedy_only": 460 | # copy all rewards from the greedy agent to the cooperative agents 461 | rewards[1] = {agent_key: sum(rewards[0].values()) for agent_key in rewards[1].keys()} 462 | elif self.cfg['operation_mode'] == "adversary_only": 463 | # The greedy agent's reward is the negative sum of all agent's rewards 464 | all_negative = -sum([sum(team_rewards.values()) for team_rewards in rewards.values()]) 465 | rewards[0] = {agent_key: all_negative for agent_key in rewards[0].keys()} 466 | 467 | # copy all rewards from the greedy agent to the cooperative agents 468 | rewards[1] = {agent_key: sum(rewards[0].values()) for agent_key in rewards[1].keys()} 469 | elif self.cfg['operation_mode'] == "coop_only": 470 | # copy all rewards from the coop agent to the greedy agents 471 | rewards[0] = {agent_key: sum(rewards[1].values()) for agent_key in rewards[0].keys()} 472 | else: 473 | raise NotImplementedError("Unknown operation_mode") 474 | 475 | flattened_rewards = {} 476 | agent_index = 0 477 | for key in self.teams.keys(): 478 | for r in rewards[key].values(): 479 | flattened_rewards[agent_index] = r 480 | agent_index += 1 481 | info = { 482 | 'current_global_coverage': self.map.get_coverage_fraction(), 483 | 'coverable_area': self.map.get_coverable_area(), 484 | 'rewards_teams': rewards, 485 | 'rewards': flattened_rewards 486 | } 487 | return state, sum([sum(t.values()) for i, t in enumerate(rewards.values()) if not self.cfg['disabled_teams_step'][i]]), done, info 488 | 489 | def clear_patches(self, ax): 490 | [p.remove() for p in reversed(ax.patches)] 491 | [t.remove() for t in reversed(ax.texts)] 492 | 493 | def render_adjacency(self, A, team_id, ax, color='b', stepsize=1.0): 494 | A = A.copy() 495 | own_team_agents = [agent for agent in self.teams[team_id]] 496 | other_agents = [agent for other_team_id, team in self.teams.items() for agent in team if not team_id == other_team_id] 497 | all_agents = own_team_agents + other_agents 498 | for agent_id, agent in enumerate(all_agents): 499 | for connected_agent_id in np.arange(len(A)): 500 | if A[agent_id][connected_agent_id] > 0: 501 | current_agent_pose = agent.prev_pose + (agent.pose - agent.prev_pose) * stepsize 502 | other_agent = all_agents[connected_agent_id] 503 | other_agent_pose = other_agent.prev_pose + (other_agent.pose - other_agent.prev_pose) * stepsize 504 | ax.add_patch(patches.ConnectionPatch( 505 | [current_agent_pose[X], current_agent_pose[Y]], 506 | [other_agent_pose[X], other_agent_pose[Y]], 507 | "data", edgecolor='g', facecolor='none', lw=1, ls=":", alpha=0.3 508 | )) 509 | 510 | A[connected_agent_id][agent_id] = 0 # don't draw same connection again 511 | 512 | def render_global_coverages(self, ax): 513 | if not hasattr(self, 'im_cov_global'): 514 | self.im_cov_global = ax.imshow(np.zeros(self.map.shape), vmin=0, vmax=100) 515 | all_team_colors = [(0, 0, 0, 0)] + [tuple(list(c) + [0.5]) for team_colors in self.teams_agents_color.values() for c in team_colors] 516 | coverage = self.map.coverage.copy() 517 | if self.cfg['map_mode'] == 'split_half_fixed': 518 | # mark coverage on left side as gray 519 | color_index_left_side = len(all_team_colors) 520 | all_team_colors += [(0, 0, 0, 0.5)] # gray 521 | xx, _ = np.meshgrid( 522 | np.arange(0, coverage.shape[X], 1), 523 | np.arange(0, coverage.shape[Y], 1) 524 | ) 525 | coverage[(xx < coverage.shape[X]/2) & (coverage > 0)] = color_index_left_side 526 | 527 | self.im_cov_global.set_data(colors.ListedColormap(all_team_colors)(coverage)) 528 | 529 | def render_local_coverages(self, ax): 530 | if not hasattr(self, 'im_robots'): 531 | self.im_robots = {} 532 | for team_key, team in self.teams.items(): 533 | if self.cfg['disabled_teams_step'][team_key]: 534 | continue 535 | self.im_robots[team_key] = [] 536 | for _ in team: 537 | self.im_robots[team_key].append(ax.imshow(np.zeros(self.map.shape), vmin=0, vmax=1, alpha=0.5)) 538 | 539 | self.im_map.set_data(self.map_colormap(self.map.map)) 540 | for (team_key, team), team_colors in zip(self.teams.items(), self.teams_agents_color.values()): 541 | if self.cfg['disabled_teams_step'][team_key]: 542 | continue 543 | team_im = self.im_robots[team_key] 544 | for (agent_i, agent), color, im in zip(enumerate(team), team_colors, team_im): 545 | im.set_data(colors.ListedColormap([(0, 0, 0, 0), color])(agent.coverage)) 546 | 547 | def render_overview(self, ax, stepsize=1.0): 548 | if not hasattr(self, 'im_map'): 549 | ax.set_xticks([]) 550 | ax.set_yticks([]) 551 | self.im_map = ax.imshow(np.zeros(self.map.shape), vmin=0, vmax=3) 552 | 553 | self.im_map.set_data(self.map_colormap(self.map.map)) 554 | for (team_key, team), team_colors in zip(self.teams.items(), self.teams_agents_color.values()): 555 | if self.cfg['disabled_teams_step'][team_key]: 556 | continue 557 | for (agent_i, agent), color in zip(enumerate(team), team_colors): 558 | rect_size = 1 559 | pose_microstep = agent.prev_pose + (agent.pose - agent.prev_pose)*stepsize 560 | rect = patches.Rectangle((pose_microstep[1] - rect_size / 2, pose_microstep[0] - rect_size / 2), rect_size, rect_size, 561 | linewidth=1, edgecolor=self.teams_colors[team_key], facecolor='none') 562 | ax.add_patch(rect) 563 | #ax.text(agent.pose[1]+1, agent.pose[0], f"{agent_i}", color=self.teams_colors[team_key], clip_on=True) 564 | 565 | #last_reward = sum([r.reward for r in self.robots.values()]) 566 | #ax.set_title( 567 | # f'Global coverage: {int(self.map.get_coverage_fraction()*100)}%\n' 568 | # #f'Last reward (r): {last_reward:.2f}' 569 | #) 570 | 571 | def render_connectivity(self, ax, agent_id, K): 572 | if K <= 1: 573 | return 574 | 575 | for connected_agent_id in np.arange(self.cfg['n_agents'])[self.A[agent_id] == 1]: 576 | current_agent_pose = self.robots[agent_id].pose 577 | connected_agent_d_pose = self.robots[connected_agent_id].pose - current_agent_pose 578 | ax.add_patch(patches.Arrow( 579 | current_agent_pose[X], 580 | current_agent_pose[Y], 581 | connected_agent_d_pose[X], 582 | connected_agent_d_pose[Y], 583 | edgecolor='b', 584 | facecolor='none' 585 | )) 586 | self.render_connectivity(ax, connected_agent_id, K-1) 587 | 588 | def render(self, mode='human', stepsize=1.0): 589 | if self.fig is None: 590 | plt.ion() 591 | self.fig = plt.figure(figsize=(3, 3)) 592 | self.ax_overview = self.fig.add_subplot(1, 1, 1, aspect='equal') 593 | 594 | self.clear_patches(self.ax_overview) 595 | self.render_overview(self.ax_overview, stepsize) 596 | #self.render_local_coverages(self.ax_overview) 597 | self.render_global_coverages(self.ax_overview) 598 | A = self.compute_gso(0) 599 | self.render_adjacency(A, 0, self.ax_overview, stepsize=stepsize) 600 | 601 | self.fig.canvas.draw() 602 | self.fig.subplots_adjust(bottom=0, top=1, left=0, right=1) 603 | return self.fig 604 | 605 | class CoverageEnvExplAdv(CoverageEnv): 606 | def __init__(self, cfg): 607 | super().__init__(cfg) 608 | 609 | def render(self, interpreter_obs, mode='human', stepsize=1.0): 610 | if self.fig is None: 611 | plt.ion() 612 | self.fig = plt.figure(figsize=(6, 3)) 613 | gs = self.fig.add_gridspec(ncols=2, nrows=1) 614 | gs.update(wspace=0, hspace=0) 615 | self.ax_overview = self.fig.add_subplot(gs[0]) 616 | ax_expl = self.fig.add_subplot(gs[1]) 617 | ax_expl.set_xticks([]) 618 | ax_expl.set_yticks([]) 619 | self.im_expl_cov = ax_expl.imshow(np.zeros((1, 1)), vmin=0, vmax=1) 620 | self.im_expl_map = ax_expl.imshow(np.zeros((1, 1)), vmin=0, vmax=1) 621 | 622 | self.clear_patches(self.ax_overview) 623 | self.render_overview(self.ax_overview, stepsize) 624 | self.render_global_coverages(self.ax_overview) 625 | A = self.compute_gso(0) 626 | self.render_adjacency(A, 0, self.ax_overview, stepsize=stepsize) 627 | 628 | adv_coverage = interpreter_obs[0][0] 629 | cmap_own_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), list(self.teams_agents_color[0][0])+[0.5]]) 630 | cmap_map = colors.ListedColormap([(0,0,0,0), (0,0,0,1)]) # free, obstacle, unknown 631 | self.im_expl_cov.set_data(cmap_own_cov(self.teams[0][0].to_abs_frame(adv_coverage))) 632 | self.im_expl_map.set_data(cmap_map(self.map.map)) 633 | 634 | self.fig.canvas.draw() 635 | self.fig.subplots_adjust(bottom=0, top=1, left=0, right=1) 636 | return self.fig 637 | 638 | class CoverageEnvAdvDec(CoverageEnv): 639 | def __init__(self, cfg): 640 | super().__init__(cfg) 641 | 642 | def render(self, interpreter_obs, mode='human', stepsize=1.0): 643 | if self.fig is None: 644 | plt.ion() 645 | self.fig = plt.figure(figsize=(6, 3)) 646 | gs = self.fig.add_gridspec(ncols=2, nrows=1) 647 | gs.update(wspace=0, hspace=0) 648 | self.ax_overview = self.fig.add_subplot(gs[0]) 649 | ax_expl = self.fig.add_subplot(gs[1]) 650 | ax_expl.set_xticks([]) 651 | ax_expl.set_yticks([]) 652 | self.im_expl_cov = ax_expl.imshow(np.zeros((1, 1)), vmin=0, vmax=1) 653 | self.im_expl_map = ax_expl.imshow(np.zeros((1, 1)), vmin=0, vmax=1) 654 | 655 | self.clear_patches(self.ax_overview) 656 | self.render_overview(self.ax_overview, stepsize) 657 | self.render_global_coverages(self.ax_overview) 658 | A = self.compute_gso(0) 659 | self.render_adjacency(A, 0, self.ax_overview, stepsize=stepsize) 660 | 661 | cmap_own_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), list(self.teams_agents_color[0][0])+[0.5]]) 662 | cmap_map = colors.ListedColormap([(0,0,0,0), (0,0,0,1)]) # free, obstacle, unknown 663 | self.im_expl_cov.set_data(cmap_own_cov(interpreter_obs[0][1])) 664 | #self.im_expl_map.set_data(cmap_map(self.map.map)) 665 | self.im_expl_map.set_data(cmap_map(interpreter_obs[0][0])) 666 | 667 | self.fig.canvas.draw() 668 | self.fig.subplots_adjust(bottom=0, top=1, left=0, right=1) 669 | return self.fig 670 | 671 | class CoverageEnvSingleSaliency(CoverageEnv): 672 | def __init__(self, cfg): 673 | super().__init__(cfg) 674 | 675 | def render(self, interpreter_obs, interpr_index=0, mode='human', stepsize=1.0): 676 | if self.fig is None: 677 | plt.ion() 678 | self.fig = plt.figure(figsize=(6, 3)) 679 | gs = self.fig.add_gridspec(ncols=3, nrows=1) 680 | gs.update(wspace=0, hspace=0) 681 | self.ax_overview = self.fig.add_subplot(gs[0]) 682 | ax_expl_map = self.fig.add_subplot(gs[1]) 683 | ax_expl_cov = self.fig.add_subplot(gs[2]) 684 | ax_expl_map.set_xticks([]) 685 | ax_expl_map.set_yticks([]) 686 | ax_expl_cov.set_xticks([]) 687 | ax_expl_cov.set_yticks([]) 688 | self.im_expl_cov = ax_expl_cov.imshow(np.zeros((1, 1)), vmin=0, vmax=1) 689 | self.im_expl_map = ax_expl_map.imshow(np.zeros((1, 1)), vmin=0, vmax=1) 690 | 691 | self.clear_patches(self.ax_overview) 692 | self.render_overview(self.ax_overview, stepsize) 693 | self.render_global_coverages(self.ax_overview) 694 | A = self.compute_gso(0) 695 | self.render_adjacency(A, 0, self.ax_overview, stepsize=stepsize) 696 | 697 | #cmap_own_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), list(self.teams_agents_color[0][0])+[0.5]]) 698 | #cmap_map = colors.ListedColormap([(0,0,0,0), (0,0,0,1)]) # free, obstacle, unknown 699 | 700 | saliency_limits = (np.min(interpreter_obs), np.max(interpreter_obs)) 701 | self.im_expl_cov.set_clim(saliency_limits[0], saliency_limits[1]) 702 | self.im_expl_map.set_clim(saliency_limits[0], saliency_limits[1]) 703 | self.im_expl_cov.set_data(interpreter_obs[interpr_index][:, :, 1]) 704 | self.im_expl_map.set_data(interpreter_obs[interpr_index][:, :, 0]) 705 | 706 | #self.im_expl_cov.set_data(cmap_own_cov(interpreter_obs[interpr_index][1])) 707 | #self.im_expl_map.set_data(cmap_map(self.map.map)) 708 | #self.im_expl_map.set_data(cmap_map(interpreter_obs[interpr_index][0])) 709 | 710 | self.fig.canvas.draw() 711 | self.fig.subplots_adjust(bottom=0, top=1, left=0, right=1) 712 | return self.fig 713 | 714 | class CoverageEnvSaliency(CoverageEnv): 715 | def __init__(self, cfg): 716 | super().__init__(cfg) 717 | 718 | def render(self, mode='human', saliency_obs=None, interpreter_obs=None): 719 | if self.fig is None: 720 | plt.ion() 721 | self.fig = plt.figure(constrained_layout=True, figsize=(16, 10)) 722 | grid_spec = self.fig.add_gridspec(ncols=max(self.cfg['n_agents']) * 3, 723 | nrows=1 + 2 * len(self.cfg['n_agents']), 724 | height_ratios=[1] + [1, 1] * len(self.cfg['n_agents'])) 725 | 726 | self.ax_overview = self.fig.add_subplot(grid_spec[0, :]) 727 | 728 | self.ax_im_agent = {} 729 | for team_key, team in self.teams.items(): 730 | self.ax_im_agent[team_key] = [] 731 | for i in range(self.cfg['n_agents'][team_key]): 732 | self.ax_im_agent[team_key].append({}) 733 | for j, col_id in enumerate(['map', 'coverage']): 734 | self.ax_im_agent[team_key][i][col_id] = {} 735 | for k, row_id in enumerate(['obs', 'sal']): 736 | ax = self.fig.add_subplot(grid_spec[j + 1 + team_key * 2, i * 3 + k]) 737 | ax.set_xticks([]) 738 | ax.set_yticks([]) 739 | self.ax_im_agent[team_key][i][col_id][row_id] = {'ax': ax, 'im': None} 740 | self.ax_im_agent[team_key][i][col_id]['sal']['im'] = self.ax_im_agent[team_key][i][col_id]['sal']['ax'].imshow(np.zeros((1, 1)), vmin=-5, vmax=5) 741 | #self.ax_im_agent[team_key][i][col_id]['int']['im'] = self.ax_im_agent[team_key][i][col_id]['int']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=1) 742 | self.ax_im_agent[team_key][i]['map']['obs']['im'] = self.ax_im_agent[team_key][i]['map']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=3) 743 | self.ax_im_agent[team_key][i]['coverage']['obs']['im'] = self.ax_im_agent[team_key][i]['coverage']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=1, alpha=0.3) 744 | 745 | self.clear_patches(self.ax_overview) 746 | self.render_overview(self.ax_overview) 747 | A = self.compute_gso(0) 748 | self.render_adjacency(A, 0, self.ax_overview) 749 | 750 | if saliency_obs is not None: 751 | saliency_limits = (np.min(saliency_obs), np.max(saliency_obs)) 752 | img_map_id = 0 753 | for team_key, team in self.teams.items(): 754 | for i, robot in enumerate(team): 755 | 756 | self.ax_im_agent[team_key][i]['map']['obs']['im'].set_data( 757 | self.map_colormap(robot.to_abs_frame(robot.state[..., 0]))) 758 | this_coverage_colormap = colors.ListedColormap([(0, 0, 0, 0), self.teams_agents_color[team_key][i]]) 759 | self.ax_im_agent[team_key][i]['coverage']['obs']['im'].set_data( 760 | this_coverage_colormap(robot.to_abs_frame(robot.state[..., 1]))) 761 | 762 | if saliency_obs is not None: 763 | self.ax_im_agent[team_key][i]['map']['sal']['im'].set_data( 764 | robot.to_abs_frame(saliency_obs[img_map_id][..., 0])) 765 | self.ax_im_agent[team_key][i]['map']['sal']['im'].set_clim(saliency_limits[0], saliency_limits[1]) 766 | self.ax_im_agent[team_key][i]['coverage']['sal']['im'].set_data( 767 | robot.to_abs_frame(saliency_obs[img_map_id][..., 1])) 768 | self.ax_im_agent[team_key][i]['coverage']['sal']['im'].set_clim(saliency_limits[0], saliency_limits[1]) 769 | 770 | if interpreter_obs is not None: 771 | self.ax_im_agent[team_key][i]['map']['int']['im'].set_data( 772 | robot.to_abs_frame(interpreter_obs[img_map_id][1])) 773 | self.ax_im_agent[team_key][i]['coverage']['int']['im'].set_data( 774 | robot.to_abs_frame(interpreter_obs[img_map_id][0])) 775 | 776 | img_map_id += 1 777 | 778 | self.ax_im_agent[team_key][i]['map']['obs']['ax'].set_title( 779 | f'{i}') #\nc: {0:.2f}\nr: {robot.reward:.2f}') 780 | 781 | # self.render_connectivity(self.ax_overview, 0, 3) 782 | self.fig.canvas.draw() 783 | return self.fig 784 | -------------------------------------------------------------------------------- /adversarial_comms/environments/path_planning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | import matplotlib.patches as patches 5 | import gym 6 | from gym import spaces 7 | from gym.utils import seeding, EzPickle 8 | from matplotlib import colors 9 | from functools import partial 10 | from enum import Enum 11 | import copy 12 | 13 | from ray.rllib.env.multi_agent_env import MultiAgentEnv 14 | 15 | # https://bair.berkeley.edu/blog/2018/12/12/rllib/ 16 | 17 | DEFAULT_OPTIONS = { 18 | 'world_shape': [12, 12], 19 | 'state_size': 24, 20 | 'max_episode_len': 50, 21 | "n_agents": [8], 22 | "disabled_teams_step": [False], 23 | "disabled_teams_comms": [False], 24 | 'communication_range': 5.0, 25 | 'ensure_connectivity': True, 26 | 'position_mode': 'random', # random or fixed 27 | 'agents': { 28 | 'visibility_distance': 3, 29 | 'relative_coord_frame': True 30 | } 31 | } 32 | 33 | X = 1 34 | Y = 0 35 | 36 | class Dir(Enum): 37 | RIGHT = 0 38 | LEFT = 1 39 | UP = 2 40 | DOWN = 3 41 | 42 | class WorldMap(): 43 | def __init__(self, shape, mode): 44 | self.shape = shape 45 | self.mode = mode 46 | self.reset() 47 | 48 | def reset(self): 49 | if self.mode == "traffic": 50 | self.map = np.array([ 51 | [1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1], 52 | [1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1], 53 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 54 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1], 55 | [1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1], 56 | [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0], 57 | [1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1], 58 | [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], 59 | [1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1], 60 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 61 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 62 | [1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1], 63 | ]) 64 | elif self.mode == "warehouse": 65 | self.map = np.zeros(self.shape, dtype=np.uint8) 66 | for y in range(1, self.shape[Y]-1, 2): 67 | for x in range(1, self.shape[X]-1, 6): 68 | self.map[y:y+1,x:x+4] = True 69 | else: 70 | raise NotImplementedError 71 | 72 | class Action(Enum): 73 | NOP = 0 74 | MOVE_RIGHT = 1 75 | MOVE_LEFT = 2 76 | MOVE_UP = 3 77 | MOVE_DOWN = 4 78 | 79 | class Robot(): 80 | def __init__(self, 81 | world, 82 | agent_observability_radius, 83 | state_size, 84 | coordinate_frame_is_local): 85 | self.world = world 86 | self.state_size = state_size 87 | self.coordinate_frame_is_local = coordinate_frame_is_local 88 | self.agent_observability_radius = agent_observability_radius 89 | self.reset([0, 0], [0, 0]) 90 | 91 | def reset(self, pose, goal): 92 | self.pose = np.array(pose, dtype=np.int) 93 | self.prev_pose = self.pose.copy() 94 | self.goal = np.array(goal, dtype=np.int) 95 | 96 | def step(self, action): 97 | action = Action(action) 98 | 99 | delta_pose = { 100 | Action.MOVE_RIGHT: [ 0, 1], 101 | Action.MOVE_LEFT: [ 0, -1], 102 | Action.MOVE_UP: [-1, 0], 103 | Action.MOVE_DOWN: [ 1, 0], 104 | Action.NOP: [ 0, 0] 105 | }[action] 106 | 107 | def is_occupied(p): 108 | for team_key, team in self.world.teams.items(): 109 | if self.world.cfg['disabled_teams_step'][team_key]: 110 | continue 111 | for o in team: 112 | if p[X] == o.pose[X] and p[Y] == o.pose[Y] and o is not self: 113 | return True 114 | return False 115 | 116 | is_valid_pose = lambda p: all([p[c] >= 0 and p[c] < self.world.map.shape[c] for c in [Y, X]]) 117 | is_obstacle = lambda p: self.world.map.map[p[Y]][p[X]] 118 | 119 | self.prev_pose = self.pose.copy() 120 | desired_pos = self.pose + delta_pose 121 | if is_valid_pose(desired_pos) and not is_occupied(desired_pos) and not is_obstacle(desired_pos): 122 | self.pose = desired_pos 123 | 124 | def update_state(self): 125 | pose_map = np.zeros(self.world.map.shape, dtype=np.uint8) 126 | for team in self.world.teams.values(): 127 | for r in team: 128 | if not r is self and np.sum((r.pose - self.pose)**2) <= self.agent_observability_radius**2: 129 | pose_map[r.pose[Y], r.pose[X]] = 2 130 | pose_map[self.pose[Y], self.pose[1]] = 1 131 | 132 | goal_map = np.zeros(self.world.map.shape, dtype=np.bool) 133 | cy, cx = self.goal - self.pose 134 | if abs(cx) < self.state_size/2 and abs(cy) < self.state_size/2: 135 | goal_map[self.goal[Y], self.goal[X]] = True 136 | else: 137 | u = max(abs(cx), abs(cy)) 138 | p_sq = np.round(self.pose + int(self.state_size / 2) * np.array([cy / u, cx / u])).astype(np.int) 139 | goal_map[p_sq[Y], p_sq[X]] = True 140 | 141 | self.state = np.stack([self.to_coordinate_frame(self.world.map.map, 1), self.to_coordinate_frame(goal_map, 0), self.to_coordinate_frame(pose_map, 0)], axis=-1).astype(np.uint8) 142 | done = all(self.pose == self.goal) 143 | return self.state, done 144 | 145 | def to_coordinate_frame(self, m, fill=0): 146 | if self.coordinate_frame_is_local: 147 | half_state_shape = np.array([self.state_size/2]*2, dtype=np.int) 148 | padded = np.pad(m,([half_state_shape[Y]]*2,[half_state_shape[X]]*2), mode='constant', constant_values=fill) 149 | return padded[self.pose[Y]:self.pose[Y] + self.state_size, self.pose[X]:self.pose[X] + self.state_size] 150 | else: 151 | return m 152 | 153 | class PathPlanningEnv(gym.Env, EzPickle): 154 | def __init__(self, env_config): 155 | EzPickle.__init__(self) 156 | self.seed() 157 | 158 | self.cfg = copy.deepcopy(DEFAULT_OPTIONS) 159 | self.cfg.update(env_config) 160 | 161 | self.fig = None 162 | self.map_colormap = colors.ListedColormap(['white', 'black', 'gray']) # free, obstacle, unknown 163 | 164 | hsv = np.ones((sum(self.cfg['n_agents']), 3)) 165 | hsv[..., 0] = np.linspace(0, 1, sum(self.cfg['n_agents']) + 1)[:-1] 166 | self.teams_agents_color = {} 167 | current_index = 0 168 | for i, n_agents in enumerate(self.cfg['n_agents']): 169 | self.teams_agents_color[i] = colors.hsv_to_rgb(hsv[current_index:current_index+n_agents]) 170 | current_index += n_agents 171 | 172 | hsv = np.ones((len(self.cfg['n_agents']), 3)) 173 | hsv[..., 0] = np.linspace(0, 1, len(self.cfg['n_agents']) + 1)[:-1] 174 | self.teams_colors = ['r', 'b'] #colors.hsv_to_rgb(hsv) 175 | 176 | n_all_agents = sum(self.cfg['n_agents']) 177 | self.observation_space = spaces.Dict({ 178 | 'agents': spaces.Tuple(( 179 | spaces.Dict({ 180 | 'map': spaces.Box(0, np.inf, shape=(self.cfg['state_size'], self.cfg['state_size'], 3)), 181 | 'pos': spaces.Box(low=np.array([0,0]), high=np.array([self.cfg['world_shape'][Y], self.cfg['world_shape'][X]]), dtype=np.int), 182 | }), 183 | )*n_all_agents), # Do not add this as additional dimension of map and pos since this way it is easier to handle in the model 184 | 'gso': spaces.Box(-np.inf, np.inf, shape=(n_all_agents, n_all_agents)), 185 | 'state': spaces.Box(low=0, high=3, shape=self.cfg['world_shape']+[sum(self.cfg['n_agents'])]), 186 | }) 187 | self.action_space = spaces.Tuple((spaces.Discrete(5),)*sum(self.cfg['n_agents'])) 188 | 189 | self.map = WorldMap(self.cfg['world_shape'], self.cfg['world_mode']) 190 | 191 | self.teams = { 192 | i: [ 193 | Robot( 194 | self, 195 | self.cfg['agents']['visibility_distance'], 196 | self.cfg['state_size'], 197 | self.cfg['agents']['relative_coord_frame'] 198 | ) for _ in range(n_agents) 199 | ] for i, n_agents in enumerate(self.cfg['n_agents']) 200 | } 201 | 202 | self.reset() 203 | 204 | def seed(self, seed=None): 205 | self.random_state, seed_agents = seeding.np_random(seed) 206 | return [seed_agents] 207 | 208 | def reset(self): 209 | self.timestep = 0 210 | self.dones = {key: [False for _ in team] for key, team in self.teams.items()} 211 | self.map.reset() 212 | 213 | def sample_random_pos(): 214 | x = self.random_state.randint(0, self.map.shape[X]) 215 | y = self.random_state.randint(0, self.map.shape[Y]) 216 | return np.array([y, x]) 217 | 218 | def sample_valid_random_pos(up_to=None): 219 | def get_agents(): 220 | return [o for team in self.teams.values() for o in team][:up_to] 221 | def is_occupied(p): 222 | return any([all(p == o.pose) for o in get_agents()]) 223 | def is_other_goal(p): 224 | return any([all(p == o.goal) for o in get_agents()]) 225 | is_obstacle = lambda p: self.map.map[p[Y]][p[X]] 226 | 227 | pose_seed = sample_random_pos() 228 | while is_obstacle(pose_seed) or is_occupied(pose_seed) or is_other_goal(pose_seed): 229 | pose_seed = sample_random_pos() 230 | return pose_seed 231 | 232 | agent_index = 0 233 | for team_key, team in self.teams.items(): 234 | if self.cfg['disabled_teams_step'][team_key]: 235 | continue 236 | for agent in team: 237 | agent.reset(sample_valid_random_pos(agent_index), sample_valid_random_pos(agent_index)) 238 | agent_index += 1 239 | 240 | return self.step([Action.NOP]*sum(self.cfg['n_agents']))[0] 241 | 242 | def compute_gso(self, team_id=0): 243 | own_team_agents = [(agent, self.cfg['disabled_teams_comms'][team_id]) for agent in self.teams[team_id]] 244 | other_agents = [(agent, self.cfg['disabled_teams_comms'][other_team_id]) for other_team_id, team in self.teams.items() for agent in team if not team_id == other_team_id] 245 | 246 | all_agents = own_team_agents + other_agents # order is important since in model the data is concatenated in this order as well 247 | dists = np.zeros((len(all_agents), len(all_agents))) 248 | done_matrix = np.zeros((len(all_agents), len(all_agents)), dtype=np.bool) 249 | for agent_y in range(len(all_agents)): 250 | for agent_x in range(agent_y): 251 | dst = np.sum(np.array(all_agents[agent_x][0].pose - all_agents[agent_y][0].pose)**2) 252 | dists[agent_y, agent_x] = dst 253 | dists[agent_x, agent_y] = dst 254 | 255 | d = all_agents[agent_x][1] or all_agents[agent_y][1] 256 | done_matrix[agent_y, agent_x] = d 257 | done_matrix[agent_x, agent_y] = d 258 | 259 | current_dist = self.cfg['communication_range'] 260 | A = dists < (current_dist**2) 261 | active_row = ~np.array([a[1] for a in all_agents]) 262 | if self.cfg['ensure_connectivity']: 263 | def is_connected(m): 264 | def walk_dfs(m, index): 265 | for i in range(len(m)): 266 | if m[index][i]: 267 | m[index][i] = False 268 | walk_dfs(m, i) 269 | 270 | m_c = m.copy() 271 | walk_dfs(m_c, 0) 272 | return not np.any(m_c.flatten()) 273 | 274 | # set done teams as generally connected since they should not be included by increasing connectivity 275 | while not is_connected(A[active_row][:, active_row]): 276 | current_dist *= 1.1 277 | A = (dists < current_dist**2) 278 | 279 | # Mask out done agents 280 | A = (A & ~done_matrix).astype(np.int) 281 | 282 | # normalization: refer https://github.com/QingbiaoLi/GraphNets/blob/master/Flocking/Utils/dataTools.py#L601 283 | np.fill_diagonal(A, 0) 284 | deg = np.sum(A, axis = 1) # nNodes (degree vector) 285 | D = np.diag(deg) 286 | Dp = np.diag(np.nan_to_num(np.power(deg, -1/2))) 287 | L = A # D-A 288 | gso = Dp @ L @ Dp 289 | return gso 290 | 291 | def step(self, actions): 292 | self.timestep += 1 293 | action_index = 0 294 | for i, team in enumerate(self.teams.values()): 295 | for j, agent in enumerate(team): 296 | if not self.cfg['disabled_teams_step'][i]: # and not self.dones[i][j]: 297 | agent.step(actions[action_index]) 298 | action_index += 1 299 | 300 | states, rewards = {}, {} 301 | for team_key, team in self.teams.items(): 302 | states[team_key] = [] 303 | rewards[team_key] = {} 304 | for i, agent in enumerate(team): 305 | state, done = agent.update_state() 306 | states[team_key].append(state) 307 | rewards[team_key][i] = 1 if done else 0 # reward while at goal, incentives moving as quickly as possible 308 | if done: 309 | self.dones[team_key][i] = True 310 | 311 | if self.cfg['reward_type'] == 'local': 312 | pass 313 | elif self.cfg['reward_type'] == 'greedy_only': 314 | rewards[1] = {agent_key: sum(rewards[0].values()) for agent_key in rewards[1].keys()} 315 | elif self.cfg['reward_type'] == 'coop_only': 316 | rewards[0] = {agent_key: sum(rewards[1].values()) for agent_key in rewards[0].keys()} 317 | else: 318 | raise NotImplementedError("Unknown reward type", self.cfg['reward_type']) 319 | 320 | done = self.timestep == self.cfg['max_episode_len'] # or all(self.dones[1]) 321 | 322 | global_state = np.stack([self.map.map.copy() for _ in range(sum(self.cfg['n_agents']))], axis=-1).astype(np.uint8) 323 | global_state_layer = 0 324 | for team in self.teams.values(): 325 | for r in team: 326 | global_state[r.pose[Y], r.pose[X], global_state_layer] = 2 327 | global_state[r.goal[Y], r.goal[X], global_state_layer] = 3 328 | global_state_layer += 1 329 | 330 | state = { 331 | 'agents': tuple([{ 332 | 'map': states[key][agent_i], 333 | 'pos': self.teams[key][agent_i].pose 334 | } for key in self.teams.keys() for agent_i in range(self.cfg['n_agents'][key])]), 335 | 'gso': self.compute_gso(0), 336 | 'state': global_state 337 | } 338 | 339 | flattened_rewards = {} 340 | agent_index = 0 341 | for key in self.teams.keys(): 342 | for r in rewards[key].values(): 343 | flattened_rewards[agent_index] = r 344 | agent_index += 1 345 | info = { 346 | 'rewards_teams': rewards, 347 | 'rewards': flattened_rewards 348 | } 349 | return state, sum([sum(t.values()) for i, t in enumerate(rewards.values()) if not self.cfg['disabled_teams_step'][i]]), done, info 350 | 351 | def clear_patches(self, ax): 352 | [p.remove() for p in reversed(ax.patches)] 353 | [t.remove() for t in reversed(ax.texts)] 354 | 355 | def render_adjacency(self, A, team_id, ax, color='b', stepsize=1.0): 356 | A = A.copy() 357 | own_team_agents = [agent for agent in self.teams[team_id]] 358 | other_agents = [agent for other_team_id, team in self.teams.items() for agent in team if not team_id == other_team_id] 359 | all_agents = own_team_agents + other_agents 360 | for agent_id, agent in enumerate(all_agents): 361 | for connected_agent_id in np.arange(len(A)): 362 | if A[agent_id][connected_agent_id] > 0: 363 | current_agent_pose = agent.prev_pose + (agent.pose - agent.prev_pose) * stepsize 364 | other_agent = all_agents[connected_agent_id] 365 | other_agent_pose = other_agent.prev_pose + (other_agent.pose - other_agent.prev_pose) * stepsize 366 | ax.add_patch(patches.ConnectionPatch( 367 | [current_agent_pose[X], current_agent_pose[Y]], 368 | [other_agent_pose[X], other_agent_pose[Y]], 369 | "data", edgecolor='g', facecolor='none', lw=1, ls=":" 370 | )) 371 | 372 | A[connected_agent_id][agent_id] = 0 # don't draw same connection again 373 | 374 | def render_overview(self, ax, stepsize=1.0): 375 | if not hasattr(self, 'im_map'): 376 | ax.set_xticks([]) 377 | ax.set_yticks([]) 378 | self.im_map = ax.imshow(np.zeros(self.map.shape), vmin=0, vmax=1) 379 | 380 | self.im_map.set_data(self.map_colormap(self.map.map)) 381 | agent_i = 0 382 | for (team_key, team) in self.teams.items(): 383 | if self.cfg['disabled_teams_step'][team_key]: 384 | continue 385 | for agent in team: 386 | rect_size = 1 387 | pose_microstep = agent.prev_pose + (agent.pose - agent.prev_pose)*stepsize 388 | rect = patches.Rectangle((pose_microstep[1] - rect_size / 2, pose_microstep[0] - rect_size / 2), rect_size, rect_size, 389 | linewidth=1, edgecolor=self.teams_colors[team_key], facecolor='none') 390 | ax.add_patch(rect) 391 | ax.text(pose_microstep[1]-0.45, pose_microstep[0], f"{agent_i}", color=self.teams_colors[team_key]) 392 | agent_i += 1 393 | 394 | #ax.set_title( 395 | # f'Global coverage: {int(self.map.get_coverage_fraction()*100)}%\n' 396 | #) 397 | 398 | def render_goals(self, ax): 399 | agent_i = 0 400 | for team_key, team in self.teams.items(): 401 | if self.cfg['disabled_teams_step'][team_key]: 402 | continue 403 | for agent in team: 404 | rect = patches.Circle((agent.goal[1], agent.goal[0]), 0.1, 405 | linewidth=1, facecolor=self.teams_colors[team_key]) 406 | ax.add_patch(rect) 407 | ax.text(agent.goal[1] - 0.45, agent.goal[0] + 0.5, f"{agent_i}", color=self.teams_colors[team_key]) 408 | agent_i += 1 409 | 410 | def render_connectivity(self, ax, agent_id, K): 411 | if K <= 1: 412 | return 413 | 414 | for connected_agent_id in np.arange(self.cfg['n_agents'])[self.A[agent_id] == 1]: 415 | current_agent_pose = self.robots[agent_id].pose 416 | connected_agent_d_pose = self.robots[connected_agent_id].pose - current_agent_pose 417 | ax.add_patch(patches.Arrow( 418 | current_agent_pose[X], 419 | current_agent_pose[Y], 420 | connected_agent_d_pose[X], 421 | connected_agent_d_pose[Y], 422 | edgecolor='b', 423 | facecolor='none' 424 | )) 425 | self.render_connectivity(ax, connected_agent_id, K-1) 426 | 427 | def render_future_steps(self, future_steps, ax, stepsize=1.0): 428 | if future_steps is None: 429 | return 430 | 431 | current_agent_i = 0 432 | for team_key, team in self.teams.items(): 433 | for agent in team: 434 | if current_agent_i not in future_steps: 435 | current_agent_i += 1 436 | continue 437 | 438 | previous_agent_pose = future_steps[current_agent_i][0].copy() 439 | for i, current_pos in enumerate(future_steps[current_agent_i][1:]): 440 | if i == len(future_steps[current_agent_i][1:])-1: 441 | current_pos = previous_agent_pose + ( 442 | future_steps[current_agent_i][-1] - previous_agent_pose) * stepsize 443 | ax.add_patch( 444 | patches.Rectangle((current_pos[1] - 1/2, current_pos[0] - 1/2), 1, 1, 445 | linewidth=1, edgecolor=self.teams_colors[team_key], 446 | facecolor='none', ls=":") 447 | ) 448 | ax.add_patch(patches.ConnectionPatch( 449 | [previous_agent_pose[X], previous_agent_pose[Y]], 450 | [current_pos[X], current_pos[Y]], 451 | "data", edgecolor=self.teams_colors[team_key], facecolor='none', lw=2 452 | )) 453 | 454 | previous_agent_pose = current_pos.copy() 455 | 456 | current_agent_i += 1 457 | 458 | def render(self, mode='human', future_steps=None, stepsize=1.0): 459 | if self.fig is None: 460 | self.fig = plt.figure(figsize=(3, 3)) 461 | self.ax_overview = self.fig.add_subplot(1, 1, 1, aspect='equal') 462 | 463 | self.clear_patches(self.ax_overview) 464 | self.render_future_steps(future_steps, self.ax_overview, stepsize) 465 | self.render_overview(self.ax_overview, stepsize) 466 | self.render_goals(self.ax_overview) 467 | A = self.compute_gso(0) 468 | self.render_adjacency(A, 0, self.ax_overview, stepsize=stepsize) 469 | 470 | self.fig.canvas.draw() 471 | self.fig.subplots_adjust(bottom=0, top=1, left=0, right=1) 472 | return self.fig 473 | 474 | class PathPlanningEnvSaliency(PathPlanningEnv): 475 | def __init__(self, cfg): 476 | super().__init__(cfg) 477 | 478 | def render(self, mode='human', saliency_obs=None, saliency_pos=None): 479 | if self.fig is None: 480 | plt.ion() 481 | self.fig = plt.figure(constrained_layout=True, figsize=(16, 10)) 482 | grid_spec = self.fig.add_gridspec(ncols=max(self.cfg['n_agents']) * 2, 483 | nrows=1 + 3 * len(self.cfg['n_agents']), 484 | height_ratios=[1] + [1, 1, 1] * len(self.cfg['n_agents'])) 485 | 486 | self.ax_overview = self.fig.add_subplot(grid_spec[0, :]) 487 | 488 | self.ax_im_agent = {} 489 | for team_key, team in self.teams.items(): 490 | self.ax_im_agent[team_key] = [] 491 | for i in range(self.cfg['n_agents'][team_key]): 492 | self.ax_im_agent[team_key].append({}) 493 | for j, col_id in enumerate(['map', 'goal', 'pos']): 494 | self.ax_im_agent[team_key][i][col_id] = {} 495 | for k, row_id in enumerate(['obs', 'sal']): 496 | ax = self.fig.add_subplot(grid_spec[j + 1 + team_key * 2, i * 2 + k]) 497 | ax.set_xticks([]) 498 | ax.set_yticks([]) 499 | self.ax_im_agent[team_key][i][col_id][row_id] = {'ax': ax, 'im': None} 500 | self.ax_im_agent[team_key][i][col_id]['sal']['im'] = \ 501 | self.ax_im_agent[team_key][i][col_id]['sal']['ax'].imshow( 502 | np.zeros((1, 1)), vmin=-5, vmax=5) 503 | self.ax_im_agent[team_key][i]['map']['obs']['im'] = self.ax_im_agent[team_key][i]['map']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=3) 504 | self.ax_im_agent[team_key][i]['goal']['obs']['im'] = self.ax_im_agent[team_key][i]['goal']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=1) 505 | self.ax_im_agent[team_key][i]['pos']['obs']['im'] = self.ax_im_agent[team_key][i]['pos']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=1) 506 | 507 | 508 | self.clear_patches(self.ax_overview) 509 | self.render_overview(self.ax_overview) 510 | A = self.compute_gso(0) 511 | self.render_adjacency(A, 0, self.ax_overview) 512 | 513 | if saliency_obs is not None: 514 | saliency_limits = (np.min(saliency_obs), np.max(saliency_obs)) 515 | saliency_map_id = 0 516 | for team_key, team in self.teams.items(): 517 | for i, robot in enumerate(team): 518 | self.ax_im_agent[team_key][i]['map']['obs']['im'].set_data(self.map_colormap(robot.state[..., 0])) 519 | self.ax_im_agent[team_key][i]['goal']['obs']['im'].set_data(self.map_colormap(robot.state[..., 1])) 520 | self.ax_im_agent[team_key][i]['pos']['obs']['im'].set_data(self.map_colormap(robot.state[..., 2])) 521 | 522 | if saliency_obs is not None: 523 | self.ax_im_agent[team_key][i]['map']['sal']['im'].set_data(saliency_obs[saliency_map_id][..., 0]) 524 | self.ax_im_agent[team_key][i]['map']['sal']['im'].set_clim(saliency_limits[0], saliency_limits[1]) 525 | self.ax_im_agent[team_key][i]['coverage']['sal']['im'].set_data(saliency_obs[saliency_map_id][..., 1]) 526 | self.ax_im_agent[team_key][i]['coverage']['sal']['im'].set_clim(saliency_limits[0], saliency_limits[1]) 527 | 528 | saliency_map_id += 1 529 | 530 | if False: # saliency_pos is not None: 531 | print("T", saliency_pos[i][2:].numpy()) 532 | self.ax_im_agent[i]['map']['sal']['ax'].set_title( 533 | f'{saliency_pos[i][0]:.2f}\n{saliency_pos[i][1]:.2f}\n{np.mean(saliency_pos[i][2:].numpy()):.2f}') 534 | 535 | #self.ax_im_agent[team_key][i]['map']['obs']['ax'].set_title( 536 | # f'{i}\nc: {0:.2f}\nr: {robot.reward:.2f}') 537 | 538 | # self.render_connectivity(self.ax_overview, 0, 3) 539 | self.fig.canvas.draw() 540 | return self.fig 541 | 542 | class PathPlanningEnvOverview(PathPlanningEnv): 543 | def __init__(self, cfg): 544 | super().__init__(cfg) 545 | self.map_colormap = colors.ListedColormap(['white', 'black', 'blue', 'red', 'green']) # free, obstacle, pos, goal 546 | 547 | def render(self, mode='human'): 548 | if self.fig is None: 549 | plt.ion() 550 | self.fig = plt.figure(constrained_layout=True, figsize=(16, 10)) 551 | grid_spec = self.fig.add_gridspec(ncols=max(self.cfg['n_agents']), 552 | nrows=1 + len(self.cfg['n_agents']), 553 | height_ratios=[1] + [1] * len(self.cfg['n_agents'])) 554 | 555 | self.ax_overview = self.fig.add_subplot(grid_spec[0, :]) 556 | 557 | self.ax_im_agent = {} 558 | for team_key, team in self.teams.items(): 559 | self.ax_im_agent[team_key] = [] 560 | for i in range(self.cfg['n_agents'][team_key]): 561 | self.ax_im_agent[team_key].append({}) 562 | for j, col_id in enumerate(['overview']): 563 | self.ax_im_agent[team_key][i][col_id] = {} 564 | for k, row_id in enumerate(['obs']): 565 | ax = self.fig.add_subplot(grid_spec[j + 1 + team_key , i + k]) 566 | ax.set_xticks([]) 567 | ax.set_yticks([]) 568 | self.ax_im_agent[team_key][i][col_id][row_id] = {'ax': ax, 'im': None} 569 | self.ax_im_agent[team_key][i]['overview']['obs']['im'] = self.ax_im_agent[team_key][i]['overview']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=4) 570 | 571 | 572 | self.clear_patches(self.ax_overview) 573 | self.render_overview(self.ax_overview) 574 | self.render_goals(self.ax_overview) 575 | A = self.compute_gso(0) 576 | self.render_adjacency(A, 0, self.ax_overview) 577 | 578 | saliency_map_id = 0 579 | for team_key, team in self.teams.items(): 580 | for i, robot in enumerate(team): 581 | state = robot.state[..., 0].copy().astype(np.uint8) # map 582 | state[robot.state[..., 1]==1] = 2 583 | state[robot.state[..., 2]==1] = 3 584 | state[robot.state[..., 2]==2] = 4 585 | #state[robot.state[..., 2]] = 3 586 | self.ax_im_agent[team_key][i]['overview']['obs']['im'].set_data(self.map_colormap(state)) 587 | 588 | self.ax_im_agent[team_key][i]['overview']['obs']['ax'].set_title( 589 | f'{i}') #\nc: {0:.2f}\nr: {robot.reward:.2f}') 590 | 591 | # self.render_connectivity(self.ax_overview, 0, 3) 592 | self.fig.canvas.draw() 593 | return self.fig 594 | -------------------------------------------------------------------------------- /adversarial_comms/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections.abc 3 | import json 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import os 7 | import pandas as pd 8 | import ray 9 | import time 10 | import traceback 11 | 12 | from pathlib import Path 13 | from ray.rllib.models import ModelCatalog 14 | from ray.tune.logger import NoopLogger 15 | from ray.tune.registry import register_env 16 | from ray.util.multiprocessing import Pool 17 | 18 | from .environments.coverage import CoverageEnv 19 | from .environments.path_planning import PathPlanningEnv 20 | from .models.adversarial import AdversarialModel 21 | from .trainers.multiagent_ppo import MultiPPOTrainer 22 | from .trainers.random_heuristic import RandomHeuristicTrainer 23 | 24 | def update_dict(d, u): 25 | for k, v in u.items(): 26 | if isinstance(v, collections.abc.Mapping): 27 | d[k] = update_dict(d.get(k, {}), v) 28 | else: 29 | d[k] = v 30 | return d 31 | 32 | def run_trial(trainer_class=MultiPPOTrainer, checkpoint_path=None, trial=0, cfg_update={}, render=False): 33 | try: 34 | t0 = time.time() 35 | cfg = {'env_config': {}, 'model': {}} 36 | if checkpoint_path is not None: 37 | # We might want to run policies that are not loaded from a checkpoint 38 | # (e.g. the random policy) and therefore need this to be optional 39 | with open(Path(checkpoint_path).parent/"params.json") as json_file: 40 | cfg = json.load(json_file) 41 | 42 | if 'evaluation_config' in cfg: 43 | # overwrite the environment config with evaluation one if it exists 44 | cfg = update_dict(cfg, cfg['evaluation_config']) 45 | 46 | cfg = update_dict(cfg, cfg_update) 47 | 48 | trainer = trainer_class( 49 | env=cfg['env'], 50 | logger_creator=lambda config: NoopLogger(config, ""), 51 | config={ 52 | "framework": "torch", 53 | "seed": trial, 54 | "num_workers": 0, 55 | "env_config": cfg['env_config'], 56 | "model": cfg['model'] 57 | } 58 | ) 59 | if checkpoint_path is not None: 60 | checkpoint_file = Path(checkpoint_path)/('checkpoint-'+os.path.basename(checkpoint_path).split('_')[-1]) 61 | trainer.restore(str(checkpoint_file)) 62 | 63 | envs = {'coverage': CoverageEnv, 'path_planning': PathPlanningEnv} 64 | env = envs[cfg['env']](cfg['env_config']) 65 | env.seed(trial) 66 | obs = env.reset() 67 | 68 | results = [] 69 | for i in range(cfg['env_config']['max_episode_len']): 70 | actions = trainer.compute_action(obs) 71 | obs, reward, done, info = env.step(actions) 72 | if render: 73 | env.render() 74 | for j, reward in enumerate(list(info['rewards'].values())): 75 | results.append({ 76 | 'step': i, 77 | 'agent': j, 78 | 'trial': trial, 79 | 'reward': reward 80 | }) 81 | 82 | print("Done", time.time() - t0) 83 | except Exception as e: 84 | print(e, traceback.format_exc()) 85 | raise 86 | df = pd.DataFrame(results) 87 | return df 88 | 89 | def path_to_hash(path): 90 | path_split = path.split('/') 91 | checkpoint_number_string = path_split[-1].split('_')[-1] 92 | path_hash = path_split[-2].split('_')[-2] 93 | return path_hash + '-' + checkpoint_number_string 94 | 95 | def serve_config(checkpoint_path, trials, cfg_change={}, trainer=MultiPPOTrainer): 96 | with Pool() as p: 97 | results = pd.concat(p.starmap(run_trial, [(trainer, checkpoint_path, t, cfg_change) for t in range(trials)])) 98 | return results 99 | 100 | def initialize(): 101 | ray.init() 102 | register_env("coverage", lambda config: CoverageEnv(config)) 103 | register_env("path_planning", lambda config: PathPlanningEnv(config)) 104 | ModelCatalog.register_custom_model("adversarial", AdversarialModel) 105 | 106 | def eval_nocomm(env_config_func, prefix): 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument("checkpoint") 109 | parser.add_argument("out_path") 110 | parser.add_argument("-t", "--trials", type=int, default=100) 111 | args = parser.parse_args() 112 | 113 | initialize() 114 | results = [] 115 | for comm in [False, True]: 116 | cfg_change={'env_config': env_config_func(comm)} 117 | df = serve_config(args.checkpoint, args.trials, cfg_change=cfg_change, trainer=MultiPPOTrainer) 118 | df['comm'] = comm 119 | results.append(df) 120 | 121 | with open(Path(args.checkpoint).parent/"params.json") as json_file: 122 | cfg = json.load(json_file) 123 | if 'evaluation_config' in cfg: 124 | update_dict(cfg, cfg['evaluation_config']) 125 | 126 | df = pd.concat(results) 127 | df.attrs = cfg 128 | filename = prefix + "-" + path_to_hash(args.checkpoint) + ".pkl" 129 | df.to_pickle(Path(args.out_path)/filename) 130 | 131 | def eval_nocomm_coop(): 132 | # Cooperative agents can communicate or not (without comm interference from adversarial agent) 133 | eval_nocomm(lambda comm: { 134 | 'disabled_teams_comms': [True, not comm], 135 | 'disabled_teams_step': [True, False] 136 | }, "eval_coop") 137 | 138 | def eval_nocomm_adv(): 139 | # all cooperative agents can still communicate, but adversarial communication is switched 140 | eval_nocomm(lambda comm: { 141 | 'disabled_teams_comms': [not comm, False], # en/disable comms for adv and always enabled for coop 142 | 'disabled_teams_step': [False, False] # both teams operating 143 | }, "eval_adv") 144 | 145 | def plot_agent(ax, df, color, step_aggregation='sum', linestyle='-'): 146 | world_shape = df.attrs['env_config']['world_shape'] 147 | max_cov = world_shape[0]*world_shape[1]*df.attrs['env_config']['min_coverable_area_fraction'] 148 | d = (df.sort_values(['trial', 'step']).groupby(['trial', 'step'])['reward'].apply(step_aggregation, 'step').groupby('trial').cumsum()/max_cov*100).groupby('step') 149 | ax.plot(d.mean(), color=color, ls=linestyle) 150 | ax.fill_between(np.arange(len(d.mean())), np.clip(d.mean()-d.std(), 0, None), d.mean()+d.std(), alpha=0.1, color=color) 151 | 152 | def plot(): 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument("data") 155 | parser.add_argument("-o", "--out_file", default=None) 156 | args = parser.parse_args() 157 | 158 | fig_overview = plt.figure(figsize=[4, 4]) 159 | ax = fig_overview.subplots(1, 1) 160 | 161 | df = pd.read_pickle(args.data) 162 | if Path(args.data).name.startswith('eval_adv'): 163 | plot_agent(ax, df[(df['comm'] == False) & (df['agent'] == 0)], 'r', step_aggregation='mean', linestyle=':') 164 | plot_agent(ax, df[(df['comm'] == False) & (df['agent'] > 0)], 'b', step_aggregation='mean', linestyle=':') 165 | plot_agent(ax, df[(df['comm'] == True) & (df['agent'] == 0)], 'r', step_aggregation='mean', linestyle='-') 166 | plot_agent(ax, df[(df['comm'] == True) & (df['agent'] > 0)], 'b', step_aggregation='mean', linestyle='-') 167 | elif Path(args.data).name.startswith('eval_coop'): 168 | plot_agent(ax, df[(df['comm'] == False) & (df['agent'] > 0)], 'b', step_aggregation='sum', linestyle=':') 169 | plot_agent(ax, df[(df['comm'] == True) & (df['agent'] > 0)], 'b', step_aggregation='sum', linestyle='-') 170 | elif Path(args.data).name.startswith('eval_rand'): 171 | plot_agent(ax, df[df['agent'] > 0], 'b', step_aggregation='sum', linestyle='-') 172 | 173 | ax.set_ylabel("Coverage %") 174 | ax.set_ylim(0, 100) 175 | ax.set_xlabel("Episode time steps") 176 | ax.margins(x=0, y=0) 177 | ax.grid() 178 | 179 | fig_overview.tight_layout() 180 | if args.out_file is not None: 181 | fig_overview.savefig(args.out_file, dpi=300) 182 | 183 | plt.show() 184 | 185 | def serve(): 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument("checkpoint") 188 | parser.add_argument("-s", "--seed", type=int, default=0) 189 | args = parser.parse_args() 190 | 191 | initialize() 192 | run_trial(checkpoint_path=args.checkpoint, trial=args.seed, render=True) 193 | 194 | -------------------------------------------------------------------------------- /adversarial_comms/generate_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ray 3 | from ray.util.multiprocessing import Pool 4 | import json 5 | import os 6 | from ray.tune.registry import register_env 7 | from ray.rllib.models import ModelCatalog 8 | from ray.tune.logger import NoopLogger 9 | #from model_team_adversarial import AdaptedVisionNetwork as AdversarialTeamModel 10 | #from model_team_adversarial_2 import AdaptedVisionNetwork as AdversarialTeamModel2 11 | from model_team_adversarial_2_vaegp import AdaptedVisionNetwork as AdversarialTeamModel2VAEGP 12 | from multiagent_ppo_trainer_2 import MultiPPOTrainer as MultiPPOTrainer2 13 | import matplotlib.style as mplstyle 14 | mplstyle.use('fast') 15 | 16 | from world_teams_2 import World as TeamWorld2 17 | from world_flow import WorldOverview as FlowWorld 18 | import pickle 19 | import torch 20 | 21 | import copy 22 | 23 | def generate(seed, checkpoint_path, sample_iterations, termination_mode, frame_take_prob=0.1, disable_adv_comm=False, ensure_conn=False, t_fac=1.5): 24 | with open(checkpoint_path + '/../params.json') as json_file: 25 | checkpoint_config = json.load(json_file) 26 | 27 | checkpoint_config['env_config']['ensure_connectivity'] = ensure_conn 28 | 29 | checkpoint_config['env_config']['disabled_teams_comms'] = [disable_adv_comm, False] 30 | checkpoint_config['env_config']['disabled_teams_step'] = [False, False] 31 | 32 | trainer_cfg = { 33 | "framework": "torch", 34 | "num_workers": 1, 35 | "num_gpus": 1, 36 | "env_config": checkpoint_config['env_config'], 37 | "model": checkpoint_config['model'], 38 | "seed": seed 39 | } 40 | 41 | trainer = MultiPPOTrainer2( 42 | logger_creator=lambda config: NoopLogger(config, ""), 43 | env=checkpoint_config['env'], 44 | config=trainer_cfg 45 | ) 46 | checkpoint_file = checkpoint_path + '/checkpoint-' + os.path.basename(checkpoint_path).split('_')[-1] 47 | trainer.restore(checkpoint_file) 48 | 49 | envs = { 50 | 'flowworld': FlowWorld, 51 | 'teamworld2': TeamWorld2 52 | } 53 | env = envs[checkpoint_config['env']](checkpoint_config['env_config']) 54 | env.seed(seed) 55 | obs = env.reset() 56 | 57 | samples = [] 58 | model = trainer.get_policy().model 59 | 60 | cnn_outputs = [] 61 | def record_cnn_output(module, input_, output): 62 | cnn_outputs.append(output[0].detach().cpu().numpy()) 63 | gnn_outputs = [] 64 | def record_gnn_output(module, input_, output): 65 | gnn_outputs.append(output[0].detach().cpu().numpy()) 66 | #model.coop_convs[-1].register_forward_hook(record_cnn_output) 67 | #model.greedy_convs[-1].register_forward_hook(record_cnn_output) 68 | model.GFL.register_forward_hook(record_gnn_output) 69 | 70 | while len(samples) < sample_iterations: 71 | actions = trainer.compute_action(obs) 72 | for j in range(1, sum(checkpoint_config['env_config']['n_agents'])): 73 | #obs['agents'][j]['cnn_out'] = cnn_outputs[j] 74 | z, mu, log = model.coop_vaegp.vae.encode(torch.from_numpy(np.array([obs['agents'][j]['map']])).float().permute(0,3,1,2)) 75 | obs['agents'][j]['cnn_out'] = z[0].detach() 76 | obs['agents'][j]['gnn_out'] = gnn_outputs[0][..., j] 77 | cnn_outputs = [] 78 | gnn_outputs = [] 79 | 80 | if np.random.rand() <= frame_take_prob: 81 | samples.append(copy.deepcopy({'obs': obs, 'actions': actions})) 82 | print(len(samples)) 83 | 84 | obs, reward, done, info = env.step(actions) 85 | if (termination_mode == 'path' and done) or (termination_mode == 'cov' and env.timestep == int(t_fac*(info['coverable_area']/checkpoint_config['env_config']['n_agents'][1]))): 86 | obs = env.reset() 87 | return samples 88 | 89 | def run(seed, checkpoint_path, samples, workers, generated_path, termination_mode, frame_take_prob=0.2, disable_adv_comm=False, t_fac=1.5): 90 | results = [] 91 | with Pool(workers) as p: 92 | for res in p.starmap(generate, [(seed+i, checkpoint_path, int(samples/workers), termination_mode, frame_take_prob, disable_adv_comm, t_fac) for i in range(workers)]): 93 | results += res 94 | print("DONE", len(results)) 95 | pickle.dump(results, open(generated_path, "wb")) 96 | 97 | if __name__ == "__main__": 98 | ray.init() 99 | #ModelCatalog.register_custom_model("vis_torch_adv_team", AdversarialTeamModel) 100 | #ModelCatalog.register_custom_model("vis_torch_adv_team_2", AdversarialTeamModel2) 101 | ModelCatalog.register_custom_model("vis_torch_adv_team_2_vaegp", AdversarialTeamModel2VAEGP) 102 | 103 | register_env("teamworld2", lambda config: TeamWorld2(config)) 104 | register_env("flowworld", lambda config: FlowWorld(config)) 105 | 106 | # cooperative trainings 107 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-19_01-15-57_hu8xcpq/checkpoint_1560" # coverage 108 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-18_23-44-12k2_enqa8/checkpoint_150" # split 109 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_flowworld_0_2020-07-16_00-47-53k6vmhzpl/checkpoint_1300" # flow 7x7 110 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_flowworld_0_2020-07-16_10-53-29pe06c7bw/checkpoint_3100" # flow 24x24 111 | 112 | # adversarial 113 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-20_23-31-52zj__fmp3/checkpoint_4600" # cov 114 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-19_09-34-02u_h77o5y/checkpoint_1400" # split 115 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_flowworld_0_2020-07-16_10-59-27c8iboc7_/checkpoint_3800" # flow 7x7 116 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_flowworld_0_2020-07-24_00-56-5896e2idut/checkpoint_8100" # flow 24x24 117 | 118 | #re-adapt 119 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-27_00-38-42vpz3xf0k/checkpoint_5690" # cov 120 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-27_00-48-12zecm5uk7/checkpoint_2190" # split 121 | 122 | #checkpoint_path = "/local/scratch/jb2270/vaegp_eval/MultiPPO/MultiPPO_teamworld2_0_2020-08-23_11-27-05u14jlcjb/checkpoint_1560" # simple 123 | 124 | #checkpoint_path = "/local/scratch/jb2270/vaegp_eval/MultiPPO/MultiPPO_teamworld2_0_2020-08-24_20-43-40mnea2uga/checkpoint_1560" # train with frozen VAE 125 | checkpoint_path = "/local/scratch/jb2270/vaegp_eval/MultiPPO/MultiPPO_teamworld2_c8d29_00000/checkpoint_410" 126 | 127 | termination_mode = "cov" # cov/path 128 | 129 | checkpoint_num = checkpoint_path.split("_")[-1] 130 | checkpoint_id = checkpoint_path.split("/")[-2].split("-")[-1] 131 | #generate(0, checkpoint_path, 1000, 0.1, 1.5) 132 | #exit() 133 | run(0, checkpoint_path, 50000, 32, f"/local/scratch/jb2270/datasets_corl/explainability_data_{checkpoint_id}_{checkpoint_num}_train.pkl",termination_mode, disable_adv_comm=True) 134 | run(1, checkpoint_path, 10000, 32, f"/local/scratch/jb2270/datasets_corl/explainability_data_{checkpoint_id}_{checkpoint_num}_valid.pkl", termination_mode, disable_adv_comm=True) 135 | run(2, checkpoint_path, 1000, 1, f"/local/scratch/jb2270/datasets_corl/explainability_data_{checkpoint_id}_{checkpoint_num}_test.pkl", termination_mode, disable_adv_comm=True, frame_take_prob=1.0, t_fac=4) 136 | #run(2, checkpoint_path, 1000, 1, f"/local/scratch/jb2270/datasets_corl/explainability_data_{checkpoint_id}_{checkpoint_num}_nocomm_test.pkl", termination_mode, disable_adv_comm=True, frame_take_prob=1.0) 137 | 138 | -------------------------------------------------------------------------------- /adversarial_comms/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/adversarial_comms/889a41e239958bd519365c6b0c143469cd27e49f/adversarial_comms/models/__init__.py -------------------------------------------------------------------------------- /adversarial_comms/models/adversarial.py: -------------------------------------------------------------------------------- 1 | from ray.rllib.models.modelv2 import ModelV2 2 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 3 | from ray.rllib.policy.rnn_sequencing import add_time_dimension 4 | from ray.rllib.models.torch.misc import normc_initializer, same_padding, SlimConv2d, SlimFC 5 | from ray.rllib.utils.annotations import override 6 | from ray.rllib.utils import try_import_torch 7 | 8 | from .gnn import adversarialGraphML as gml_adv 9 | from .gnn import graphML as gml 10 | from .gnn import graphTools 11 | import numpy as np 12 | import copy 13 | 14 | torch, nn = try_import_torch() 15 | from torchsummary import summary 16 | 17 | # https://ray.readthedocs.io/en/latest/using-ray-with-pytorch.html 18 | 19 | DEFAULT_OPTIONS = { 20 | "activation": "relu", 21 | "agent_split": 1, 22 | "cnn_compression": 512, 23 | "cnn_filters": [[32, [8, 8], 4], [64, [4, 4], 2], [128, [4, 4], 2]], 24 | "cnn_residual": False, 25 | "freeze_coop": True, 26 | "freeze_coop_value": False, 27 | "freeze_greedy": False, 28 | "freeze_greedy_value": False, 29 | "graph_edge_features": 1, 30 | "graph_features": 512, 31 | "graph_layers": 1, 32 | "graph_tabs": 3, 33 | "relative": True, 34 | "value_cnn_compression": 512, 35 | "value_cnn_filters": [[32, [8, 8], 2], [64, [4, 4], 2], [128, [4, 4], 2]], 36 | "forward_values": True 37 | } 38 | 39 | class AdversarialModel(TorchModelV2, nn.Module): 40 | def __init__(self, obs_space, action_space, num_outputs, model_config, name):#, 41 | #graph_layers, graph_features, graph_tabs, graph_edge_features, cnn_filters, value_cnn_filters, value_cnn_compression, cnn_compression, relative, activation): 42 | TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) 43 | nn.Module.__init__(self) 44 | 45 | self.cfg = copy.deepcopy(DEFAULT_OPTIONS) 46 | self.cfg.update(model_config['custom_model_config']) 47 | 48 | #self.cfg = model_config['custom_options'] 49 | self.n_agents = len(obs_space.original_space['agents']) 50 | self.graph_features = self.cfg['graph_features'] 51 | self.cnn_compression = self.cfg['cnn_compression'] 52 | self.activation = { 53 | 'relu': nn.ReLU, 54 | 'leakyrelu': nn.LeakyReLU 55 | }[self.cfg['activation']] 56 | 57 | layers = [] 58 | input_shape = obs_space.original_space['agents'][0]['map'].shape 59 | (w, h, in_channels) = input_shape 60 | 61 | in_size = [w, h] 62 | for out_channels, kernel, stride in self.cfg['cnn_filters'][:-1]: 63 | padding, out_size = same_padding(in_size, kernel, [stride, stride]) 64 | layers.append(SlimConv2d(in_channels, out_channels, kernel, stride, padding, activation_fn=self.activation)) 65 | in_channels = out_channels 66 | in_size = out_size 67 | 68 | out_channels, kernel, stride = self.cfg['cnn_filters'][-1] 69 | layers.append( 70 | SlimConv2d(in_channels, out_channels, kernel, stride, None)) 71 | layers.append(nn.Flatten(1, -1)) 72 | #if isinstance(cnn_compression, int): 73 | # layers.append(nn.Linear(cnn_compression, self.cfg['graph_features']-2)) # reserve 2 for pos 74 | # layers.append(self.activation{)) 75 | self.coop_convs = nn.Sequential(*layers) 76 | self.greedy_convs = copy.deepcopy(self.coop_convs) 77 | 78 | self.coop_value_obs_convs = copy.deepcopy(self.coop_convs) 79 | self.greedy_value_obs_convs = copy.deepcopy(self.coop_convs) 80 | 81 | summary(self.coop_convs, device="cpu", input_size=(input_shape[2], input_shape[0], input_shape[1])) 82 | 83 | gfl = [] 84 | for i in range(self.cfg['graph_layers']): 85 | gfl.append(gml_adv.GraphFilterBatchGSOA(self.graph_features, self.graph_features, self.cfg['graph_tabs'], self.cfg['agent_split'], self.cfg['graph_edge_features'], False)) 86 | #gfl.append(gml.GraphFilterBatchGSO(self.graph_features, self.graph_features, self.cfg['graph_tabs'], self.cfg['graph_edge_features'], False)) 87 | gfl.append(self.activation()) 88 | 89 | self.GFL = nn.Sequential(*gfl) 90 | 91 | #gso_sum = torch.zeros(2, 1, 8, 8) 92 | #self.GFL[0].addGSO(gso_sum) 93 | #summary(self.GFL, device="cuda" if torch.cuda.is_available() else "cpu", input_size=(self.graph_features, 8)) 94 | 95 | logits_inp_features = self.graph_features 96 | if self.cfg['cnn_residual']: 97 | logits_inp_features += self.cnn_compression 98 | 99 | post_logits = [ 100 | nn.Linear(logits_inp_features, 64), 101 | self.activation(), 102 | nn.Linear(64, 32), 103 | self.activation() 104 | ] 105 | logit_linear = nn.Linear(32, 5) 106 | nn.init.xavier_uniform_(logit_linear.weight) 107 | nn.init.constant_(logit_linear.bias, 0) 108 | post_logits.append(logit_linear) 109 | self.coop_logits = nn.Sequential(*post_logits) 110 | self.greedy_logits = copy.deepcopy(self.coop_logits) 111 | summary(self.coop_logits, device="cpu", input_size=(logits_inp_features,)) 112 | 113 | ############################## 114 | 115 | layers = [] 116 | input_shape = np.array(obs_space.original_space['state'].shape) 117 | (w, h, in_channels) = input_shape 118 | 119 | in_size = [w, h] 120 | for out_channels, kernel, stride in self.cfg['value_cnn_filters'][:-1]: 121 | padding, out_size = same_padding(in_size, kernel, [stride, stride]) 122 | layers.append(SlimConv2d(in_channels, out_channels, kernel, stride, padding, activation_fn=self.activation)) 123 | in_channels = out_channels 124 | in_size = out_size 125 | 126 | out_channels, kernel, stride = self.cfg['value_cnn_filters'][-1] 127 | layers.append( 128 | SlimConv2d(in_channels, out_channels, kernel, stride, None)) 129 | layers.append(nn.Flatten(1, -1)) 130 | 131 | self.coop_value_cnn = nn.Sequential(*layers) 132 | self.greedy_value_cnn = copy.deepcopy(self.coop_value_cnn) 133 | summary(self.greedy_value_cnn, device="cpu", input_size=(input_shape[2], input_shape[0], input_shape[1])) 134 | 135 | layers = [ 136 | nn.Linear(self.cnn_compression + self.cfg['value_cnn_compression'], 64), 137 | self.activation(), 138 | nn.Linear(64, 32), 139 | self.activation() 140 | ] 141 | values_linear = nn.Linear(32, 1) 142 | normc_initializer()(values_linear.weight) 143 | nn.init.constant_(values_linear.bias, 0) 144 | layers.append(values_linear) 145 | 146 | self.coop_value_branch = nn.Sequential(*layers) 147 | self.greedy_value_branch = copy.deepcopy(self.coop_value_branch) 148 | summary(self.coop_value_branch, device="cpu", input_size=(self.cnn_compression + self.cfg['value_cnn_compression'],)) 149 | 150 | self._cur_value = None 151 | 152 | self.freeze_coop_value(self.cfg['freeze_coop_value']) 153 | self.freeze_greedy_value(self.cfg['freeze_greedy_value']) 154 | self.freeze_coop(self.cfg['freeze_coop']) 155 | self.freeze_greedy(self.cfg['freeze_greedy']) 156 | 157 | def freeze_coop(self, freeze): 158 | all_params = \ 159 | list(self.coop_convs.parameters()) + \ 160 | [self.GFL[0].weight1] + \ 161 | list(self.coop_logits.parameters()) 162 | 163 | for param in all_params: 164 | param.requires_grad = not freeze 165 | 166 | def freeze_greedy(self, freeze): 167 | all_params = \ 168 | list(self.greedy_logits.parameters()) + \ 169 | list(self.greedy_convs.parameters()) + \ 170 | [self.GFL[0].weight0] 171 | 172 | for param in all_params: 173 | param.requires_grad = not freeze 174 | 175 | def freeze_greedy_value(self, freeze): 176 | all_params = \ 177 | list(self.greedy_value_branch.parameters()) + \ 178 | list(self.greedy_value_cnn.parameters()) + \ 179 | list(self.greedy_value_obs_convs) 180 | 181 | for param in all_params: 182 | param.requires_grad = not freeze 183 | 184 | def freeze_coop_value(self, freeze): 185 | all_params = \ 186 | list(self.coop_value_cnn.parameters()) + \ 187 | list(self.coop_value_branch.parameters()) + \ 188 | list(self.coop_value_obs_convs) 189 | 190 | for param in all_params: 191 | param.requires_grad = not freeze 192 | 193 | @override(ModelV2) 194 | def forward(self, input_dict, state, seq_lens): 195 | batch_size = input_dict["obs"]['gso'].shape[0] 196 | o_as = input_dict["obs"]['agents'] 197 | 198 | gso = input_dict["obs"]['gso'].unsqueeze(1) 199 | device = gso.device 200 | 201 | for i in range(len(self.GFL)//2): 202 | self.GFL[i*2].addGSO(gso) 203 | 204 | greedy_cnn = self.greedy_convs(o_as[0]['map'].permute(0, 3, 1, 2)) 205 | coop_agents_cnn = {id_agent: self.coop_convs(o_as[id_agent]['map'].permute(0, 3, 1, 2)) for id_agent in range(1, len(o_as))} 206 | 207 | greedy_value_obs_cnn = self.greedy_value_obs_convs(o_as[0]['map'].permute(0, 3, 1, 2)) 208 | coop_value_obs_cnn = {id_agent: self.coop_value_obs_convs(o_as[id_agent]['map'].permute(0, 3, 1, 2)) for id_agent in range(1, len(o_as))} 209 | 210 | extract_feature_map = torch.zeros(batch_size, self.graph_features, self.n_agents).to(device) 211 | extract_feature_map[:, :self.cnn_compression, 0] = greedy_cnn 212 | for id_agent in range(1, len(o_as)): 213 | extract_feature_map[:, :self.cnn_compression, id_agent] = coop_agents_cnn[id_agent] 214 | 215 | shared_feature = self.GFL(extract_feature_map) 216 | 217 | logits = torch.empty(batch_size, self.n_agents, 5).to(device) 218 | values = torch.empty(batch_size, self.n_agents).to(device) 219 | 220 | logits_inp = shared_feature[..., 0] 221 | if self.cfg['cnn_residual']: 222 | logits_inp = torch.cat([logits_inp, greedy_cnn], dim=1) 223 | logits[:, 0] = self.greedy_logits(logits_inp) 224 | if self.cfg['forward_values']: 225 | greedy_value_cnn = self.greedy_value_cnn(input_dict["obs"]["state"].permute(0, 3, 1, 2)) 226 | coop_value_cnn = self.coop_value_cnn(input_dict["obs"]["state"].permute(0, 3, 1, 2)) 227 | 228 | values[:, 0] = self.greedy_value_branch(torch.cat([greedy_value_obs_cnn, greedy_value_cnn], dim=1)).squeeze(1) 229 | 230 | for id_agent in range(1, len(o_as)): 231 | this_entity = shared_feature[..., id_agent] 232 | if self.cfg['cnn_residual']: 233 | this_entity = torch.cat([this_entity, coop_agents_cnn[id_agent]], dim=1) 234 | logits[:, id_agent] = self.coop_logits(this_entity) 235 | 236 | if self.cfg['forward_values']: 237 | value_cat = torch.cat([coop_value_cnn, coop_value_obs_cnn[id_agent]], dim=1) 238 | values[:, id_agent] = self.coop_value_branch(value_cat).squeeze(1) 239 | 240 | self._cur_value = values 241 | return logits.view(batch_size, self.n_agents*5), state 242 | 243 | @override(ModelV2) 244 | def value_function(self): 245 | assert self._cur_value is not None, "must call forward() first" 246 | return self._cur_value 247 | 248 | -------------------------------------------------------------------------------- /adversarial_comms/models/gnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/adversarial_comms/889a41e239958bd519365c6b0c143469cd27e49f/adversarial_comms/models/gnn/__init__.py -------------------------------------------------------------------------------- /adversarial_comms/models/gnn/adversarialGraphML.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .graphML import GraphFilterBatchGSO 7 | 8 | def batchLSIGFA(h0, h1, N0, SK, x, bias=None, aggregation=lambda y, dim: torch.sum(y, dim=dim)): 9 | """ 10 | batchLSIGF(filter_taps, GSO_K, input, bias=None) Computes the output of a 11 | linear shift-invariant graph filter on input and then adds bias. 12 | 13 | In this case, we consider that there is a separate GSO to be used for each 14 | of the signals in the batch. In other words, SK[b] is applied when filtering 15 | x[b] as opposed to applying the same SK to all the graph signals in the 16 | batch. 17 | 18 | Inputs: 19 | filter_taps: vector of filter taps; size: 20 | output_features x edge_features x filter_taps x input_features 21 | GSO_K: collection of matrices; size: 22 | batch_size x edge_features x filter_taps x number_nodes x number_nodes 23 | input: input signal; size: 24 | batch_size x input_features x number_nodes 25 | bias: size: output_features x number_nodes 26 | if the same bias is to be applied to all nodes, set number_nodes = 1 27 | so that b_{f} vector becomes b_{f} \mathbf{1}_{N} 28 | 29 | Outputs: 30 | output: filtered signals; size: 31 | batch_size x output_features x number_nodes 32 | """ 33 | # Get the parameter numbers: 34 | assert h0.shape == h1.shape 35 | F = h0.shape[0] 36 | E = h0.shape[1] 37 | K = h0.shape[2] 38 | G = h0.shape[3] 39 | B = SK.shape[0] 40 | assert SK.shape[1] == E 41 | assert SK.shape[2] == K 42 | N = SK.shape[3] 43 | assert SK.shape[4] == N 44 | assert x.shape[0] == B 45 | assert x.shape[1] == G 46 | assert x.shape[2] == N 47 | # Or, in the notation I've been using: 48 | # h in F x E x K x G 49 | # SK in B x E x K x N x N 50 | # x in B x G x N 51 | # b in F x N 52 | # y in B x F x N 53 | SK = SK.permute(1, 2, 0, 3, 4) 54 | # Now, SK is of shape E x K x B x N x N so that we can multiply by x of 55 | # size B x G x N to get 56 | z = torch.matmul(x, SK) 57 | # which is of size E x K x B x G x N 58 | # Now, we have already carried out the multiplication across the dimension 59 | # of the nodes. Now we need to focus on the K, F, G. 60 | # Let's start by putting B and N in the front 61 | z = z.permute(1, 2, 4, 0, 3).reshape([K, B, N, E * G]) 62 | # so that we get z in B x N x EKG. 63 | # Now adjust the filter taps so they are of the form EKG x F 64 | h0 = h0.permute(2, 1, 3, 0).reshape([K, G * E, F]) 65 | h1 = h1.permute(2, 1, 3, 0).reshape([K, G * E, F]) 66 | #h1 = h1.reshape([F, G * E * K]).permute(1, 0) 67 | # Multiply 68 | if N0 == 0: 69 | y = torch.empty(K, B, N, G * E).to(z.device) 70 | for k in range(K): 71 | y[k] = torch.matmul(z[k], h1[k]) 72 | y = aggregation(y, 0) 73 | # to get a result of size B x N x F. And permute 74 | y = y.permute(0, 2, 1) 75 | else: 76 | z0 = z[:, :, :N0] 77 | z1 = z[:, :, N0:] 78 | y0 = torch.empty(K, B, N0, G * E).to(z.device) 79 | y1 = torch.empty(K, B, N-N0, G * E).to(z.device) 80 | for k in range(K): 81 | y0[k] = torch.matmul(z0[k], h0[k]) 82 | y1[k] = torch.matmul(z1[k], h1[k]) 83 | y0 = aggregation(y0, 0) 84 | y1 = aggregation(y1, 0) 85 | # to get a result of size B x N x F. And permute 86 | y0 = y0.permute(0, 2, 1) 87 | y1 = y1.permute(0, 2, 1) 88 | y = torch.cat([y0, y1], dim = 2) # concat along N 89 | # to get it back in the right order: B x F x N. 90 | # Now, in this case, each element x[b,:,:] has adequately been filtered by 91 | # the GSO S[b,:,:,:] 92 | if bias is not None: 93 | y = y + bias 94 | return y 95 | 96 | class GraphFilterBatchGSOA(GraphFilterBatchGSO): 97 | def __init__(self, G, F, K, N0, E = 1, bias = True, aggregation='sum'): 98 | super().__init__(G, F, K, E, bias) 99 | self.weight0 = self.weight 100 | self.weight1 = nn.parameter.Parameter(torch.Tensor(self.F, self.E, self.K, self.G)) 101 | self.N0 = N0 102 | self.reset_parameters() 103 | self.aggregation = { 104 | "sum": lambda y, dim: torch.sum(y, dim=dim), 105 | "median": lambda y, dim: torch.median(y, dim=dim)[0], 106 | "min": lambda y, dim: torch.min(y, dim=dim)[0] 107 | }[aggregation] 108 | 109 | def reset_parameters(self): 110 | super().reset_parameters() 111 | if hasattr(self, 'weight1'): 112 | stdv = 1. / math.sqrt(self.G * self.K) 113 | self.weight1.data.uniform_(-stdv, stdv) 114 | 115 | def forward(self, x): 116 | return self.forward_gpvae(x) if self.K == 2 else batchLSIGFA(self.weight0, self.weight1, self.N0, self.SK, x, self.bias, aggregation=self.aggregation) 117 | 118 | def forward_gpvae(self, x): 119 | # K=1 120 | hx_0_0 = torch.matmul(self.weight0[:, 0, 0, :], x[:, :, :self.N0]) 121 | hx_0_1 = torch.matmul(self.weight1[:, 0, 0, :], x[:, :, self.N0:]) 122 | hx_0 = torch.cat([hx_0_0, hx_0_1], dim=2) 123 | 124 | # K=2 125 | neighbors = self.aggregation(x[:, :, :, None] * self.S, dim=2) 126 | hx_1_0 = torch.matmul(self.weight0[:, 0, 1, :], neighbors[:, :, :self.N0]) 127 | hx_1_1 = torch.matmul(self.weight1[:, 0, 1, :], neighbors[:, :, self.N0:]) 128 | hx_1 = torch.cat([hx_1_0, hx_1_1], dim=2) 129 | 130 | output = hx_0 + hx_1 131 | return output 132 | 133 | def forward_naive(self, x): 134 | bs, features, n_agents = x.shape 135 | output = torch.zeros(bs, features, n_agents) 136 | for b in range(bs): 137 | sxas = torch.zeros(self.K, features, n_agents) 138 | sk = torch.eye(n_agents).expand(n_agents, n_agents) 139 | for k in range(self.K): 140 | sx = torch.matmul(x[b], sk) 141 | h0 = self.weight0[:, 0, k, :] 142 | h1 = self.weight1[:, 0, k, :] 143 | if self.N0 == 0: 144 | sxas[k] = torch.matmul(h1, sx) 145 | else: 146 | sxa0 = torch.matmul(h0, sx[:, :self.N0]) 147 | sxa1 = torch.matmul(h1, sx[:, self.N0:]) 148 | sxas[k] = torch.cat([sxa0, sxa1], dim=1) # concat along N 149 | sk = torch.matmul(self.S[b, 0], sk) 150 | 151 | output[b] = self.aggregation(sxas, 0) 152 | return output 153 | -------------------------------------------------------------------------------- /adversarial_comms/train_interpreter.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | 4 | from ray.rllib.models.modelv2 import ModelV2 5 | from ray.rllib.utils.annotations import override 6 | from ray.rllib.utils import try_import_torch 7 | 8 | import utils.graphML as gml 9 | import utils.graphTools 10 | import numpy as np 11 | 12 | torch, nn = try_import_torch() 13 | from torch.utils.data import Dataset 14 | 15 | from torch.optim import SGD, Adam 16 | import torch.nn.functional as F 17 | 18 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator 19 | from ignite.metrics import Precision, Recall, Fbeta, Loss, RunningAverage 20 | #from ignite.contrib.metrics import ROC_AUC, AveragePrecision 21 | from ignite.handlers import ModelCheckpoint, global_step_from_engine, EarlyStopping, TerminateOnNan 22 | from ignite.contrib.handlers import ProgressBar 23 | from ignite.contrib.handlers.tensorboard_logger import * 24 | 25 | from torchsummary import summary 26 | 27 | from collections import OrderedDict 28 | import matplotlib.pyplot as plt 29 | import matplotlib.patches as patches 30 | from matplotlib import colors 31 | import json 32 | import random 33 | from pathlib import Path 34 | import time 35 | import os 36 | import copy 37 | 38 | # https://ray.readthedocs.io/en/latest/using-ray-with-pytorch.html 39 | 40 | X = 1 41 | Y = 0 42 | 43 | def get_transpose_cnn(inp_features, out_shape, out_classes): 44 | return [ 45 | nn.ConvTranspose2d(in_channels=inp_features, out_channels=64, kernel_size=3, stride=1), 46 | nn.LeakyReLU(inplace=True), 47 | nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2), 48 | nn.LeakyReLU(inplace=True), 49 | nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=4, stride=2), 50 | nn.LeakyReLU(inplace=True), 51 | nn.ZeroPad2d([1,1,1,1]), 52 | nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=4, stride=2), 53 | nn.LeakyReLU(inplace=True), 54 | nn.ZeroPad2d([1,1,1,1]), 55 | nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3, stride=1), 56 | nn.LeakyReLU(inplace=True), 57 | nn.Conv2d(8, out_classes, 3, 1), 58 | nn.Sigmoid(), 59 | ] 60 | 61 | def get_upsampling_cnn(inp_features, out_shape, out_classes): 62 | if out_shape == 7: 63 | return [ 64 | nn.ZeroPad2d([2]*4), 65 | nn.Conv2d(in_channels=inp_features, out_channels=16, kernel_size=3), 66 | nn.LeakyReLU(inplace=True), 67 | nn.Upsample(scale_factor=2), 68 | nn.ZeroPad2d([1]*4), 69 | nn.Conv2d(in_channels=16, out_channels=8, kernel_size=4), 70 | nn.LeakyReLU(inplace=True), 71 | nn.Upsample(scale_factor=2), 72 | nn.Conv2d(in_channels=8, out_channels=out_classes, kernel_size=4), 73 | nn.Sigmoid(), 74 | ] 75 | elif out_shape == 12: 76 | return [ 77 | nn.ZeroPad2d([2]*4), 78 | nn.Conv2d(in_channels=inp_features, out_channels=16, kernel_size=3), 79 | nn.LeakyReLU(inplace=True), 80 | nn.Upsample(scale_factor=2), 81 | nn.ZeroPad2d([1]*4), 82 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=4), 83 | nn.LeakyReLU(inplace=True), 84 | nn.Upsample(scale_factor=2), 85 | nn.Conv2d(in_channels=16, out_channels=8, kernel_size=4), 86 | nn.LeakyReLU(inplace=True), 87 | nn.Upsample(scale_factor=2), 88 | nn.Conv2d(in_channels=8, out_channels=out_classes, kernel_size=3), 89 | nn.Sigmoid(), 90 | ] 91 | elif out_shape == 24: 92 | return [ 93 | nn.ZeroPad2d([2]*4), 94 | nn.Conv2d(in_channels=inp_features, out_channels=32, kernel_size=3), 95 | nn.LeakyReLU(inplace=True), 96 | nn.Upsample(scale_factor=2), 97 | nn.ZeroPad2d([1]*4), 98 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3), 99 | nn.LeakyReLU(inplace=True), 100 | nn.Upsample(scale_factor=2), 101 | nn.ZeroPad2d([1]*4), 102 | nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3), 103 | nn.LeakyReLU(inplace=True), 104 | nn.Upsample(scale_factor=2), 105 | nn.ZeroPad2d([1]*4), 106 | nn.Conv2d(in_channels=8, out_channels=out_classes, kernel_size=3), 107 | nn.Sigmoid(), 108 | ] 109 | elif out_shape == 48: 110 | return [ 111 | nn.ZeroPad2d([2]*4), 112 | nn.Conv2d(in_channels=inp_features, out_channels=64, kernel_size=3), 113 | nn.LeakyReLU(inplace=True), 114 | nn.Upsample(scale_factor=2), 115 | nn.ZeroPad2d([1]*4), 116 | nn.Conv2d(in_channels=64, out_channels=48, kernel_size=3), 117 | nn.LeakyReLU(inplace=True), 118 | nn.Upsample(scale_factor=2), 119 | nn.ZeroPad2d([1]*4), 120 | nn.Conv2d(in_channels=48, out_channels=32, kernel_size=3), 121 | nn.LeakyReLU(inplace=True), 122 | nn.Upsample(scale_factor=2), 123 | nn.ZeroPad2d([1]*4), 124 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3), 125 | nn.LeakyReLU(inplace=True), 126 | nn.Upsample(scale_factor=2), 127 | nn.ZeroPad2d([1]*4), 128 | nn.Conv2d(in_channels=16, out_channels=out_classes, kernel_size=3), 129 | nn.Sigmoid() 130 | ] 131 | assert False 132 | 133 | class Model(nn.Module): 134 | def __init__(self, dataset, config): 135 | nn.Module.__init__(self) 136 | 137 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 138 | self.config = config 139 | self.inp_features = { 140 | 'gnn': dataset.gnn_features, 141 | 'cnn': dataset.cnn_features, 142 | 'gnn_cnn': dataset.cnn_features+dataset.gnn_features 143 | }[self.config['nn_mode']] 144 | 145 | if self.config['format']=='relative': 146 | if self.config['pred_mode'] == 'global': 147 | self.out_size = dataset.world_shape[0]*2 148 | elif self.config['pred_mode'] == 'local': 149 | self.out_size = dataset.obs_size 150 | else: 151 | raise NotImplementedError 152 | elif self.config['format']=='absolute': 153 | self.out_size = dataset.world_shape[0] 154 | else: 155 | raise NotImplementedError 156 | 157 | if self.config['type'] == 'cov': 158 | self.classes = 1 if self.config['prediction']=='cov_only' else 2 159 | elif self.config['type'] == 'path': 160 | self.classes = 3 if self.config['prediction']=='all' else 1 161 | else: 162 | raise NotImplementedError("Invalid type") 163 | 164 | layers = get_upsampling_cnn(self.inp_features, self.out_size, self.classes) 165 | cnn = nn.Sequential(*layers) 166 | #summary(cnn, device="cpu", input_size=(self.inp_features, 1, 1)) 167 | self._post_cnn = cnn.to(self.device) 168 | 169 | @override(ModelV2) 170 | def forward(self, input_dict): 171 | agent_observations = input_dict["obs"]['agents'] 172 | batch_size = input_dict["obs"]['gso'].shape[0] 173 | 174 | prediction = torch.empty(batch_size, len(agent_observations), self.classes, self.out_size, self.out_size).to( 175 | self.device) 176 | for this_id, this_state in enumerate(agent_observations): 177 | gnn_out = this_state['gnn_out'] 178 | cnn_out = this_state['cnn_out'] 179 | this_entity = { 180 | 'gnn': gnn_out, 181 | 'cnn': cnn_out, 182 | 'gnn_cnn': torch.cat([gnn_out, cnn_out], dim=1) 183 | }[self.config['nn_mode']] 184 | prediction[:, this_id] = self._post_cnn(this_entity.view(batch_size, self.inp_features, 1, 1)) 185 | 186 | return prediction.double() 187 | 188 | class BaseDataset(Dataset): 189 | def __init__(self, path): 190 | try: 191 | with open(Path(path), "rb") as f: 192 | self.data = pickle.load(f) 193 | assert (len(self.data) > 0) 194 | except TypeError: 195 | self.data = [{'obs': path}] 196 | self.world_shape = self.data[0]['obs']['state'].shape[:2] 197 | self.obs_size = self.data[0]['obs']['agents'][0]['map'].shape[0] 198 | self.cnn_features = self.data[0]['obs']['agents'][-1]['cnn_out'].shape[0] 199 | self.gnn_features = self.data[0]['obs']['agents'][-1]['gnn_out'].shape[0] 200 | 201 | def __len__(self): 202 | return len(self.data) 203 | 204 | def get_coverable_area(self, idx): 205 | coverable_area = ~(self.data[idx]['obs']['state'][...,0] > 0) 206 | return np.sum(coverable_area) 207 | 208 | def get_coverage_fraction(self, idx): 209 | coverable_area = ~(self.data[idx]['obs']['state'][...,0] > 0) 210 | covered_area = self.data[idx]['obs']['state'][...,1] & coverable_area 211 | return np.sum(covered_area) / np.sum(coverable_area) 212 | 213 | def to_agent_coord_frame(self, m, state_size, pose, fill=0): 214 | half_state_shape = np.array([state_size / 2] * 2, dtype=np.int) 215 | padded = np.pad(m, ([half_state_shape[Y]] * 2, [half_state_shape[X]] * 2), mode='constant', 216 | constant_values=fill) 217 | return padded[pose[Y]:pose[Y] + state_size, pose[X]:pose[X] + state_size] 218 | 219 | class CoverageDataset(BaseDataset): 220 | def __init__(self, path, config): 221 | # is_relative: Agent relative or world absolute prediction 222 | # cov_only: Predict only coverage or predict both coverage and map 223 | # is_global: Predict local coverage and map or global coverage and map 224 | 225 | super().__init__(path) 226 | self.is_relative=config['format']=='relative' 227 | self.cov_only=config['prediction']=='cov_only' 228 | self.is_global=config['pred_mode']=='global' 229 | self.skip_agents=config['skip_agents'] if 'skip_agents' in config else 0 230 | self.stop_agents=config['stop_agents'] if 'stop_agents' in config else None 231 | 232 | def __getitem__(self, idx): 233 | if torch.is_tensor(idx): 234 | idx = idx.tolist() 235 | 236 | y = [] 237 | weights = [] 238 | for agent_obs in self.data[idx]['obs']['agents'][self.skip_agents:self.stop_agents]: 239 | if self.is_global: 240 | obs_cov = self.data[idx]['obs']['state'][...,1] 241 | obs_map = self.data[idx]['obs']['state'][...,0] 242 | if self.is_relative: 243 | obs_cov = self.to_agent_coord_frame(obs_cov, self.obs_size, agent_obs['pos'], fill=0) 244 | obs_map = self.to_agent_coord_frame(obs_map, self.obs_size, agent_obs['pos'], fill=1) 245 | else: 246 | if self.is_relative: 247 | # directly use agent's relative view coverage 248 | obs_cov = agent_obs['map'][..., 1] 249 | obs_map = agent_obs['map'][..., 0] 250 | else: 251 | # shift the agent's local coverage to an absolute view 252 | m = np.roll(agent_obs['map'], agent_obs['pos'], axis=(0,1))[int(self.obs_size/2):,int(self.obs_size/2):] 253 | obs_cov = m[...,1] 254 | obs_map = m[...,0] 255 | 256 | if self.cov_only: 257 | # only predict local coverage and use world map as mask 258 | y.append([obs_cov]) 259 | weights.append([(~obs_map.astype(np.bool)).astype(np.int)]) 260 | else: 261 | # predict both local coverage and world map, but mask out everything outside the world shifted to the agents position 262 | d = np.stack([obs_cov, obs_map], axis=0) 263 | y.append(d) 264 | weight = np.ones(obs_cov.shape) 265 | if self.is_relative: 266 | weight = self.to_agent_coord_frame(weight, self.obs_size, agent_obs['pos'], fill=0) 267 | weight = np.stack([weight]*2, axis=0) 268 | #print(d.shape, weight.shape) 269 | weights.append(weight) 270 | 271 | y = np.array(y, dtype=np.double) 272 | w = np.array(weights, dtype=np.double) 273 | obs = self.data[idx]['obs'] 274 | obs['agents'] = obs['agents'][self.skip_agents:self.stop_agents] 275 | return {'obs': self.data[idx]['obs']}, {'y': y, 'w': w} 276 | 277 | class PathplanningDataset(BaseDataset): 278 | def __init__(self, path, config): 279 | # cov_only: Predict only coverage or predict both coverage and map 280 | # is_global: Predict local coverage and map or global coverage and map 281 | 282 | super().__init__(path) 283 | self.is_relative=config['format']=='relative' 284 | self.pred_mode=config['prediction'] 285 | self.is_global=config['pred_mode']=='global' 286 | self.skip_agents=config['skip_agents'] if 'skip_agents' in config else 0 287 | self.stop_agents=config['stop_agents'] if 'stop_agents' in config else None 288 | 289 | def __getitem__(self, idx): 290 | if torch.is_tensor(idx): 291 | idx = idx.tolist() 292 | 293 | y = [] 294 | weights = [] 295 | for agent_obs in self.data[idx]['obs']['agents'][self.skip_agents:self.stop_agents]: 296 | if self.is_global: 297 | obs_map = np.zeros(self.data[idx]['obs']['state'].shape[:2], dtype=np.float) 298 | obs_pos = np.zeros(self.data[idx]['obs']['state'].shape[:2], dtype=np.float) 299 | obs_goal = np.zeros(self.data[idx]['obs']['state'].shape[:2], dtype=np.float) 300 | obs_map[self.data[idx]['obs']['state'][..., 0] == 1] = 1 301 | for i in range(self.data[idx]['obs']['state'].shape[-1]): 302 | obs_pos[self.data[idx]['obs']['state'][..., i] == 2] = 1 303 | obs_goal[self.data[idx]['obs']['state'][..., i] == 3] = 1 304 | 305 | if self.is_relative: 306 | obs_goal = self.to_agent_coord_frame(obs_goal, self.world_shape[0]*2, agent_obs['pos'], fill=0) 307 | obs_pos = self.to_agent_coord_frame(obs_pos, self.world_shape[0]*2, agent_obs['pos'], fill=0) 308 | obs_map = self.to_agent_coord_frame(obs_map, self.world_shape[0]*2, agent_obs['pos'], fill=1) 309 | else: 310 | # directly use agent's relative view coverage 311 | obs_map = agent_obs['map'][..., 0] 312 | obs_goal = agent_obs['map'][..., 1] 313 | obs_pos = agent_obs['map'][..., 2] 314 | 315 | if self.pred_mode == "goal": 316 | # only predict local coverage and use world map as mask 317 | y.append(np.stack([obs_goal], axis=0)) 318 | weight = (~obs_map.astype(np.bool)).astype(np.int) 319 | # goal can generally be on the margin if it is projected! 320 | for row in [0, -1]: 321 | weight[row] = 1 322 | weight[:, row] = 1 323 | weights.append(copy.deepcopy([weight])) 324 | elif self.pred_mode == "all": 325 | # predict both local coverage and world map, but mask out everything outside the world shifted to the agents position 326 | d = np.stack([obs_map, obs_goal, obs_pos], axis=0) 327 | y.append(d) 328 | weight = np.ones(obs_map.shape) 329 | weight = np.stack([weight]*3, axis=0) 330 | weights.append(weight) 331 | 332 | y = np.array(y, dtype=np.double) 333 | w = np.array(weights, dtype=np.double) 334 | obs = self.data[idx]['obs'] 335 | obs['agents'] = obs['agents'][self.skip_agents:self.stop_agents] 336 | return {'obs': self.data[idx]['obs']}, {'y': y, 'w': w} 337 | 338 | dataset_classes = { 339 | "path": PathplanningDataset, 340 | "cov": CoverageDataset 341 | } 342 | 343 | def inference(model_checkpoint_path, 344 | data_path, 345 | seed=None, run_eval=False, save_dirname=None): 346 | if seed is None: 347 | seed = time.time() 348 | torch.manual_seed(seed) 349 | random.seed(seed) 350 | batch_size = 1 351 | 352 | checkpoint_file = Path(model_checkpoint_path) 353 | with open(checkpoint_file.parent / 'config.json', 'r') as config_file: 354 | config = json.load(config_file) 355 | config['skip_agents'] = 0 356 | dataset = dataset_classes[config['type']](data_path, config) 357 | loader = torch.utils.data.DataLoader( 358 | dataset, 359 | batch_size=batch_size, 360 | shuffle=True, 361 | num_workers=1 362 | ) 363 | model = load_model(checkpoint_file, Model(dataset, config)) 364 | 365 | cmap_map = colors.LinearSegmentedColormap.from_list("cmap_map", [(0, 0, 0, 0), (0, 0, 0, 1)]) 366 | cmap_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (0, 1, 0, 1)]) 367 | cmap_own_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (1, 1, 0, 1)]) 368 | 369 | def transform_rel_abs(data, pos): 370 | return np.roll(data, pos, axis=(0, 1))[int(dataset.obs_size / 2):, int(dataset.obs_size / 2):] 371 | 372 | if run_eval: 373 | evaluator = create_evaluator(model) 374 | evaluator.run(loader) 375 | metrics = evaluator.state.metrics 376 | print(metrics) 377 | identifier = f"{config['format']}-{config['nn_mode']}-{config['pred_mode']}" 378 | batch_index = 0 379 | for x, y_true in loader: 380 | fig, axs = plt.subplots(batch_size*2, 5, figsize=[8, 3.2]) 381 | if run_eval: 382 | #axs[0][0].set_title(f"ap: {metrics['ap']:.4f}, auc: {metrics['auc']:.4f}") 383 | fig.suptitle(f"{identifier} f1: {metrics['f1']:.4f}, ap: {metrics['ap']:.4f}, auc: {metrics['auc']:.4f}") 384 | 385 | y_pred = model(x).detach().numpy() 386 | #y_pred = torch.round(model(x).detach()).numpy() 387 | for i in range(batch_size): 388 | for j in range(5): 389 | #for agent in x['obs']['agents']: 390 | # print(agent['pos']) 391 | 392 | #axs[i*2][j].imshow(y_true['y'][i][j][1, :, :], cmap=cmap_cov) 393 | agent_obs = x['obs']['agents'][j] 394 | pos = agent_obs['pos'][i] 395 | agent_map = agent_obs['map'][i] 396 | 397 | if config['format'] == 'relative': 398 | #axs[i*2][j].imshow(transform_rel_abs(agent_map[...,0], pos), cmap=cmap_map) # obstacles 399 | #axs[i*2][j].imshow(transform_rel_abs(y_true['y'][i][j][0, :, :], pos), cmap=cmap_cov) 400 | #axs[i*2][j].imshow(transform_rel_abs(agent_map[...,1], pos), cmap=cmap_own_cov) 401 | 402 | axs[i*2][j].imshow(agent_map[...,0], cmap=cmap_map) # obstacles 403 | #axs[i*2][j].imshow(y_true['y'][i][j][1, :, :], cmap=cmap_map) 404 | axs[i*2][j].imshow(y_true['y'][i][j][0, :, :], cmap=cmap_cov) 405 | 406 | #print(y_pred[i][j][1, :, :]) 407 | #axs[i*2+1][j].imshow(y_pred[i][j][1, :, :], cmap=cmap_map) 408 | #axs[i*2+1][j].imshow(y_pred[i][j][0, :, :], cmap=cmap_cov) 409 | 410 | #axs[i*2][j].imshow(agent_map[...,1], cmap=cmap_own_cov) 411 | 412 | axs[i*2+1][j].imshow(y_pred[i][j][0, :, :], cmap=cmap_cov) 413 | #axs[i*2+1][j].imshow(y_pred[i][j][1, :, :], cmap=cmap_map) 414 | axs[i * 2+1][j].imshow(agent_map[..., 0], cmap=cmap_map) # obstacles 415 | #axs[i*2+1][j].imshow(y_true['w'][i][j][0, :, :], cmap=cmap_map) # weighting 416 | #axs[i*2+1][j].imshow(transform_rel_abs(y_pred[i][j][0, :, :], pos), cmap=cmap_own_cov if config['pred_mode'] == 'local' else cmap_cov) 417 | 418 | #map_data = transform_rel_abs(agent_obs['map'][i][...,0], pos) if len(y_pred[i][j]) == 1 else y_pred[i][j][1, :, :] 419 | #axs[i*2+1][j].imshow(map_data, cmap=cmap_map) # obstacles 420 | 421 | else: 422 | axs[i*2][j].imshow(x['obs']['state'][i][...,0], cmap=cmap_map) # obstacles 423 | axs[i*2][j].imshow(x['obs']['state'][i][...,1], cmap=cmap_cov) 424 | m = np.roll(agent_obs['map'][i], agent_obs['pos'][i], axis=(0,1))[int(dataset.obs_size/2):,int(dataset.obs_size/2):] 425 | axs[i*2][j].imshow(m[...,1], cmap=cmap_own_cov) 426 | 427 | axs[i*2+1][j].imshow(y_pred[i][j][0, :, :], cmap=cmap_own_cov if config['pred_mode'] == 'local' else cmap_cov) 428 | axs[i*2+1][j].imshow(x['obs']['state'][i][...,0], cmap=cmap_map) # obstacles 429 | 430 | ''' 431 | for k in range(2): 432 | rect = patches.Rectangle((agent_obs['pos'][i][1] - 1 / 2, agent_obs['pos'][i][0] - 1 / 2), 1, 1, 433 | linewidth=1, edgecolor='r', facecolor='none') 434 | axs[i*2+k][j].add_patch(rect) 435 | ''' 436 | for k in range(2): 437 | axs[i * 2 + k][j].set_xticks([]) 438 | axs[i * 2 + k][j].set_yticks([]) 439 | 440 | fig.tight_layout() #rect=[0, 0.03, 1, 0.95]) 441 | if save_dirname is not None: 442 | img_path = checkpoint_file.parent/save_dirname 443 | img_path.mkdir(exist_ok=True) 444 | frame_path = img_path/f"{batch_index:05d}.png" 445 | print("Frame", frame_path) 446 | plt.savefig(frame_path, dpi=300) 447 | else: 448 | plt.show() 449 | plt.close() 450 | batch_index += 1 451 | if batch_index == 300: 452 | break 453 | 454 | def inference_gnn_cnn(cnn_model_checkpoint_path, 455 | gnn_model_checkpoint_path, 456 | data_path, 457 | seed=None, run_eval=False, save_dirname=None): 458 | if seed is None: 459 | seed = time.time() 460 | torch.manual_seed(seed) 461 | random.seed(seed) 462 | batch_size = 1 463 | 464 | checkpoint_file = Path(gnn_model_checkpoint_path) 465 | with open(checkpoint_file.parent / 'config.json', 'r') as config_file: 466 | config = json.load(config_file) 467 | config['skip_agents'] = 0 468 | gnn_dataset = dataset_classes[config['type']](data_path, config) 469 | gnn_loader = torch.utils.data.DataLoader( 470 | gnn_dataset, 471 | batch_size=batch_size, 472 | shuffle=False, 473 | num_workers=1 474 | ) 475 | gnn_model = load_model(checkpoint_file, Model(gnn_dataset, config)) 476 | 477 | checkpoint_file = Path(cnn_model_checkpoint_path) 478 | with open(checkpoint_file.parent / 'config.json', 'r') as config_file: 479 | config = json.load(config_file) 480 | config['skip_agents'] = 0 481 | cnn_dataset = dataset_classes[config['type']](data_path, config) 482 | cnn_loader = torch.utils.data.DataLoader( 483 | cnn_dataset, 484 | batch_size=batch_size, 485 | shuffle=False, 486 | num_workers=1 487 | ) 488 | cnn_model = load_model(checkpoint_file, Model(cnn_dataset, config)) 489 | 490 | 491 | cmap_map = colors.LinearSegmentedColormap.from_list("cmap_map", [(0, 0, 0, 0), (0, 0, 0, 1)]) 492 | cmap_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (0, 1, 0, 1)]) 493 | cmap_own_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (0, 0, 1, 1)]) 494 | 495 | def transform_rel_abs(data, pos): 496 | return np.roll(data, pos, axis=(0, 1))[int(gnn_dataset.obs_size / 2):, int(gnn_dataset.obs_size / 2):] 497 | 498 | identifier = f"{config['format']}-{config['nn_mode']}-{config['pred_mode']}" 499 | batch_index = 0 500 | for (x, y_true), (x_cnn, y_cnn_true) in zip(gnn_loader, cnn_loader): 501 | fig, axs = plt.subplots(batch_size*2, 5, figsize=[8, 3.2]) 502 | 503 | y_pred = gnn_model(x).detach().numpy() 504 | y_pred_cnn = cnn_model(x_cnn).detach().numpy() 505 | for i in range(batch_size): 506 | for j in range(5): 507 | if j == 0: 508 | agent_obs = x['obs']['agents'][j] 509 | pos = agent_obs['pos'][i] 510 | agent_map = agent_obs['map'][i] 511 | else: 512 | agent_obs = x_cnn['obs']['agents'][j] 513 | pos = agent_obs['pos'][i] 514 | agent_map = agent_obs['map'][i] 515 | 516 | axs[i*2][j].imshow(transform_rel_abs(agent_map[...,0], pos), cmap=cmap_map) # obstacles 517 | axs[i*2][j].imshow(transform_rel_abs(y_true['y'][i][j][0, :, :], pos), cmap=cmap_cov) 518 | axs[i*2][j].imshow(transform_rel_abs(agent_map[...,1], pos), cmap=cmap_own_cov) 519 | 520 | if j == 0: 521 | axs[i*2+1][j].imshow(transform_rel_abs(y_pred_cnn[i][j][0, :, :], pos), cmap=cmap_own_cov) 522 | else: 523 | axs[i*2+1][j].imshow(transform_rel_abs(y_pred[i][j][0, :, :], pos), cmap=cmap_cov) 524 | 525 | map_data = agent_obs['map'][i][...,0] if len(y_pred[i][j]) == 1 else y_pred[i][j][1, :, :] 526 | axs[i*2+1][j].imshow(transform_rel_abs(map_data, pos), cmap=cmap_map) # obstacles 527 | 528 | for k in range(2): 529 | rect = patches.Rectangle((agent_obs['pos'][i][1] - 1 / 2, agent_obs['pos'][i][0] - 1 / 2), 1, 1, 530 | linewidth=1, edgecolor='r', facecolor='none') 531 | axs[i*2+k][j].add_patch(rect) 532 | 533 | for k in range(2): 534 | axs[i * 2 + k][j].set_xticks([]) 535 | axs[i * 2 + k][j].set_yticks([]) 536 | 537 | fig.tight_layout() #rect=[0, 0.03, 1, 0.95]) 538 | if save_dirname is not None: 539 | img_path = checkpoint_file.parent/save_dirname 540 | img_path.mkdir(exist_ok=True) 541 | frame_path = img_path/f"{batch_index:05d}.png" 542 | print("Frame", frame_path) 543 | plt.savefig(frame_path, dpi=300) 544 | else: 545 | plt.show() 546 | plt.close() 547 | batch_index += 1 548 | if batch_index == 10: 549 | break 550 | 551 | def inference_path(model_checkpoint_path, 552 | data_path, 553 | seed=None, run_eval=False, save_dirname=None): 554 | if seed is None: 555 | seed = time.time() 556 | torch.manual_seed(seed) 557 | random.seed(seed) 558 | batch_size = 1 559 | 560 | checkpoint_file = Path(model_checkpoint_path) 561 | with open(checkpoint_file.parent / 'config.json', 'r') as config_file: 562 | config = json.load(config_file) 563 | config['skip_agents'] = 0 564 | dataset = dataset_classes[config['type']](data_path, config) 565 | loader = torch.utils.data.DataLoader( 566 | dataset, 567 | batch_size=batch_size, 568 | shuffle=True, 569 | num_workers=1 570 | ) 571 | model = load_model(checkpoint_file, Model(dataset, config)) 572 | 573 | cmap_map = colors.LinearSegmentedColormap.from_list("cmap_map", [(0, 0, 0, 0), (0, 0, 0, 1)]) 574 | cmap_pos = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (0, 1, 0, 1)]) 575 | cmap_goal = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (1, 1, 0, 1)]) 576 | 577 | def transform_rel_abs(data, pos): 578 | return data #np.roll(data, pos, axis=(0, 1))#[int(dataset.obs_size / 2):, int(dataset.obs_size / 2):] 579 | 580 | if run_eval: 581 | evaluator = create_evaluator(model) 582 | evaluator.run(loader) 583 | metrics = evaluator.state.metrics 584 | print(metrics) 585 | identifier = f"{config['format']}-{config['nn_mode']}-{config['pred_mode']}" 586 | batch_index = 0 587 | for x, y_true in loader: 588 | n_agents = len(x['obs']['agents']) 589 | fig, axs = plt.subplots(batch_size*2, 5, figsize=[8, 3.2]) 590 | if run_eval: 591 | #axs[0][0].set_title(f"ap: {metrics['ap']:.4f}, auc: {metrics['auc']:.4f}") 592 | fig.suptitle(f"{identifier} f1: {metrics['f1']:.4f}, ap: {metrics['ap']:.4f}, auc: {metrics['auc']:.4f}") 593 | 594 | y_pred = model(x).detach().numpy() 595 | #y_pred = torch.round(model(x).detach()).numpy() 596 | for i in range(batch_size): 597 | for j in range(5): 598 | #for agent in x['obs']['agents']: 599 | # print(agent['pos']) 600 | 601 | #axs[i*2][j].imshow(y_true['y'][i][j][1, :, :], cmap=cmap_cov) 602 | agent_obs = x['obs']['agents'][j] 603 | pos = agent_obs['pos'][i] 604 | agent_map = agent_obs['map'][i] 605 | 606 | axs[i*2][j].imshow(transform_rel_abs(agent_map[...,0], pos), cmap=cmap_map) # obstacles 607 | axs[i*2][j].imshow(transform_rel_abs(agent_map[...,2], pos), cmap=cmap_pos) # obstacles 608 | axs[i*2][j].imshow(transform_rel_abs(agent_map[...,1], pos), cmap=cmap_goal) # obstacles 609 | 610 | axs[i*2+1][j].imshow(transform_rel_abs(agent_map[...,0], pos), cmap=cmap_map) # obstacles 611 | axs[i*2+1][j].imshow(transform_rel_abs(agent_map[...,2], pos), cmap=cmap_pos) # obstacles 612 | axs[i*2+1][j].imshow(y_pred[i][j][0, :, :], cmap=cmap_goal) 613 | 614 | #axs[i*2+1][j].imshow(y_pred[i][j][1, :, :], cmap=cmap_goal) 615 | #axs[i*2+1][j].imshow(y_pred[i][j][2, :, :], cmap=cmap_pos) 616 | 617 | #axs[i*2+1][j].imshow(transform_rel_abs(y_pred[i][j][0, :, :], pos), cmap=cmap_own_cov if config['pred_mode'] == 'local' else cmap_cov) 618 | #axs[i*2+1][j].imshow(transform_rel_abs(agent_obs['map'][i][...,0], pos), cmap=cmap_map) # obstacles 619 | 620 | for k in range(2): 621 | axs[i * 2 + k][j].set_xticks([]) 622 | axs[i * 2 + k][j].set_yticks([]) 623 | 624 | fig.tight_layout() #rect=[0, 0.03, 1, 0.95]) 625 | if save_dirname is not None: 626 | img_path = checkpoint_file.parent/save_dirname 627 | img_path.mkdir(exist_ok=True) 628 | frame_path = img_path/f"{batch_index:05d}.png" 629 | print("Frame", frame_path) 630 | plt.savefig(frame_path, dpi=300) 631 | else: 632 | plt.show() 633 | plt.close() 634 | batch_index += 1 635 | if batch_index == 300: 636 | break 637 | 638 | def thresholded_output_transform(output): 639 | y_pred, y = output 640 | y_pred = torch.round(y_pred) 641 | return y_pred, y 642 | 643 | def apply_weight_output_transform(x): 644 | y_pred_raw, y_raw = x[0], x[1] # shape each (batch size, agent, channel, x, y) 645 | 646 | classes = y_pred_raw.shape[2] 647 | 648 | w = y_raw['w'].permute([2, 0,1,3,4]).flatten() 649 | y = y_raw['y'].permute([2, 0,1,3,4]).flatten()[w==1].reshape(-1, classes) 650 | y_pred = y_pred_raw.permute([2, 0,1,3,4]).flatten()[w==1].reshape(-1, classes) 651 | 652 | return (y_pred, y) 653 | 654 | def apply_weight_threshold_output_transform(x): 655 | return apply_weight_output_transform(thresholded_output_transform(x)) 656 | 657 | def weighted_binary_cross_entropy(y_pred, y): 658 | return F.binary_cross_entropy(y_pred, y['y'], weight=y['w']) 659 | 660 | from ignite.metrics import EpochMetric 661 | 662 | class AveragePrecision(EpochMetric): 663 | def __init__(self, output_transform=lambda x: x): 664 | def average_precision_compute_fn(y_preds, y_targets): 665 | try: 666 | from sklearn.metrics import average_precision_score 667 | except ImportError: 668 | raise RuntimeError("This contrib module requires sklearn to be installed.") 669 | 670 | y_true = y_targets.numpy() 671 | y_pred = y_preds.numpy() 672 | return average_precision_score(y_true, y_pred, average='micro') 673 | 674 | super(AveragePrecision, self).__init__(average_precision_compute_fn, output_transform=output_transform) 675 | 676 | def create_evaluator(model): 677 | return create_supervised_evaluator( 678 | model, 679 | metrics={ 680 | #"p": Precision(apply_weight_threshold_output_transform), 681 | #"r": Recall(apply_weight_threshold_output_transform), 682 | #"f1": Fbeta(1, output_transform=apply_weight_threshold_output_transform), 683 | #"auc": ROC_AUC(output_transform=apply_weight_output_transform), 684 | "ap": AveragePrecision(output_transform=apply_weight_output_transform) 685 | }, 686 | device=model.device 687 | ) 688 | 689 | def load_model(checkpoint_path, model): 690 | model_state = torch.load(checkpoint_path, map_location=torch.device('cpu')) 691 | 692 | # A basic remapping is required 693 | mapping = {k: v for k, v in zip(model_state.keys(), model.state_dict().keys())} 694 | mapped_model_state = OrderedDict([(mapping[k], v) for k, v in model_state.items()]) 695 | model.load_state_dict(mapped_model_state, strict=False) 696 | return model 697 | 698 | def train(train_data_path, 699 | valid_data_path, 700 | config, 701 | out_dir="./explainability", 702 | batch_size=64, 703 | lr=1e-4, 704 | epochs=100): 705 | train_dataset = dataset_classes[config['type']](train_data_path, config) 706 | valid_dataset = dataset_classes[config['type']](valid_data_path, config) 707 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8) 708 | val_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8) 709 | 710 | model = Model(valid_dataset, config) 711 | path_cp = "./explainability_checkpoints/"+out_dir 712 | os.makedirs(path_cp, exist_ok=True) 713 | with open(path_cp+"/config.json", 'w') as config_file: 714 | json.dump(config, config_file) 715 | 716 | optimizer = Adam(model.parameters(), lr=lr) 717 | trainer = create_supervised_trainer(model, optimizer, weighted_binary_cross_entropy, device=model.device) 718 | 719 | validation_evaluator = create_evaluator(model) 720 | 721 | RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") 722 | 723 | pbar = ProgressBar(persist=True) 724 | pbar.attach(trainer, metric_names="all") 725 | 726 | trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) 727 | 728 | best_model_handler = ModelCheckpoint(dirname="./explainability_checkpoints/"+out_dir, 729 | filename_prefix="best", 730 | n_saved=1, 731 | global_step_transform=global_step_from_engine(trainer), 732 | score_name="val_ap", 733 | score_function=lambda engine: engine.state.metrics['ap'], 734 | require_empty=False) 735 | validation_evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {'model': model, }) 736 | 737 | tb_logger = TensorboardLogger(log_dir='./explainability_tensorboard/'+out_dir) 738 | tb_logger.attach( 739 | trainer, 740 | log_handler=OutputHandler( 741 | tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all" 742 | ), 743 | event_name=Events.ITERATION_COMPLETED(every=100), 744 | ) 745 | 746 | tb_logger.attach( 747 | validation_evaluator, 748 | log_handler=OutputHandler(tag="validation", metric_names=["ap"], another_engine=trainer), 749 | event_name=Events.EPOCH_COMPLETED, 750 | ) 751 | #tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_COMPLETED(every=100)) 752 | #tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) 753 | #tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) 754 | #tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) 755 | #tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) 756 | 757 | @trainer.on(Events.EPOCH_COMPLETED(every=5)) 758 | def log_validation_results(engine): 759 | validation_evaluator.run(val_loader) 760 | metrics = validation_evaluator.state.metrics 761 | pbar.log_message( 762 | f"Validation Results - Epoch: {engine.state.epoch} ap: {metrics['ap']}" # f1: {metrics['f1']}, p: {metrics['p']}, r: {metrics['r']} 763 | ) 764 | 765 | pbar.n = pbar.last_print_n = 0 766 | 767 | trainer.run(train_loader, max_epochs=epochs) 768 | 769 | def evaluate(model_checkpoint, data_path, **kwargs): 770 | checkpoint_file = Path(model_checkpoint) 771 | with open(checkpoint_file.parent/'config.json', 'r') as config_file: 772 | config = json.load(config_file) 773 | 774 | ap = [] 775 | for start, end in [[0, 1], [1, None]]: 776 | config['stop_agents'] = end 777 | config['skip_agents'] = start 778 | 779 | dataset = dataset_classes[config['type']](data_path, config) 780 | loader = torch.utils.data.DataLoader( 781 | dataset, 782 | batch_size=64, 783 | shuffle=True, 784 | num_workers=2 785 | ) 786 | model = load_model(checkpoint_file, Model(dataset, config)) 787 | evaluator = create_evaluator(model) 788 | t0 = time.time() 789 | evaluator.run(loader) 790 | #print("T", time.time() - t0) 791 | m = evaluator.state.metrics['ap'] 792 | if not isinstance(m, list): 793 | m = [m] 794 | ap.append(m) 795 | print(ap) 796 | return ap 797 | 798 | def analyse_dataset(path): 799 | with open(path, "rb") as f: 800 | data = pickle.load(f) 801 | analyse(data) 802 | 803 | def get_coverage_fraction(map, coverage): 804 | coverable_area = ~(map > 0) 805 | covered_area = coverage & coverable_area 806 | return np.sum(covered_area) / np.sum(coverable_area) 807 | 808 | coverage_fractions = [] 809 | for sample in data: 810 | world_cov = sample['obs']['state'][..., 1] 811 | world_map = sample['obs']['state'][..., 0] 812 | coverage_fractions.append(get_coverage_fraction(world_map, world_cov)) 813 | 814 | print(np.mean(coverage_fractions), np.std(coverage_fractions)) 815 | # plt.hist(coverage_fractions) 816 | # plt.show() 817 | 818 | 819 | if __name__ == "__main__": 820 | #train("explainability_data_k3_268ugnliyw_2735_train.pkl", "explainability_data_k3_268ugnliyw_2735_valid.pkl", "./explainability_k3_sgd", epochs=10000, batch_size=32, lr=0.1, sgd_momentum=0.9) 821 | if False: 822 | #dataset_id = "031g2r0u73_3070" 823 | #dataset_id = "096dwc6v_g_2990" 824 | #dataset_id = "22hfzad070_3050" 825 | #dataset_id = "280kl0nkhl_2960" 826 | #dataset_id = "44cdqovnq4_1540" # split 827 | #dataset_id = "27c8iboc7__2250" # flow 828 | 829 | #dataset_id = "101jtn1ssr_2450" 830 | #dataset_id = "46t6qhvrxf_5400" 831 | 832 | dataset_id = "300d3g1xqj_3120" 833 | #dataset_id = "303qg71k5o_3120" 834 | train( 835 | f"/local/scratch/jb2270/datasets_corl/explainability_data_{dataset_id}_train.pkl", 836 | f"/local/scratch/jb2270/datasets_corl/explainability_data_{dataset_id}_valid.pkl", 837 | { 838 | 'format': 'relative', # absolute/relative 839 | 'nn_mode': 'cnn', # cnn/gnn/gnn_cnn 840 | 'pred_mode': 'local', # local (own coverage and map)/global (global coverage and map) 841 | 'prediction': 'cov_map', #cov_only/cov_map 842 | 'type': 'cov' # path/cov 843 | }, 844 | f"explainability_cov_map_local_cnn_{dataset_id}", 845 | epochs=10000, 846 | batch_size=64, 847 | lr=5e-3, 848 | ) 849 | 850 | #evaluate("explainability_data_56uhj2ync9_2650_valid.pkl", "./explainability_checkpoints/explainability_56uhj2ync9_rel_glob/best_model_636_val_auc=0.8963244891793261.pth") 851 | 852 | #inference("./results/0610/explainability_checkpoints/explainability_228pbizcxq_1955_glob/best_model_973_val_auc=0.8896503235407915.pth", "./explainability_data_228pbizcxq_1955_test.pkl", 11, False) 853 | 854 | #inference("./results/0610/explainability_checkpoints/explainability_228pbizcxq_1955_loc/best_model_1017_val_auc=0.9856628585466523.pth", "./explainability_data_228pbizcxq_1955_test.pkl", 11, False) 855 | 856 | # flow 857 | #evaluate("./results/0712/explainability_checkpoints/explainability_glob_27c8iboc7__2250/best_model_162_val_auc=0.9998915352938512.pth", "./results/0712/explainability_data_27c8iboc7__3500_comm_test.pkl", save_dirname="rendering_comm") 858 | #evaluate("./results/0712/explainability_checkpoints/explainability_glob_27c8iboc7__2250/best_model_162_val_auc=0.9998915352938512.pth", "./results/0712/explainability_data_27c8iboc7__3500_nocomm_test.pkl", save_dirname="rendering_nocomm") 859 | #inference_path("./results/0721/explainability_checkpoints/explainability_path_goal_only_local_cnn_27c8iboc7__2250/best_model_30_val_ap=1.0.pth", "./results/0712/explainability_data_27c8iboc7__3500_nocomm_test.pkl") 860 | 861 | # split 862 | #inference("./results/0712/explainability_checkpoints/explainability_cov_map_44cdqovnq4_1540/best_model_77_val_auc=0.999701756785555.pth", "./results/0712/explainability_data_44cdqovnq4_1540_nocomm_test.pkl") 863 | #evaluate("./results/0712/explainability_checkpoints/explainability_cov_map_44cdqovnq4_1540/best_model_77_val_auc=0.999701756785555.pth", "./results/0712/explainability_data_44cdqovnq4_1540_comm_test.pkl", save_dirname="rendering_comm") 864 | #evaluate("./results/0712/explainability_checkpoints/explainability_cov_map_global_gnn_44cdqovnq4_1540/best_model_67_val_auc=0.9980259594481737.pth", "./results/0712/explainability_data_44cdqovnq4_1540_comm_test.pkl", save_dirname="rendering_comm") 865 | #evaluate("./results/0712/explainability_checkpoints/explainability_cov_map_global_gnn_44cdqovnq4_1540/best_model_67_val_auc=0.9980259594481737.pth", "./results/0712/explainability_data_44cdqovnq4_1540_nocomm_test.pkl", save_dirname="rendering_nocomm") 866 | 867 | # coverage normal 868 | #evaluate("./results/0712/explainability_checkpoints/explainability_cov_map_local_cnn_300d3g1xqj_3120/best_model_139_val_auc=0.9885640826508932.pth", "./results/0712/explainability_data_300d3g1xqj_3120_nocomm_test.pkl", save_dirname="rendering_nocomm") 869 | #inference("./results/0712/explainability_checkpoints/explainability_cov_map_local_cnn_300d3g1xqj_3120/best_model_139_val_auc=0.9885640826508932.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #, save_dirname="rendering_comm") 870 | #inference("./results/0721/explainability_checkpoints/explainability_cov_cov_only_local_cnn_300d3g1xqj_3120/best_model_190_val_ap=0.8773743058356964.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #, save_dirname="rendering_comm") 871 | #evaluate("./results/0721/explainability_checkpoints/explainability_cov_cov_only_local_cnn_300d3g1xqj_3120/best_model_190_val_ap=0.8773743058356964.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #, save_dirname="rendering_comm") 872 | 873 | #inference("./results/0721/explainability_checkpoints/explainability_cov_cov_only_global_gnn_300d3g1xqj_3120/best_model_140_val_ap=0.8570488698746233.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #, save_dirname="rendering_comm") 874 | #inference_gnn_cnn( 875 | # "./results/0721/explainability_checkpoints/explainability_cov_cov_only_local_cnn_300d3g1xqj_3120/best_model_190_val_ap=0.8773743058356964.pth", 876 | # "./results/0721/explainability_checkpoints/explainability_cov_cov_only_global_gnn_300d3g1xqj_3120/best_model_140_val_ap=0.8570488698746233.pth", 877 | # "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl", 878 | # save_dirname="rendering_comm" 879 | #) 880 | #evaluate("./results/0712/explainability_checkpoints/explainability_cov_map_global_gnn_300d3g1xqj_3120/best_model_69_val_auc=0.961077006252264.pth", "./results/0712/explainability_data_300d3g1xqj_3120_nocomm_test.pkl", save_dirname="rendering_nocomm") 881 | #inference("./results/0712/explainability_checkpoints/explainability_cov_map_global_gnn_300d3g1xqj_3120/best_model_69_val_auc=0.961077006252264.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #,save_dirname="rendering_comm") 882 | #inference("./results/0712/explainability_checkpoints/explainability_cov_map_global_gnn_300d3g1xqj_3120/best_model_69_val_auc=0.961077006252264.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #,save_dirname="rendering_comm") 883 | 884 | #evaluate("./results/0712/explainability_checkpoints/explainability_300d3g1xqj_3120/best_model_384_val_auc=0.9779040424681612.pth", "./results/0712/explainability_data_300d3g1xqj_3120_nocomm_test.pkl", save_dirname="rendering_nocomm") 885 | #evaluate("./results/0712/explainability_checkpoints/explainability_300d3g1xqj_3120/best_model_384_val_auc=0.9779040424681612.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl", save_dirname="rendering_comm") 886 | 887 | #inference("./results/0823/expl_checkpoints/explainability_cov_cov_map_local_271e7f5bc3_1560/best_model_30_val_ap=0.9611180740842954.pth", "../../Internship/gpvae/data/explainability_data_271e7f5bc3_1560_test.pkl") #,save_dirname="rendering_comm") 888 | inference("./results/0823/expl_checkpoints/explainability_cov_cov_only_local_271e7f5bc3_1560/best_model_85_val_ap=0.8428215464016102.pth", "../../Internship/gpvae/data/explainability_data_271e7f5bc3_1560_test.pkl") #,save_dirname="rendering_comm") 889 | -------------------------------------------------------------------------------- /adversarial_comms/train_policy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections.abc 3 | import yaml 4 | import json 5 | import os 6 | import ray 7 | 8 | import numpy as np 9 | 10 | from pathlib import Path 11 | from ray import tune 12 | from ray.rllib.utils import try_import_torch 13 | from ray.rllib.models import ModelCatalog 14 | from ray.tune.registry import register_env 15 | from ray.tune.logger import pretty_print, DEFAULT_LOGGERS, TBXLogger 16 | from ray.rllib.utils.schedules import PiecewiseSchedule 17 | from ray.rllib.agents.callbacks import DefaultCallbacks 18 | 19 | from .environments.coverage import CoverageEnv 20 | from .environments.path_planning import PathPlanningEnv 21 | from .models.adversarial import AdversarialModel 22 | from .trainers.multiagent_ppo import MultiPPOTrainer 23 | from .trainers.hom_multi_action_dist import TorchHomogeneousMultiActionDistribution 24 | 25 | torch, _ = try_import_torch() 26 | 27 | def update_dict(d, u): 28 | for k, v in u.items(): 29 | if isinstance(v, collections.abc.Mapping): 30 | d[k] = update_dict(d.get(k, {}), v) 31 | else: 32 | d[k] = v 33 | return d 34 | 35 | def trial_dirname_creator(trial): 36 | return str(trial) #f"{ray.tune.trial.date_str()}_{trial}" 37 | 38 | def dir_path(string): 39 | if os.path.isdir(string): 40 | return string 41 | else: 42 | raise NotADirectoryError(string) 43 | 44 | def check_file(string): 45 | if os.path.isfile(string): 46 | return string 47 | else: 48 | raise FileNotFoundError(string) 49 | 50 | def get_config_base(): 51 | return Path(os.path.dirname(os.path.realpath(__file__))) / "config" 52 | 53 | class EvaluationCallbacks(DefaultCallbacks): 54 | def on_episode_start(self, worker, base_env, policies, episode, **kwargs): 55 | episode.user_data["reward_greedy"] = [] 56 | episode.user_data["reward_coop"] = [] 57 | 58 | def on_episode_step(self, worker, base_env, episode, **kwargs): 59 | ep_info = episode.last_info_for() 60 | if ep_info is not None and ep_info: 61 | episode.user_data["reward_greedy"].append(sum(ep_info['rewards_teams'][0].values())) 62 | episode.user_data["reward_coop"].append(sum(ep_info['rewards_teams'][1].values())) 63 | 64 | def on_episode_end(self, worker, base_env, policies, episode, **kwargs): 65 | episode.custom_metrics["reward_greedy"] = np.sum(episode.user_data["reward_greedy"]) 66 | episode.custom_metrics["reward_coop"] = np.sum(episode.user_data["reward_coop"]) 67 | 68 | ''' 69 | def on_train_result(self, trainer, result, **kwargs): 70 | greedy_mse_fac = trainer.config['model']['custom_model_config']['greedy_mse_fac'] 71 | if isinstance(greedy_mse_fac, list): 72 | s = PiecewiseSchedule(greedy_mse_fac[0], "torch", outside_value=greedy_mse_fac[1]) 73 | trainer.workers.foreach_worker( 74 | lambda w: w.foreach_policy( 75 | lambda p, p_id: p.model.update_config({'greedy_mse_fac': s(result['timesteps_total'])}))) 76 | ''' 77 | 78 | def initialize(): 79 | ray.init() 80 | register_env("coverage", lambda config: CoverageEnv(config)) 81 | register_env("path_planning", lambda config: PathPlanningEnv(config)) 82 | ModelCatalog.register_custom_model("adversarial", AdversarialModel) 83 | ModelCatalog.register_custom_action_dist("hom_multi_action", TorchHomogeneousMultiActionDistribution) 84 | 85 | def start_experiment(): 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("experiment") 88 | parser.add_argument("-o", "--override", help='Key in alternative_config from which to take data to override main config', default=None) 89 | parser.add_argument("-t", "--timesteps", help="Number of total time steps for training stop condition in millions", type=int, default=20) 90 | args = parser.parse_args() 91 | 92 | try: 93 | config_path = check_file(args.experiment) 94 | except FileNotFoundError: 95 | config_path = get_config_base() / (args.experiment + ".yaml") 96 | 97 | with open(config_path, "rb") as config_file: 98 | config = yaml.load(config_file) 99 | if args.override is not None: 100 | if not args.override in config['alternative_config']: 101 | print("Invalid alternative config key! Choose one from:") 102 | print(config['alternative_config'].keys()) 103 | exit() 104 | update_dict(config, config['alternative_config'][args.override]) 105 | config.pop('alternative_config', None) 106 | config['callbacks'] = EvaluationCallbacks 107 | 108 | initialize() 109 | tune.run( 110 | MultiPPOTrainer, 111 | checkpoint_freq=10, 112 | stop={"timesteps_total": args.timesteps*1e6}, 113 | keep_checkpoints_num=1, 114 | config=config, 115 | #local_dir="/tmp", 116 | trial_dirname_creator=trial_dirname_creator, 117 | ) 118 | 119 | def continue_experiment(): 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("checkpoint", type=dir_path) 122 | parser.add_argument("-t", "--timesteps", help="Number of total time steps for training stop condition in millions", type=int, default=20) 123 | parser.add_argument("-e", "--experiment", help="Path/id to training config", default=None) 124 | parser.add_argument("-o", "--override", help='Key in alternative_config from which to take data to override main config', default=None) 125 | 126 | args = parser.parse_args() 127 | 128 | with open(Path(args.checkpoint) / '..' / 'params.json', "rb") as config_file: 129 | config = json.load(config_file) 130 | 131 | if args.experiment is not None: 132 | try: 133 | config_path = check_file(args.experiment) 134 | except FileNotFoundError: 135 | config_path = get_config_base() / (args.experiment + ".yaml") 136 | 137 | with open(config_path, "rb") as config_file: 138 | update_dict(config, yaml.load(config_file)['alternative_config'][args.override]) 139 | 140 | config['callbacks'] = EvaluationCallbacks 141 | 142 | checkpoint_file = Path(args.checkpoint) / ('checkpoint-' + os.path.basename(args.checkpoint).split('_')[-1]) 143 | 144 | initialize() 145 | tune.run( 146 | MultiPPOTrainer, 147 | checkpoint_freq=20, 148 | stop={"timesteps_total": args.timesteps*1e6}, 149 | restore=checkpoint_file, 150 | keep_checkpoints_num=1, 151 | config=config, 152 | #local_dir="/tmp", 153 | trial_dirname_creator=trial_dirname_creator, 154 | ) 155 | 156 | if __name__ == '__main__': 157 | start_experiment() 158 | exit() 159 | 160 | 161 | ### Cooperative 162 | run_experiment("./config/coverage.yaml", {"timesteps_total": 20e6}, None) 163 | run_experiment("./config/coverage_split.yaml", {"timesteps_total": 3e6}, None) 164 | run_experiment("./config/path_planning.yaml", {"timesteps_total": 20e6}, None) 165 | 166 | ### Adversarial 167 | continue_experiment("checkpoint_cov", {"timesteps_total": 60e6}, "./config/coverage.yaml", "adversarial") 168 | continue_experiment("checkpoint_split", {"timesteps_total": 20e6}, "./config/coverage_split.yaml", "adversarial") 169 | continue_experiment("checkpoint_flow", {"timesteps_total": 60e6}, "./config/path_planning.yaml", "adversarial") 170 | 171 | ### Re-adapt 172 | continue_experiment("checkpoint_cov_adv", {"timesteps_total": 90e6}, "./config/coverage.yaml", "cooperative") 173 | continue_experiment("checkpoint_split_adv", {"timesteps_total": 30e6}, "./config/coverage_split.yaml", "cooperative") 174 | continue_experiment("checkpoint_flow_adv", {"timesteps_total": 90e6}, "./config/path_planning.yaml", "cooperative") 175 | 176 | 177 | -------------------------------------------------------------------------------- /adversarial_comms/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/adversarial_comms/889a41e239958bd519365c6b0c143469cd27e49f/adversarial_comms/trainers/__init__.py -------------------------------------------------------------------------------- /adversarial_comms/trainers/hom_multi_action_dist.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import tree 4 | from ray.rllib.models.torch.torch_action_dist import TorchMultiActionDistribution 5 | from ray.rllib.utils.annotations import override 6 | from ray.rllib.utils.framework import try_import_torch 7 | 8 | torch, nn = try_import_torch() 9 | 10 | 11 | class InvalidActionSpace(Exception): 12 | """Raised when the action space is invalid""" 13 | 14 | pass 15 | 16 | 17 | class TorchHomogeneousMultiActionDistribution(TorchMultiActionDistribution): 18 | @override(TorchMultiActionDistribution) 19 | def logp(self, x): 20 | logps = [] 21 | for i, (d, action_space) in enumerate( 22 | zip(self.flat_child_distributions, self.action_space_struct) 23 | ): 24 | if isinstance(action_space, gym.spaces.box.Box): 25 | assert len(action_space.shape) == 1 26 | a_w = action_space.shape[0] 27 | x_sel = x[:, a_w * i : a_w * (i + 1)] 28 | elif isinstance(action_space, gym.spaces.discrete.Discrete): 29 | x_sel = x[:, i] 30 | else: 31 | raise InvalidActionSpace( 32 | "Expect gym.spaces.box or gym.spaces.discrete action space" 33 | ) 34 | logps.append(d.logp(x_sel)) 35 | 36 | return torch.stack(logps, axis=1) 37 | 38 | @override(TorchMultiActionDistribution) 39 | def entropy(self): 40 | return torch.stack( 41 | [d.entropy() for d in self.flat_child_distributions], axis=-1 42 | ) 43 | 44 | @override(TorchMultiActionDistribution) 45 | def sampled_action_logp(self): 46 | return torch.stack( 47 | [d.sampled_action_logp() for d in self.flat_child_distributions], axis=-1 48 | ) 49 | 50 | @override(TorchMultiActionDistribution) 51 | def kl(self, other): 52 | return torch.stack( 53 | [ 54 | d.kl(o) 55 | for d, o in zip( 56 | self.flat_child_distributions, other.flat_child_distributions 57 | ) 58 | ], 59 | axis=-1, 60 | ) 61 | -------------------------------------------------------------------------------- /adversarial_comms/trainers/multiagent_ppo.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch policy class used for PPO. 3 | """ 4 | import gym 5 | import logging 6 | import numpy as np 7 | from typing import Dict, List, Optional, Type, Union 8 | 9 | import ray 10 | from ray.rllib.agents.ppo.ppo_tf_policy import setup_config 11 | from ray.rllib.agents.ppo.ppo_torch_policy import kl_and_loss_stats, \ 12 | vf_preds_fetches, setup_mixins, KLCoeffMixin, ValueNetworkMixin 13 | from ray.rllib.agents.trainer_template import build_trainer 14 | from ray.rllib.evaluation.episode import MultiAgentEpisode 15 | from ray.rllib.evaluation.postprocessing import compute_advantages, \ 16 | Postprocessing 17 | from ray.rllib.models.modelv2 import ModelV2 18 | from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper 19 | from ray.rllib.policy.policy import Policy 20 | from ray.rllib.policy.policy_template import build_policy_class 21 | from ray.rllib.policy.sample_batch import SampleBatch 22 | from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \ 23 | LearningRateSchedule 24 | from ray.rllib.utils.framework import try_import_torch 25 | from ray.rllib.utils.torch_ops import apply_grad_clipping, \ 26 | convert_to_torch_tensor, explained_variance, sequence_mask 27 | from ray.rllib.utils.typing import TensorType, TrainerConfigDict, AgentID 28 | 29 | torch, nn = try_import_torch() 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | class InvalidActionSpace(Exception): 34 | """Raised when the action space is invalid""" 35 | pass 36 | 37 | 38 | def compute_gae_for_sample_batch( 39 | policy: Policy, 40 | sample_batch: SampleBatch, 41 | other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, 42 | episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: 43 | """Adds GAE (generalized advantage estimations) to a trajectory. 44 | The trajectory contains only data from one episode and from one agent. 45 | - If `config.batch_mode=truncate_episodes` (default), sample_batch may 46 | contain a truncated (at-the-end) episode, in case the 47 | `config.rollout_fragment_length` was reached by the sampler. 48 | - If `config.batch_mode=complete_episodes`, sample_batch will contain 49 | exactly one episode (no matter how long). 50 | New columns can be added to sample_batch and existing ones may be altered. 51 | Args: 52 | policy (Policy): The Policy used to generate the trajectory 53 | (`sample_batch`) 54 | sample_batch (SampleBatch): The SampleBatch to postprocess. 55 | other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional 56 | dict of AgentIDs mapping to other agents' trajectory data (from the 57 | same episode). NOTE: The other agents use the same policy. 58 | episode (Optional[MultiAgentEpisode]): Optional multi-agent episode 59 | object in which the agents operated. 60 | Returns: 61 | SampleBatch: The postprocessed, modified SampleBatch (or a new one). 62 | """ 63 | 64 | # the trajectory view API will pass populate the info dict with a np.zeros((n,)) 65 | # array in the first call, in that case the dtype will be float32 and we 66 | # have to ignore it. For regular calls, we extract the rewards from the info 67 | # dict into the samplebatch_infos_rewards dict, which now holds the rewards 68 | # for all agents as dict. 69 | samplebatch_infos_rewards = {'0': sample_batch[SampleBatch.INFOS]} 70 | if not sample_batch[SampleBatch.INFOS].dtype == "float32": 71 | samplebatch_infos = SampleBatch.concat_samples([ 72 | SampleBatch({k: [v] for k, v in s.items()}) 73 | for s in sample_batch[SampleBatch.INFOS] 74 | ]) 75 | samplebatch_infos_rewards = SampleBatch.concat_samples([ 76 | SampleBatch({str(k): [v] for k, v in s.items()}) 77 | for s in samplebatch_infos["rewards"] 78 | ]) 79 | 80 | if not isinstance(policy.action_space, gym.spaces.tuple.Tuple): 81 | raise InvalidActionSpace("Expect tuple action space") 82 | 83 | # samplebatches for each agents 84 | batches = [] 85 | for key, action_space in zip(samplebatch_infos_rewards.keys(), policy.action_space): 86 | i = int(key) 87 | sample_batch_agent = sample_batch.copy() 88 | sample_batch_agent[SampleBatch.REWARDS] = (samplebatch_infos_rewards[key]) 89 | if isinstance(action_space, gym.spaces.box.Box): 90 | assert len(action_space.shape) == 1 91 | a_w = action_space.shape[0] 92 | elif isinstance(action_space, gym.spaces.discrete.Discrete): 93 | a_w = 1 94 | else: 95 | raise InvalidActionSpace("Expect gym.spaces.box or gym.spaces.discrete action space") 96 | 97 | sample_batch_agent[SampleBatch.ACTIONS] = sample_batch[SampleBatch.ACTIONS][:, a_w * i : a_w * (i + 1)] 98 | sample_batch_agent[SampleBatch.VF_PREDS] = sample_batch[SampleBatch.VF_PREDS][:, i] 99 | 100 | # Trajectory is actually complete -> last r=0.0. 101 | if sample_batch[SampleBatch.DONES][-1]: 102 | last_r = 0.0 103 | # Trajectory has been truncated -> last r=VF estimate of last obs. 104 | else: 105 | # Input dict is provided to us automatically via the Model's 106 | # requirements. It's a single-timestep (last one in trajectory) 107 | # input_dict. 108 | # Create an input dict according to the Model's requirements. 109 | input_dict = policy.model.get_input_dict( 110 | sample_batch, index="last") 111 | all_values = policy._value(**input_dict, seq_lens=input_dict.seq_lens) 112 | last_r = all_values[i].item() 113 | 114 | # Adds the policy logits, VF preds, and advantages to the batch, 115 | # using GAE ("generalized advantage estimation") or not. 116 | batches.append( 117 | compute_advantages( 118 | sample_batch_agent, 119 | last_r, 120 | policy.config["gamma"], 121 | policy.config["lambda"], 122 | use_gae=policy.config["use_gae"], 123 | use_critic=policy.config.get("use_critic", True) 124 | ) 125 | ) 126 | 127 | # Now take original samplebatch and overwrite following elements as a concatenation of these 128 | for k in [ 129 | SampleBatch.REWARDS, 130 | SampleBatch.VF_PREDS, 131 | Postprocessing.ADVANTAGES, 132 | Postprocessing.VALUE_TARGETS, 133 | ]: 134 | sample_batch[k] = np.stack([b[k] for b in batches], axis=-1) 135 | 136 | return sample_batch 137 | 138 | 139 | def ppo_surrogate_loss( 140 | policy: Policy, model: ModelV2, 141 | dist_class: Type[TorchDistributionWrapper], 142 | train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: 143 | """Constructs the loss for Proximal Policy Objective. 144 | Args: 145 | policy (Policy): The Policy to calculate the loss for. 146 | model (ModelV2): The Model to calculate the loss for. 147 | dist_class (Type[ActionDistribution]: The action distr. class. 148 | train_batch (SampleBatch): The training data. 149 | Returns: 150 | Union[TensorType, List[TensorType]]: A single loss tensor or a list 151 | of loss tensors. 152 | """ 153 | logits, state = model.from_batch(train_batch, is_training=True) 154 | curr_action_dist = dist_class(logits, model) 155 | 156 | # RNN case: Mask away 0-padded chunks at end of time axis. 157 | if state: 158 | B = len(train_batch["seq_lens"]) 159 | max_seq_len = logits.shape[0] // B 160 | mask = sequence_mask( 161 | train_batch["seq_lens"], 162 | max_seq_len, 163 | time_major=model.is_time_major()) 164 | mask = torch.reshape(mask, [-1]) 165 | num_valid = torch.sum(mask) 166 | 167 | def reduce_mean_valid(t): 168 | return torch.sum(t[mask]) / num_valid 169 | 170 | # non-RNN case: No masking. 171 | else: 172 | mask = None 173 | reduce_mean_valid = torch.mean 174 | 175 | loss_data = [] 176 | 177 | curr_action_dist = dist_class(logits, model) 178 | prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], 179 | model) 180 | logps = curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) 181 | entropies = curr_action_dist.entropy() 182 | 183 | action_kl = prev_action_dist.kl(curr_action_dist) 184 | mean_kl = reduce_mean_valid(torch.sum(action_kl, axis=1)) 185 | 186 | for i in range(len(train_batch[SampleBatch.VF_PREDS][0])): 187 | logp_ratio = torch.exp( 188 | logps[:, i] - 189 | train_batch[SampleBatch.ACTION_LOGP][:, i]) 190 | 191 | mean_entropy = reduce_mean_valid(entropies[:, i]) 192 | 193 | surrogate_loss = torch.min( 194 | train_batch[Postprocessing.ADVANTAGES][..., i] * logp_ratio, 195 | train_batch[Postprocessing.ADVANTAGES][..., i] * torch.clamp( 196 | logp_ratio, 1 - policy.config["clip_param"], 197 | 1 + policy.config["clip_param"])) 198 | mean_policy_loss = reduce_mean_valid(-surrogate_loss) 199 | 200 | if policy.config["use_gae"]: 201 | prev_value_fn_out = train_batch[SampleBatch.VF_PREDS][..., i] 202 | value_fn_out = model.value_function()[..., i] 203 | vf_loss1 = torch.pow( 204 | value_fn_out - train_batch[Postprocessing.VALUE_TARGETS][..., i], 2.0) 205 | vf_clipped = prev_value_fn_out + torch.clamp( 206 | value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], 207 | policy.config["vf_clip_param"]) 208 | vf_loss2 = torch.pow( 209 | vf_clipped - train_batch[Postprocessing.VALUE_TARGETS][..., i], 2.0) 210 | vf_loss = torch.max(vf_loss1, vf_loss2) 211 | mean_vf_loss = reduce_mean_valid(vf_loss) 212 | total_loss = reduce_mean_valid( 213 | -surrogate_loss + policy.kl_coeff * action_kl[:, i] + 214 | policy.config["vf_loss_coeff"] * vf_loss - 215 | policy.entropy_coeff * entropies[:, i]) 216 | else: 217 | mean_vf_loss = 0.0 218 | total_loss = reduce_mean_valid(-surrogate_loss + 219 | policy.kl_coeff * action_kl[:, i] - 220 | policy.entropy_coeff * entropies[:, i]) 221 | 222 | # Store stats in policy for stats_fn. 223 | loss_data.append( 224 | { 225 | "total_loss": total_loss, 226 | "mean_policy_loss": mean_policy_loss, 227 | "mean_vf_loss": mean_vf_loss, 228 | "mean_entropy": mean_entropy, 229 | } 230 | ) 231 | 232 | policy._total_loss = (torch.sum(torch.stack([o["total_loss"] for o in loss_data])),) 233 | policy._mean_policy_loss = torch.mean( 234 | torch.stack([o["mean_policy_loss"] for o in loss_data]) 235 | ) 236 | policy._mean_vf_loss = torch.mean( 237 | torch.stack([o["mean_vf_loss"] for o in loss_data]) 238 | ) 239 | policy._mean_entropy = torch.mean( 240 | torch.stack([o["mean_entropy"] for o in loss_data]) 241 | ) 242 | policy._vf_explained_var = explained_variance( 243 | train_batch[Postprocessing.VALUE_TARGETS], 244 | policy.model.value_function()) 245 | policy._mean_kl = mean_kl 246 | 247 | return policy._total_loss 248 | 249 | 250 | class ValueNetworkMixin: 251 | """This is exactly the same mixin class as in ppo_torch_policy, 252 | but that one calls .item() on self.model.value_function()[0], 253 | which will not work for us since our value function returns 254 | multiple values. Instead, we call .item() in 255 | compute_gae_for_sample_batch above. 256 | """ 257 | 258 | def __init__(self, obs_space, action_space, config): 259 | if config["use_gae"]: 260 | 261 | def value(**input_dict): 262 | input_dict = SampleBatch(input_dict) 263 | input_dict = self._lazy_tensor_dict(input_dict) 264 | model_out, _ = self.model(input_dict) 265 | # [0] = remove the batch dim. 266 | return self.model.value_function()[0] 267 | 268 | else: 269 | 270 | def value(*args, **kwargs): 271 | return 0.0 272 | 273 | self._value = value 274 | 275 | 276 | def setup_mixins_override(policy: Policy, obs_space: gym.spaces.Space, 277 | action_space: gym.spaces.Space, 278 | config: TrainerConfigDict) -> None: 279 | """Have to initialize the custom ValueNetworkMixin 280 | """ 281 | setup_mixins(policy, obs_space, action_space, config) 282 | ValueNetworkMixin.__init__(policy, obs_space, action_space, config) 283 | 284 | 285 | # Build a child class of `TorchPolicy`, given the custom functions defined 286 | # above. 287 | MultiPPOTorchPolicy = build_policy_class( 288 | name="MultiPPOTorchPolicy", 289 | framework="torch", 290 | get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, 291 | loss_fn=ppo_surrogate_loss, 292 | stats_fn=kl_and_loss_stats, 293 | extra_action_out_fn=vf_preds_fetches, 294 | postprocess_fn=compute_gae_for_sample_batch, 295 | extra_grad_process_fn=apply_grad_clipping, 296 | before_init=setup_config, 297 | before_loss_init=setup_mixins_override, 298 | mixins=[ 299 | LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, 300 | ValueNetworkMixin 301 | ], 302 | ) 303 | 304 | def get_policy_class(config): 305 | return MultiPPOTorchPolicy 306 | 307 | MultiPPOTrainer = build_trainer( 308 | name="MultiPPO", 309 | default_config=ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, 310 | validate_config=ray.rllib.agents.ppo.ppo.validate_config, 311 | default_policy=MultiPPOTorchPolicy, 312 | get_policy_class=get_policy_class, 313 | execution_plan=ray.rllib.agents.ppo.ppo.execution_plan 314 | ) 315 | -------------------------------------------------------------------------------- /adversarial_comms/trainers/random_heuristic.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import ray 4 | import random 5 | 6 | import numpy as np 7 | 8 | from enum import Enum 9 | from gym import spaces 10 | from ray.rllib import Policy 11 | from ray.rllib.agents import with_common_config 12 | from ray.rllib.agents.trainer_template import build_trainer 13 | from ray.rllib.evaluation.worker_set import WorkerSet 14 | from ray.rllib.execution.metric_ops import StandardMetricsReporting 15 | from ray.rllib.execution.rollout_ops import ParallelRollouts, SelectExperiences 16 | from ray.rllib.models.modelv2 import restore_original_dimensions 17 | from ray.rllib.utils import override 18 | from ray.rllib.utils.typing import TrainerConfigDict 19 | from ray.util.iter import LocalIterator 20 | from ray.tune.registry import register_env 21 | 22 | DEFAULT_CONFIG = with_common_config({}) 23 | 24 | class Action(Enum): 25 | NOP = 0 26 | MOVE_RIGHT = 1 27 | MOVE_LEFT = 2 28 | MOVE_UP = 3 29 | MOVE_DOWN = 4 30 | 31 | X = 1 32 | Y = 0 33 | 34 | class RandomHeuristicPolicy(Policy, ABC): 35 | """ 36 | Based on 37 | https://github.com/ray-project/ray/blob/releases/1.0.1/rllib/examples/policy/random_policy.py 38 | Visit a random uncovered neighboring cell or a random cell if all are covered 39 | """ 40 | 41 | def __init__(self, *args, **kwargs): 42 | super().__init__(*args, **kwargs) 43 | 44 | def single_random_heuristic(self, obs): 45 | state_obstacles, state_coverage = (obs[:, :, i] for i in range(2)) 46 | half_state_shape = (np.array(state_obstacles.shape)/2).astype(int) 47 | actions_deltas = { 48 | Action.MOVE_RIGHT.value: [ 0, 1], 49 | Action.MOVE_LEFT.value: [ 0, -1], 50 | Action.MOVE_UP.value: [-1, 0], 51 | Action.MOVE_DOWN.value: [ 1, 0], 52 | } 53 | 54 | options_free = [] 55 | options_uncovered = [] 56 | for a, dp in actions_deltas.items(): 57 | p = half_state_shape + dp 58 | if state_obstacles[p[Y], p[X]] > 0: 59 | continue 60 | options_free.append(a) 61 | 62 | if state_coverage[p[Y], p[X]] > 0: 63 | continue 64 | options_uncovered.append(a) 65 | 66 | if len(options_uncovered) > 0: 67 | return random.choice(options_uncovered) 68 | elif len(options_free) > 0: 69 | return random.choice(options_free) 70 | return NOP.value 71 | 72 | @override(Policy) 73 | def compute_actions(self, 74 | obs_batch, 75 | state_batches=None, 76 | prev_action_batch=None, 77 | prev_reward_batch=None, 78 | info_batch=None, 79 | episodes=None, 80 | **kwargs): 81 | 82 | obs_batch = restore_original_dimensions( 83 | np.array(obs_batch, dtype=np.float32), 84 | self.observation_space, 85 | tensorlib=np) 86 | 87 | r = np.array([[self.single_random_heuristic(map_batch) for map_batch in agent['map']] for agent in obs_batch['agents']]) 88 | return r.transpose(), [], {} 89 | 90 | def learn_on_batch(self, samples): 91 | pass 92 | 93 | def get_weights(self): 94 | pass 95 | 96 | def set_weights(self, weights): 97 | pass 98 | 99 | 100 | def execution_plan(workers: WorkerSet, 101 | config: TrainerConfigDict) -> LocalIterator[dict]: 102 | rollouts = ParallelRollouts(workers, mode="async") 103 | 104 | # Collect batches for the trainable policies. 105 | rollouts = rollouts.for_each( 106 | SelectExperiences(workers.trainable_policies())) 107 | 108 | # Return training metrics. 109 | return StandardMetricsReporting(rollouts, workers, config) 110 | 111 | 112 | RandomHeuristicTrainer = build_trainer( 113 | name="RandomHeuristic", 114 | default_config=DEFAULT_CONFIG, 115 | default_policy=RandomHeuristicPolicy, 116 | execution_plan=execution_plan) 117 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | ray[rllib]==1.3.0 3 | matplotlib==3.4.1 4 | sklearn==0.24.1 5 | torchsummary==1.5.1 6 | 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import setuptools 3 | 4 | package_dir = os.path.dirname(os.path.realpath(__file__)) 5 | 6 | with open(package_dir + "/README.md", "r") as fh: 7 | long_description = fh.read() 8 | 9 | requirements_dir = package_dir + '/requirements.txt' 10 | install_requires = [] 11 | with open(requirements_dir) as f: 12 | install_requires = f.read().splitlines() 13 | 14 | setuptools.setup( 15 | name="adversarial-comms", 16 | version="1.1", 17 | author="Jan Blumenkamp", 18 | author_email="jb2270@cam.ac.uk", 19 | description="Package accompanying the paper 'The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning'", 20 | long_description=long_description, 21 | long_description_content_type="text/markdown", 22 | url="https://github.com/proroklab/adversarial_comms", 23 | packages=setuptools.find_packages(), 24 | classifiers=[ 25 | "Programming Language :: Python :: 3", 26 | "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)", 27 | "Operating System :: OS Independent", 28 | ], 29 | install_requires=install_requires, 30 | entry_points = { 31 | 'console_scripts': [ 32 | 'train_policy=adversarial_comms.train_policy:start_experiment', 33 | 'continue_policy=adversarial_comms.train_policy:continue_experiment', 34 | 'evaluate_coop=adversarial_comms.evaluate:eval_nocomm_coop', 35 | 'evaluate_adv=adversarial_comms.evaluate:eval_nocomm_adv', 36 | 'evaluate_random=adversarial_comms.evaluate:eval_random', 37 | 'evaluate_plot=adversarial_comms.evaluate:plot', 38 | 'evaluate_serve=adversarial_comms.evaluate:serve' 39 | ], 40 | }, 41 | python_requires='>=3.7', 42 | ) 43 | --------------------------------------------------------------------------------