├── .gitignore
├── LICENSE
├── README.md
├── adversarial_comms
├── __init__.py
├── config
│ ├── coverage.yaml
│ ├── coverage_split.yaml
│ └── path_planning.yaml
├── environments
│ ├── __init__.py
│ ├── coverage.py
│ └── path_planning.py
├── evaluate.py
├── generate_dataset.py
├── models
│ ├── __init__.py
│ ├── adversarial.py
│ └── gnn
│ │ ├── __init__.py
│ │ ├── adversarialGraphML.py
│ │ ├── graphML.py
│ │ └── graphTools.py
├── train_interpreter.py
├── train_policy.py
└── trainers
│ ├── __init__.py
│ ├── hom_multi_action_dist.py
│ ├── multiagent_ppo.py
│ └── random_heuristic.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adversarial Comms
2 | Code accompanying the paper
3 | > [The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning](https://arxiv.org/abs/2008.02616)\
4 | > Jan Blumenkamp, Amanda Prorok\
5 | > (University of Cambridge)\
6 | > _arXiv: 2008.02616_.
7 |
8 | The five minute video presentation for CoRL 2020:
9 |
10 | [](https://www.youtube.com/watch?v=SUDkRaj4FAI)
11 |
12 | Supplementary video material:
13 |
14 | [](https://www.youtube.com/watch?v=o1Nq9XoSU6U)
15 |
16 | ## Installation
17 | Clone the repository, change directory into its root and run:
18 | ```
19 | pip install -e .
20 | ```
21 | This will install the package and all requirements. It will also set up the entry points we are referring to later in these instructions.
22 |
23 | ## Training
24 | Generally, training is performed for the policies and for the interpreter. We first explain the three policy training steps (cooperative, self-interested, and re-adaptation) for all three experiments (coverage, split coverage and path planning) and then for the interpreters.
25 |
26 | The policy training follows this scheme:
27 | ```
28 | train_policy [experiment] -t [total time steps in millions]
29 | continue_policy [cooperative checkpoint path] -t [total time steps] -e [experiment] -o self_interested
30 | continue_policy [self-interested checkpoint path] -t [total time steps] -e [experiment] -o re_adapt
31 | ```
32 | where `experiment` is one of `{coverage, coverage_split, path_planning}`, `-t` is the total number of time steps at which the experiment is to be terminated (note that this is not per call, but total time steps, so if a policy is trained with `train_policy -t 20` and then continued with `continue_policy -t 20` it will terminate immediately) and `-o` is a config option (one of `{self_interested, re_adapt}` as can be found in the `alternative_config` key in each of the config files in `config`).
33 |
34 | When running each experiment, Ray will print the trial name to the terminal, which looks something like `MultiPPO_coverage_f4dc4_00000`. By default, Ray will create the directory `~/ray_results/MultiPPO` in which the trial with the given name can be found with its checkpoint. `continue_policy` expects the path to one of such checkpoints, for example `~/ray_results/MultiPPO/MultiPPO_coverage_f4dc4_00000/checkpoint_440`. The first `continue_policy` expects the checkpoint generated in the first `train_policy` call and the second `continue_policy` the checkpoint generated in the first `continue_policy` call. You should take note of each experiment's checkpoint path.
35 |
36 | ### Standard Coverage
37 | ```
38 | train_policy coverage -t 20
39 | continue_policy [cooperative checkpoint path] -t 60 -e coverage -o self_interested
40 | continue_policy [adversarial checkpoint path] -t 80 -e coverage -o re_adapt
41 | ```
42 |
43 | ### Split coverage
44 | ```
45 | train_policy coverage_split -t 3
46 | continue_policy [cooperative checkpoint path] -t 20 -e coverage_split -o self_interested
47 | continue_policy [adversarial checkpoint path] -t 30 -e coverage_split -o re_adapt
48 | ```
49 |
50 | ### Path Planning
51 | ```
52 | train_policy path_planning -t 20
53 | continue_policy [cooperative checkpoint path] -t 60 -e path_planning -o self_interested
54 | continue_policy [adversarial checkpoint path] -t 80 -e path_planning -o re_adapt
55 | ```
56 |
57 | ## Evaluation
58 | We provide three methods for evaluation:
59 |
60 | 1) `evaluate_coop`: Evaluate cooperative only performance while disabling self-interested agents with and without communication among cooperative agents.
61 | 2) `evaluate_adv`: Evaluate cooperative and self-interested agents with and without communication between cooperative and self-interested agents (cooperative agents can always communicate to each other).
62 | 3) `evaluate_random`: Run a random policy that visits random neighboring (preferably uncovered) cells.
63 |
64 | The evaluation is run as
65 | ```
66 | evaluate_{coop, adv} [checkpoint path] [result path] --trials 100
67 | evaluate_random [result path] --trials 100
68 | ```
69 | for 100 evaluation runs with different seeds. The resulting file is a Pandas dataframe containing the rewards for all agents at every time step. It can be processed and visualized by running `evaluate_plot [pickled data path]`.
70 |
71 | Additionally, a checkpoint can be rolled out and rendered for a randomly generated environment with `evaluate_serve [checkpoint_path] --seed 0`.
72 |
73 | ## Citation
74 | If you use any part of this code in your research, please cite our paper:
75 | ```
76 | @article{blumenkamp2020adversarial,
77 | title={The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning},
78 | author={Blumenkamp, Jan and Prorok, Amanda},
79 | journal={Conference on Robot Learning (CoRL)},
80 | year={2020}
81 | }
82 | ```
83 |
--------------------------------------------------------------------------------
/adversarial_comms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proroklab/adversarial_comms/889a41e239958bd519365c6b0c143469cd27e49f/adversarial_comms/__init__.py
--------------------------------------------------------------------------------
/adversarial_comms/config/coverage.yaml:
--------------------------------------------------------------------------------
1 | framework: torch
2 | env: coverage
3 | lambda: 0.95
4 | kl_coeff: 0.5
5 | kl_target: 0.01
6 | clip_rewards: True
7 | clip_param: 0.2
8 | vf_clip_param: 250.0
9 | vf_share_layers: False
10 | vf_loss_coeff: 1.0e-4
11 | entropy_coeff: 0.01
12 | train_batch_size: 5000
13 | rollout_fragment_length: 100
14 | sgd_minibatch_size: 1000
15 | num_sgd_iter: 5
16 | num_workers: 7
17 | num_envs_per_worker: 16
18 | lr: 5.0e-4
19 | gamma: 0.9
20 | batch_mode: truncate_episodes
21 | observation_filter: NoFilter
22 | num_gpus: 0.5
23 | num_gpus_per_worker: 0.0625
24 | model:
25 | custom_model: adversarial
26 | custom_action_dist: hom_multi_action
27 | custom_model_config:
28 | graph_layers: 1
29 | graph_tabs: 2
30 | graph_edge_features: 1
31 | graph_features: 128
32 | cnn_filters: [[8, [4, 4], 2], [16, [4, 4], 2], [32, [3, 3], 2]]
33 | value_cnn_filters: [[8, [4, 4], 2], [16, [4, 4], 2], [32, [4, 4], 2]]
34 | value_cnn_compression: 128
35 | cnn_compression: 32
36 | pre_gnn_mlp: [64, 128, 32]
37 | gp_kernel_size: 16
38 | graph_aggregation: sum
39 | relative: true
40 | activation: relu
41 | freeze_coop: False
42 | freeze_greedy: False
43 | freeze_coop_value: False
44 | freeze_greedy_value: False
45 | cnn_residual: False
46 | agent_split: 1
47 | greedy_mse_fac: 0.0
48 | env_config:
49 | world_shape: [24, 24]
50 | state_size: 16
51 | collapse_state: False
52 | termination_no_new_coverage: 10
53 | max_episode_len: 345 # 24*24*0.6
54 | n_agents: [1, 5]
55 | disabled_teams_step: [True, False]
56 | disabled_teams_comms: [True, False]
57 | min_coverable_area_fraction: 0.6
58 | map_mode: random
59 | reward_annealing: 0.0
60 | communication_range: 16.0
61 | ensure_connectivity: True
62 | reward_type: semi_cooperative #semi_cooperative/cooperative
63 | episode_termination: early # early/fixed/default
64 | operation_mode: coop_only
65 | evaluation_num_workers: 1
66 | evaluation_interval: 1
67 | evaluation_num_episodes: 10
68 | evaluation_config:
69 | env_config:
70 | termination_no_new_coverage: -1
71 | max_episode_len: 345 # 24*24*0.6
72 | episode_termination: default
73 | operation_mode: all
74 | ensure_connectivity: False
75 | logger_config:
76 | wandb:
77 | project: adv_paper
78 | #project: vaegp_0920
79 | group: revised_gp
80 | api_key_file: "./wandb_api_key_file"
81 | alternative_config:
82 | self_interested:
83 | # adversarial case in co-training
84 | evaluation_num_workers: 1
85 | num_workers: 7
86 | num_envs_per_worker: 64
87 | rollout_fragment_length: 100
88 | num_gpus_per_worker: 0.0625
89 | num_gpus: 0.5
90 | env_config:
91 | operation_mode: greedy_only
92 | disabled_teams_step: [False, False]
93 | disabled_teams_comms: [False, False]
94 | n_agents: [1, 5]
95 | model:
96 | custom_model_config:
97 | freeze_coop: True
98 | freeze_greedy: False
99 | adversarial:
100 | evaluation_num_workers: 1
101 | num_workers: 7
102 | num_envs_per_worker: 64
103 | rollout_fragment_length: 100
104 | num_gpus_per_worker: 0.0625
105 | num_gpus: 0.5
106 |
107 | env_config:
108 | operation_mode: adversary_only
109 | disabled_teams_step: [False, False]
110 | disabled_teams_comms: [False, False]
111 | termination_no_new_coverage: -1
112 | max_episode_len: 173 # 24*24*0.6
113 | episode_termination: default
114 | model:
115 | custom_model_config:
116 | freeze_coop: True
117 | freeze_greedy: False
118 | re_adapt:
119 | env_config:
120 | operation_mode: coop_only
121 | disabled_teams_step: [False, False]
122 | disabled_teams_comms: [False, False]
123 | model:
124 | custom_model_config:
125 | freeze_coop: False
126 | freeze_greedy: True
127 | adversarial_abundance:
128 | # adversarial case in co-training
129 | env_config:
130 | #map_mode: random_teams_far
131 | map_mode: split_half_fixed_block
132 | #map_mode: split_half_fixed_block_same_side
133 | communication_range: 8.0
134 | model:
135 | custom_model_config:
136 | graph_tabs: 3
137 | logger_config:
138 | wandb:
139 | project: vaegp_0920
140 |
141 |
--------------------------------------------------------------------------------
/adversarial_comms/config/coverage_split.yaml:
--------------------------------------------------------------------------------
1 | framework: torch
2 | env: coverage
3 | lambda: 0.95
4 | kl_coeff: 0.5
5 | kl_target: 0.01
6 | clip_rewards: True
7 | clip_param: 0.2
8 | vf_clip_param: 250.0
9 | vf_share_layers: False
10 | vf_loss_coeff: 1.0e-4
11 | entropy_coeff: 0.01
12 | train_batch_size: 5000
13 | rollout_fragment_length: 100
14 | sgd_minibatch_size: 1000
15 | num_sgd_iter: 5
16 | num_workers: 16
17 | num_envs_per_worker: 8
18 | lr: 5.0e-4
19 | gamma: 0.9
20 | batch_mode: truncate_episodes
21 | observation_filter: NoFilter
22 | num_gpus: 1
23 | model:
24 | custom_model: adversarial
25 | custom_action_dist: hom_multi_action
26 | custom_model_config:
27 | graph_layers: 1
28 | graph_tabs: 2
29 | graph_edge_features: 1
30 |
31 | # 16
32 | graph_features: 32
33 | cnn_filters: [[8, [4, 4], 2], [16, [4, 4], 2], [32, [3, 3], 2]]
34 | value_cnn_filters: [[8, [4, 4], 2], [16, [4, 4], 2], [32, [4, 4], 2]]
35 | value_cnn_compression: 128
36 | cnn_compression: 32
37 |
38 | relative: true
39 | activation: relu
40 | freeze_coop: False
41 | freeze_greedy: False
42 | freeze_coop_value: False
43 | freeze_greedy_value: False
44 | cnn_residual: False
45 | agent_split: 1
46 | env_config:
47 | world_shape: [24, 24]
48 | state_size: 16
49 | collapse_state: False
50 | termination_no_new_coverage: 10
51 | max_episode_len: 288 # (24*24)/2
52 | n_agents: [1, 5]
53 | disabled_teams_step: [True, False]
54 | disabled_teams_comms: [True, False]
55 | map_mode: split_half_fixed
56 | reward_annealing: 0.0
57 | communication_range: 16.0
58 | ensure_connectivity: True
59 | reward_type: split_right
60 | episode_termination: early_right # early/fixed/default
61 | operation_mode: coop_only
62 | agents:
63 | coverage_radius: 1
64 | visibility_distance: 0
65 | map_update_radius: 100
66 | relative_coord_frame: True
67 | evaluation_num_workers: 2
68 | evaluation_interval: 1
69 | evaluation_num_episodes: 10
70 | evaluation_config:
71 | env_config:
72 | termination_no_new_coverage: -1
73 | max_episode_len: 288 # (24*24)/2
74 | episode_termination: default
75 | operation_mode: all
76 | ensure_connectivity: False
77 | alternative_config:
78 | self_interested:
79 | env_config:
80 | operation_mode: greedy_only
81 | disabled_teams_step: [False, False]
82 | disabled_teams_comms: [False, False]
83 | model:
84 | custom_model_config:
85 | freeze_coop: True
86 | freeze_greedy: False
87 | re_adapt:
88 | env_config:
89 | operation_mode: coop_only
90 | disabled_teams_step: [False, False]
91 | disabled_teams_comms: [False, False]
92 | model:
93 | custom_model_config:
94 | freeze_coop: False
95 | freeze_greedy: True
96 |
--------------------------------------------------------------------------------
/adversarial_comms/config/path_planning.yaml:
--------------------------------------------------------------------------------
1 | framework: torch
2 | env: path_planning
3 | lambda: 0.95
4 | kl_coeff: 0.5
5 | kl_target: 0.01
6 | clip_rewards: True
7 | clip_param: 0.2
8 | vf_clip_param: 250.0
9 | vf_share_layers: False
10 | vf_loss_coeff: 1.0e-4
11 | entropy_coeff: 0.01
12 | train_batch_size: 5000
13 | rollout_fragment_length: 100
14 | sgd_minibatch_size: 1000
15 | num_sgd_iter: 5
16 | num_workers: 7
17 | num_envs_per_worker: 16
18 | lr: 4.0e-4
19 | gamma: 0.99
20 | batch_mode: complete_episodes
21 | observation_filter: NoFilter
22 | num_gpus: 1.0
23 | model:
24 | custom_model: adversarial
25 | custom_action_dist: hom_multi_action
26 | custom_model_config:
27 | graph_layers: 1
28 | graph_tabs: 2
29 | graph_edge_features: 1
30 | graph_features: 128
31 | cnn_filters: [[8, [4, 4], 2], [16, [4, 4], 2], [32, [3, 3], 2]]
32 | value_cnn_filters: [[16, [4, 4], 2], [32, [4, 4], 2]]
33 | value_cnn_compression: 128
34 | cnn_compression: 32
35 | gp_kernel_size: 16
36 | graph_aggregation: sum
37 | activation: relu
38 | freeze_coop: False
39 | freeze_greedy: False
40 | freeze_coop_value: False
41 | freeze_greedy_value: False
42 | cnn_residual: False
43 | agent_split: 1
44 | env_config:
45 | world_shape: [12, 12]
46 | state_size: 16
47 | max_episode_len: 50
48 | n_agents: [1, 15]
49 | disabled_teams_step: [True, False]
50 | disabled_teams_comms: [True, False]
51 | communication_range: 5.0
52 | ensure_connectivity: True
53 | reward_type: coop_only
54 | world_mode: warehouse
55 | agents:
56 | visibility_distance: 0
57 | relative_coord_frame: True
58 | evaluation_num_workers: 1
59 | evaluation_interval: 1
60 | evaluation_num_episodes: 10
61 | evaluation_config:
62 | env_config:
63 | reward_type: local
64 | alternative_config:
65 | self_interested:
66 | env_config:
67 | reward_type: greedy_only
68 | disabled_teams_step: [False, False]
69 | disabled_teams_comms: [False, False]
70 | model:
71 | custom_model_config:
72 | freeze_coop: True
73 | freeze_greedy: False
74 | evaluation_config:
75 | env_config:
76 | reward_type: local
77 | re_adapt:
78 | env_config:
79 | reward_type: coop_only
80 | disabled_teams_step: [False, False]
81 | disabled_teams_comms: [False, False]
82 | model:
83 | custom_model_config:
84 | freeze_coop: False
85 | freeze_greedy: True
86 | evaluation_config:
87 | env_config:
88 | reward_type: local
89 |
--------------------------------------------------------------------------------
/adversarial_comms/environments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proroklab/adversarial_comms/889a41e239958bd519365c6b0c143469cd27e49f/adversarial_comms/environments/__init__.py
--------------------------------------------------------------------------------
/adversarial_comms/environments/coverage.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | import matplotlib.pyplot as plt
4 | import matplotlib.patches as patches
5 | import gym
6 | from gym import spaces
7 | from gym.utils import seeding, EzPickle
8 | from matplotlib import colors
9 | from functools import partial
10 | from enum import Enum
11 | import copy
12 |
13 | from ray.rllib.env.multi_agent_env import MultiAgentEnv
14 |
15 | # https://bair.berkeley.edu/blog/2018/12/12/rllib/
16 |
17 | DEFAULT_OPTIONS = {
18 | 'world_shape': [24, 24],
19 | 'state_size': 48,
20 | 'collapse_state': False,
21 | 'termination_no_new_coverage': 10,
22 | 'max_episode_len': -1,
23 | "min_coverable_area_fraction": 0.6,
24 | "map_mode": "random",
25 | "n_agents": [5],
26 | "disabled_teams_step": [False],
27 | "disabled_teams_comms": [False],
28 | 'communication_range': 8.0,
29 | 'one_agent_per_cell': False,
30 | 'ensure_connectivity': True,
31 | 'reward_type': 'semi_cooperative',
32 | #"operation_mode": 'all', # greedy_only, coop_only, don't default for now
33 | 'episode_termination': 'early',
34 | 'agent_observability_radius': None,
35 | }
36 |
37 | X = 1
38 | Y = 0
39 |
40 | class Dir(Enum):
41 | RIGHT = 0
42 | LEFT = 1
43 | UP = 2
44 | DOWN = 3
45 |
46 | class WorldMap():
47 | def __init__(self, random_state, shape, min_coverable_area_fraction):
48 | self.shape = tuple(shape)
49 | self.min_coverable_area_fraction = min_coverable_area_fraction
50 | self.reset(random_state)
51 |
52 | def reset(self, random_state, mode="random"):
53 | self.coverage = np.zeros(self.shape, dtype=np.int)
54 | if mode == "random":
55 | if self.min_coverable_area_fraction == 1.0:
56 | self.map = np.zeros(self.shape, dtype=np.uint8)
57 | else:
58 | self.map = np.ones(self.shape, dtype=np.uint8)
59 | p = np.array([random_state.randint(0, self.shape[c]) for c in [Y, X]])
60 | while self.get_coverable_area_faction() < self.min_coverable_area_fraction:
61 | d_p = np.array([[0, 1], [0, -1], [-1, 0], [1, 0]][random_state.randint(0, 4)])#*random_state.randint(1, 5)
62 | p_new = np.clip(p + d_p, [0,0], np.array(self.shape)-1)
63 | self.map[min(p[Y],p_new[Y]):max(p[Y],p_new[Y])+1, min(p[X],p_new[X]):max(p[X],p_new[X])+1] = 0
64 | #print(min(p[Y],p_new[Y]),max(p[Y],p_new[Y])+1, min(p[X],p_new[X]),max(p[X],p_new[X])+1, np.sum(self.map))
65 | p = p_new
66 | elif mode == "split_half_fixed" or mode == "split_half_fixed_block" or mode == "split_half_fixed_block_same_side":
67 | self.map = np.zeros(self.shape, dtype=np.uint8)
68 | self.map[:, int(self.shape[X]/2)] = 1
69 | if mode == "split_half_fixed":
70 | self.map[int(self.shape[Y]/2), int(self.shape[X]/2)] = 0
71 |
72 | def get_coverable_area_faction(self):
73 | coverable_area = ~(self.map > 0)
74 | return np.sum(coverable_area)/(self.map.shape[X]*self.map.shape[Y])
75 |
76 | def get_coverable_area(self):
77 | coverable_area = ~(self.map>0)
78 | return np.sum(coverable_area)
79 |
80 | def get_covered_area(self):
81 | coverable_area = ~(self.map>0)
82 | return np.sum((self.coverage > 0) & coverable_area)
83 |
84 | def get_coverage_fraction(self):
85 | coverable_area = ~(self.map>0)
86 | covered_area = (self.coverage > 0) & coverable_area
87 | return np.sum(covered_area)/np.sum(coverable_area)
88 |
89 | class Action(Enum):
90 | NOP = 0
91 | MOVE_RIGHT = 1
92 | MOVE_LEFT = 2
93 | MOVE_UP = 3
94 | MOVE_DOWN = 4
95 |
96 | class Robot():
97 | def __init__(self,
98 | index,
99 | random_state,
100 | world,
101 | state_size,
102 | collapse_state,
103 | termination_no_new_coverage,
104 | agent_observability_radius,
105 | one_agent_per_cell):
106 | self.index = index
107 | self.world = world
108 | self.termination_no_new_coverage = termination_no_new_coverage
109 | self.state_size = state_size
110 | self.collapse_state = collapse_state
111 | self.initialized_rendering = False
112 | self.agent_observability_radius = agent_observability_radius
113 | self.one_agent_per_cell = one_agent_per_cell
114 | self.pose = np.array([-1, -1]) # assign negative pose so that during reset agents are not placed at same initial position
115 | self.reset(random_state)
116 |
117 | def reset(self, random_state, pose_mean=np.array([0, 0]), pose_var=1):
118 | def random_pos(var):
119 | return np.array([
120 | int(np.clip(random_state.normal(loc=pose_mean[c], scale=var), 0, self.world.map.shape[c]-1))
121 | for c in [Y, X]])
122 |
123 | current_pose_var = pose_var
124 | self.pose = random_pos(current_pose_var)
125 | self.prev_pose = self.pose.copy()
126 | while self.world.map.map[self.pose[Y], self.pose[X]] == 1 or (self.world.is_occupied(self.pose, self) and self.one_agent_per_cell):
127 | self.pose = random_pos(current_pose_var)
128 | current_pose_var += 0.1
129 |
130 | self.coverage = np.zeros(self.world.map.shape, dtype=np.bool)
131 | self.state = None
132 | self.no_new_coverage_steps = 0
133 | self.reward = 0
134 |
135 | def step(self, action):
136 | action = Action(action)
137 |
138 | delta_pose = {
139 | Action.MOVE_RIGHT: [ 0, 1],
140 | Action.MOVE_LEFT: [ 0, -1],
141 | Action.MOVE_UP: [-1, 0],
142 | Action.MOVE_DOWN: [ 1, 0],
143 | Action.NOP: [ 0, 0]
144 | }[action]
145 |
146 | is_valid_pose = lambda p: all([p[c] >= 0 and p[c] < self.world.map.shape[c] for c in [Y, X]])
147 | is_obstacle = lambda p: self.world.map.map[p[Y]][p[X]] == 1
148 |
149 | self.prev_pose = self.pose.copy()
150 | desired_pos = self.pose + delta_pose
151 | if is_valid_pose(desired_pos) and (not self.world.is_occupied(desired_pos, self) or not self.one_agent_per_cell) and not is_obstacle(desired_pos):
152 | self.pose = desired_pos
153 |
154 | if self.world.map.coverage[self.pose[Y], self.pose[X]] == 0:
155 | self.world.map.coverage[self.pose[Y], self.pose[X]] = self.index
156 | self.reward = 1
157 | self.no_new_coverage_steps = 0
158 | else:
159 | self.reward = 0
160 | self.no_new_coverage_steps += 1
161 |
162 | self.coverage[self.pose[Y], self.pose[X]] = True
163 | #self.reward -= 1 # subtract each time step
164 |
165 | def update_state(self):
166 | coverage = self.coverage.copy().astype(np.int)
167 | if self.collapse_state:
168 | yy, xx = np.mgrid[:self.coverage.shape[Y], :self.coverage.shape[X]]
169 | for (cx, cy) in zip(xx[self.coverage]-self.pose[X], yy[self.coverage]-self.pose[Y]):
170 | if abs(cx) < self.state_size/2 and abs(cy) < self.state_size/2:
171 | continue
172 | u = max(abs(cx), abs(cy))
173 | p_sq = np.round(self.pose + int(self.state_size/2)*np.array([cy/u, cx/u])).astype(np.int)
174 | coverage[p_sq[Y], p_sq[X]] += 1
175 |
176 | state_output_shape = np.array([self.state_size]*2, dtype=int)
177 | state_data = [
178 | self.to_coordinate_frame(self.world.map.map, state_output_shape, fill=1),
179 | self.to_coordinate_frame(coverage, state_output_shape, fill=0)
180 | ]
181 | if self.agent_observability_radius is not None:
182 | pose_map = np.zeros(self.world.map.shape, dtype=np.uint8)
183 |
184 | for team in self.world.teams.values():
185 | for r in team:
186 | if not r is self and np.sum((r.pose - self.pose)**2) < self.agent_observability_radius**2:
187 | pose_map[r.pose[Y], r.pose[X]] = 2
188 | pose_map[self.pose[Y], self.pose[X]] = 1
189 | state_data.append(self.to_coordinate_frame(pose_map, state_output_shape, fill=0))
190 | self.state = np.stack(state_data, axis=-1).astype(np.uint8)
191 |
192 | done = self.no_new_coverage_steps == self.termination_no_new_coverage
193 | return self.state, self.reward, done, {}
194 |
195 | def to_abs_frame(self, data):
196 | half_state_size = int(self.state_size / 2)
197 | return np.roll(data, self.pose, axis=(0, 1))[half_state_size:, half_state_size:]
198 |
199 | def to_coordinate_frame(self, m, output_shape, fill=0):
200 | half_out_shape = np.array(output_shape/2, dtype=np.int)
201 | padded = np.pad(m,([half_out_shape[Y]]*2,[half_out_shape[X]]*2), mode='constant', constant_values=fill)
202 | return padded[self.pose[Y]:self.pose[Y] + output_shape[Y], self.pose[X]:self.pose[X] + output_shape[Y]]
203 |
204 | class CoverageEnv(gym.Env, EzPickle):
205 | def __init__(self, env_config):
206 | EzPickle.__init__(self)
207 | self.seed()
208 |
209 | self.cfg = copy.deepcopy(DEFAULT_OPTIONS)
210 | self.cfg.update(env_config)
211 |
212 | self.fig = None
213 | self.map_colormap = colors.ListedColormap(['white', 'black', 'gray']) # free, obstacle, unknown
214 |
215 | hsv = np.ones((self.cfg['n_agents'][1], 3))
216 | hsv[..., 0] = np.linspace(160/360, 250/360, self.cfg['n_agents'][1] + 1)[:-1]
217 | self.teams_agents_color = {
218 | 0: [(1, 0, 0)],
219 | 1: colors.hsv_to_rgb(hsv)
220 | }
221 |
222 | '''
223 | hsv = np.ones((sum(self.cfg['n_agents']), 3))
224 | hsv[..., 0] = np.linspace(0, 1, sum(self.cfg['n_agents']) + 1)[:-1]
225 | self.teams_agents_color = {}
226 | current_index = 0
227 | for i, n_agents in enumerate(self.cfg['n_agents']):
228 | self.teams_agents_color[i] = colors.hsv_to_rgb(hsv[current_index:current_index+n_agents])
229 | current_index += n_agents
230 | '''
231 |
232 | hsv = np.ones((len(self.cfg['n_agents']), 3))
233 | hsv[..., 0] = np.linspace(0, 1, len(self.cfg['n_agents']) + 1)[:-1]
234 | self.teams_colors = ['r', 'b'] #colors.hsv_to_rgb(hsv)
235 |
236 | n_all_agents = sum(self.cfg['n_agents'])
237 | self.observation_space = spaces.Dict({
238 | 'agents': spaces.Tuple((
239 | spaces.Dict({
240 | 'map': spaces.Box(0, np.inf, shape=(self.cfg['state_size'], self.cfg['state_size'], 2 if self.cfg['agent_observability_radius'] is None else 3)),
241 | 'pos': spaces.Box(low=np.array([0,0]), high=np.array([self.cfg['world_shape'][Y], self.cfg['world_shape'][X]]), dtype=np.int),
242 | }),
243 | )*n_all_agents), # Do not add this as additional dimension of map and pos since this way it is easier to handle in the model
244 | 'gso': spaces.Box(-np.inf, np.inf, shape=(n_all_agents, n_all_agents)),
245 | 'state': spaces.Box(low=0, high=2, shape=self.cfg['world_shape']+[2+len(self.cfg['n_agents'])]),
246 | })
247 | self.action_space = spaces.Tuple((spaces.Discrete(5),)*sum(self.cfg['n_agents']))
248 |
249 | self.map = WorldMap(self.world_random_state, self.cfg['world_shape'], self.cfg['min_coverable_area_fraction'])
250 | self.teams = {}
251 | agent_index = 1
252 | for i, n_agents in enumerate(self.cfg['n_agents']):
253 | self.teams[i] = []
254 | for j in range(n_agents):
255 | self.teams[i].append(
256 | Robot(
257 | agent_index,
258 | self.agent_random_state,
259 | self,
260 | self.cfg['state_size'],
261 | self.cfg['collapse_state'],
262 | self.cfg['termination_no_new_coverage'],
263 | self.cfg['agent_observability_radius'],
264 | self.cfg['one_agent_per_cell']
265 | )
266 | )
267 | agent_index += 1
268 |
269 | self.reset()
270 |
271 | def is_occupied(self, p, agent_ignore=None):
272 | for team_key, team in self.teams.items():
273 | if self.cfg['disabled_teams_step'][team_key]:
274 | continue
275 | for o in team:
276 | if o is agent_ignore:
277 | continue
278 | if p[X] == o.pose[X] and p[Y] == o.pose[Y]:
279 | return True
280 | return False
281 |
282 | def seed(self, seed=None):
283 | self.agent_random_state, seed_agents = seeding.np_random(seed)
284 | self.world_random_state, seed_world = seeding.np_random(seed)
285 | return [seed_agents, seed_world]
286 |
287 | def reset(self):
288 | self.dones = {key: [False for _ in team] for key, team in self.teams.items()}
289 | self.timestep = 0
290 | self.map.reset(self.world_random_state, self.cfg['map_mode'])
291 |
292 | def random_pos_seed(team_key):
293 | rnd = self.agent_random_state
294 | if self.cfg['map_mode'] == "random":
295 | return np.array([rnd.randint(0, self.map.shape[c]) for c in [Y, X]])
296 | if self.cfg['map_mode'] == "split_half_fixed":
297 | return np.array([
298 | rnd.randint(0, self.map.shape[Y]),
299 | rnd.randint(0, int(self.map.shape[X]/3))
300 | ])
301 | elif self.cfg['map_mode'] == "split_half_fixed_block":
302 | if team_key == 0:
303 | return np.array([
304 | rnd.randint(0, self.map.shape[Y]),
305 | rnd.randint(0, int(self.map.shape[X] / 3))
306 | ])
307 | else:
308 | return np.array([
309 | rnd.randint(0, self.map.shape[Y]),
310 | rnd.randint(2*int(self.map.shape[X] / 3), self.map.shape[X])
311 | ])
312 | elif self.cfg['map_mode'] == "split_half_fixed_block_same_side":
313 | return np.array([
314 | rnd.randint(0, self.map.shape[Y]),
315 | rnd.randint(2*int(self.map.shape[X] / 3), self.map.shape[X])
316 | ])
317 |
318 | pose_seed = None
319 | for team_key, team in self.teams.items():
320 | if not self.cfg['map_mode'] == "random" or pose_seed is None:
321 | # shared pose_seed if random map mode
322 | pose_seed = random_pos_seed(team_key)
323 | while self.map.map[pose_seed[Y], pose_seed[X]] == 1:
324 | pose_seed = random_pos_seed(team_key)
325 | for r in team:
326 | r.reset(self.agent_random_state, pose_mean=pose_seed, pose_var=1)
327 | return self.step([Action.NOP]*sum(self.cfg['n_agents']))[0]
328 |
329 | def compute_gso(self, team_id=0):
330 | own_team_agents = [(agent, self.cfg['disabled_teams_comms'][team_id]) for agent in self.teams[team_id]]
331 | other_agents = [(agent, self.cfg['disabled_teams_comms'][other_team_id]) for other_team_id, team in self.teams.items() for agent in team if not team_id == other_team_id]
332 |
333 | all_agents = own_team_agents + other_agents # order is important since in model the data is concatenated in this order as well
334 | dists = np.zeros((len(all_agents), len(all_agents)))
335 | done_matrix = np.zeros((len(all_agents), len(all_agents)), dtype=np.bool)
336 | for agent_y in range(len(all_agents)):
337 | for agent_x in range(agent_y):
338 | dst = np.sum(np.array(all_agents[agent_x][0].pose - all_agents[agent_y][0].pose)**2)
339 | dists[agent_y, agent_x] = dst
340 | dists[agent_x, agent_y] = dst
341 |
342 | d = all_agents[agent_x][1] or all_agents[agent_y][1]
343 | done_matrix[agent_y, agent_x] = d
344 | done_matrix[agent_x, agent_y] = d
345 |
346 | current_dist = self.cfg['communication_range']
347 | A = dists < (current_dist**2)
348 | active_row = ~np.array([a[1] for a in all_agents])
349 | if self.cfg['ensure_connectivity']:
350 | def is_connected(m):
351 | def walk_dfs(m, index):
352 | for i in range(len(m)):
353 | if m[index][i]:
354 | m[index][i] = False
355 | walk_dfs(m, i)
356 |
357 | m_c = m.copy()
358 | walk_dfs(m_c, 0)
359 | return not np.any(m_c.flatten())
360 |
361 | # set done teams as generally connected since they should not be included by increasing connectivity
362 | while not is_connected(A[active_row][:, active_row]):
363 | current_dist *= 1.1
364 | A = (dists < current_dist**2)
365 |
366 | # Mask out done agents
367 | A = (A & ~done_matrix).astype(np.int)
368 |
369 | # normalization: refer https://github.com/QingbiaoLi/GraphNets/blob/master/Flocking/Utils/dataTools.py#L601
370 | np.fill_diagonal(A, 0)
371 | deg = np.sum(A, axis = 1) # nNodes (degree vector)
372 | D = np.diag(deg)
373 | Dp = np.diag(np.nan_to_num(np.power(deg, -1/2)))
374 | L = A # D-A
375 | gso = Dp @ L @ Dp
376 | return gso
377 |
378 | def step(self, actions):
379 | self.timestep += 1
380 | action_index = 0
381 | for i, team in enumerate(self.teams.values()):
382 | for agent in team:
383 | if not self.cfg['disabled_teams_step'][i]:
384 | agent.step(actions[action_index])
385 | action_index += 1
386 |
387 | states, rewards = {}, {}
388 | for team_key, team in self.teams.items():
389 | states[team_key] = []
390 | rewards[team_key] = {}
391 | for i, agent in enumerate(team):
392 | state, reward, done, _ = agent.update_state()
393 | states[team_key].append(state)
394 | rewards[team_key][i] = reward
395 | if done:
396 | self.dones[team_key][i] = True
397 | dones = {}
398 | world_done = self.timestep == self.cfg['max_episode_len'] or self.map.get_coverage_fraction() == 1.0
399 | for key in self.teams.keys():
400 | dones[key] = world_done
401 | if self.cfg['episode_termination'] == 'early' or self.cfg['episode_termination'] == 'early_any':
402 | dones[key] = world_done or any(self.dones[key])
403 | elif self.cfg['episode_termination'] == 'early_all':
404 | dones[key] = world_done or all(self.dones[key])
405 | elif self.cfg['episode_termination'] == 'early_right':
406 | # early term only if at least one agent has reached right side of env
407 | # before that fixed episode length (world_done)
408 | agent_is_in_right_half = False
409 | for agent in self.teams[key]:
410 | if agent.pose[X] < self.cfg['world_shape'][X]/2:
411 | agent.no_new_coverage_steps = 0
412 | else:
413 | agent_is_in_right_half = True
414 |
415 | if agent_is_in_right_half:
416 | dones[key] = any(self.dones[key])
417 | else:
418 | dones[key] = world_done
419 | elif self.cfg['episode_termination'] == 'default':
420 | pass
421 | else:
422 | raise NotImplementedError("Unknown termination mode", self.cfg['episode_termination'])
423 |
424 | if self.cfg['operation_mode'] == "all":
425 | pass
426 | elif self.cfg['operation_mode'] == "greedy_only" or self.cfg['operation_mode'] == "adversary_only":
427 | dones[1] = dones[0]
428 | elif self.cfg['operation_mode'] == "coop_only":
429 | dones[0] = dones[1]
430 | else:
431 | raise NotImplementedError("Unknown operation_mode")
432 | done = any(dones.values()) # Currently we cannot run teams independently, all have to stop at the same time
433 |
434 | pose_map = np.zeros(self.map.shape + (len(self.teams),), dtype=np.uint8)
435 | for i, team in enumerate(self.teams.values()):
436 | for r in team:
437 | pose_map[r.pose[Y], r.pose[X], i] = 1
438 | global_state = np.concatenate([np.stack([self.map.map, self.map.coverage > 0], axis=-1), pose_map], axis=-1)
439 | state = {
440 | 'agents': tuple([{
441 | 'map': states[key][agent_i],
442 | 'pos': self.teams[key][agent_i].pose
443 | } for key in self.teams.keys() for agent_i in range(self.cfg['n_agents'][key])]),
444 | 'gso': self.compute_gso(0),
445 | 'state': global_state
446 | }
447 |
448 | for key in self.teams.keys():
449 | if self.cfg['reward_type'] == 'semi_cooperative':
450 | pass
451 | elif self.cfg['reward_type'] == 'split_right':
452 | for agent_key in rewards[key].keys():
453 | if self.teams[key][agent_key].pose[X] < self.cfg['world_shape'][X]/2:
454 | rewards[key][agent_key] = 0
455 | else:
456 | raise NotImplementedError("Unknown reward type", self.cfg['reward_type'])
457 | if self.cfg['operation_mode'] == "all":
458 | pass
459 | elif self.cfg['operation_mode'] == "greedy_only":
460 | # copy all rewards from the greedy agent to the cooperative agents
461 | rewards[1] = {agent_key: sum(rewards[0].values()) for agent_key in rewards[1].keys()}
462 | elif self.cfg['operation_mode'] == "adversary_only":
463 | # The greedy agent's reward is the negative sum of all agent's rewards
464 | all_negative = -sum([sum(team_rewards.values()) for team_rewards in rewards.values()])
465 | rewards[0] = {agent_key: all_negative for agent_key in rewards[0].keys()}
466 |
467 | # copy all rewards from the greedy agent to the cooperative agents
468 | rewards[1] = {agent_key: sum(rewards[0].values()) for agent_key in rewards[1].keys()}
469 | elif self.cfg['operation_mode'] == "coop_only":
470 | # copy all rewards from the coop agent to the greedy agents
471 | rewards[0] = {agent_key: sum(rewards[1].values()) for agent_key in rewards[0].keys()}
472 | else:
473 | raise NotImplementedError("Unknown operation_mode")
474 |
475 | flattened_rewards = {}
476 | agent_index = 0
477 | for key in self.teams.keys():
478 | for r in rewards[key].values():
479 | flattened_rewards[agent_index] = r
480 | agent_index += 1
481 | info = {
482 | 'current_global_coverage': self.map.get_coverage_fraction(),
483 | 'coverable_area': self.map.get_coverable_area(),
484 | 'rewards_teams': rewards,
485 | 'rewards': flattened_rewards
486 | }
487 | return state, sum([sum(t.values()) for i, t in enumerate(rewards.values()) if not self.cfg['disabled_teams_step'][i]]), done, info
488 |
489 | def clear_patches(self, ax):
490 | [p.remove() for p in reversed(ax.patches)]
491 | [t.remove() for t in reversed(ax.texts)]
492 |
493 | def render_adjacency(self, A, team_id, ax, color='b', stepsize=1.0):
494 | A = A.copy()
495 | own_team_agents = [agent for agent in self.teams[team_id]]
496 | other_agents = [agent for other_team_id, team in self.teams.items() for agent in team if not team_id == other_team_id]
497 | all_agents = own_team_agents + other_agents
498 | for agent_id, agent in enumerate(all_agents):
499 | for connected_agent_id in np.arange(len(A)):
500 | if A[agent_id][connected_agent_id] > 0:
501 | current_agent_pose = agent.prev_pose + (agent.pose - agent.prev_pose) * stepsize
502 | other_agent = all_agents[connected_agent_id]
503 | other_agent_pose = other_agent.prev_pose + (other_agent.pose - other_agent.prev_pose) * stepsize
504 | ax.add_patch(patches.ConnectionPatch(
505 | [current_agent_pose[X], current_agent_pose[Y]],
506 | [other_agent_pose[X], other_agent_pose[Y]],
507 | "data", edgecolor='g', facecolor='none', lw=1, ls=":", alpha=0.3
508 | ))
509 |
510 | A[connected_agent_id][agent_id] = 0 # don't draw same connection again
511 |
512 | def render_global_coverages(self, ax):
513 | if not hasattr(self, 'im_cov_global'):
514 | self.im_cov_global = ax.imshow(np.zeros(self.map.shape), vmin=0, vmax=100)
515 | all_team_colors = [(0, 0, 0, 0)] + [tuple(list(c) + [0.5]) for team_colors in self.teams_agents_color.values() for c in team_colors]
516 | coverage = self.map.coverage.copy()
517 | if self.cfg['map_mode'] == 'split_half_fixed':
518 | # mark coverage on left side as gray
519 | color_index_left_side = len(all_team_colors)
520 | all_team_colors += [(0, 0, 0, 0.5)] # gray
521 | xx, _ = np.meshgrid(
522 | np.arange(0, coverage.shape[X], 1),
523 | np.arange(0, coverage.shape[Y], 1)
524 | )
525 | coverage[(xx < coverage.shape[X]/2) & (coverage > 0)] = color_index_left_side
526 |
527 | self.im_cov_global.set_data(colors.ListedColormap(all_team_colors)(coverage))
528 |
529 | def render_local_coverages(self, ax):
530 | if not hasattr(self, 'im_robots'):
531 | self.im_robots = {}
532 | for team_key, team in self.teams.items():
533 | if self.cfg['disabled_teams_step'][team_key]:
534 | continue
535 | self.im_robots[team_key] = []
536 | for _ in team:
537 | self.im_robots[team_key].append(ax.imshow(np.zeros(self.map.shape), vmin=0, vmax=1, alpha=0.5))
538 |
539 | self.im_map.set_data(self.map_colormap(self.map.map))
540 | for (team_key, team), team_colors in zip(self.teams.items(), self.teams_agents_color.values()):
541 | if self.cfg['disabled_teams_step'][team_key]:
542 | continue
543 | team_im = self.im_robots[team_key]
544 | for (agent_i, agent), color, im in zip(enumerate(team), team_colors, team_im):
545 | im.set_data(colors.ListedColormap([(0, 0, 0, 0), color])(agent.coverage))
546 |
547 | def render_overview(self, ax, stepsize=1.0):
548 | if not hasattr(self, 'im_map'):
549 | ax.set_xticks([])
550 | ax.set_yticks([])
551 | self.im_map = ax.imshow(np.zeros(self.map.shape), vmin=0, vmax=3)
552 |
553 | self.im_map.set_data(self.map_colormap(self.map.map))
554 | for (team_key, team), team_colors in zip(self.teams.items(), self.teams_agents_color.values()):
555 | if self.cfg['disabled_teams_step'][team_key]:
556 | continue
557 | for (agent_i, agent), color in zip(enumerate(team), team_colors):
558 | rect_size = 1
559 | pose_microstep = agent.prev_pose + (agent.pose - agent.prev_pose)*stepsize
560 | rect = patches.Rectangle((pose_microstep[1] - rect_size / 2, pose_microstep[0] - rect_size / 2), rect_size, rect_size,
561 | linewidth=1, edgecolor=self.teams_colors[team_key], facecolor='none')
562 | ax.add_patch(rect)
563 | #ax.text(agent.pose[1]+1, agent.pose[0], f"{agent_i}", color=self.teams_colors[team_key], clip_on=True)
564 |
565 | #last_reward = sum([r.reward for r in self.robots.values()])
566 | #ax.set_title(
567 | # f'Global coverage: {int(self.map.get_coverage_fraction()*100)}%\n'
568 | # #f'Last reward (r): {last_reward:.2f}'
569 | #)
570 |
571 | def render_connectivity(self, ax, agent_id, K):
572 | if K <= 1:
573 | return
574 |
575 | for connected_agent_id in np.arange(self.cfg['n_agents'])[self.A[agent_id] == 1]:
576 | current_agent_pose = self.robots[agent_id].pose
577 | connected_agent_d_pose = self.robots[connected_agent_id].pose - current_agent_pose
578 | ax.add_patch(patches.Arrow(
579 | current_agent_pose[X],
580 | current_agent_pose[Y],
581 | connected_agent_d_pose[X],
582 | connected_agent_d_pose[Y],
583 | edgecolor='b',
584 | facecolor='none'
585 | ))
586 | self.render_connectivity(ax, connected_agent_id, K-1)
587 |
588 | def render(self, mode='human', stepsize=1.0):
589 | if self.fig is None:
590 | plt.ion()
591 | self.fig = plt.figure(figsize=(3, 3))
592 | self.ax_overview = self.fig.add_subplot(1, 1, 1, aspect='equal')
593 |
594 | self.clear_patches(self.ax_overview)
595 | self.render_overview(self.ax_overview, stepsize)
596 | #self.render_local_coverages(self.ax_overview)
597 | self.render_global_coverages(self.ax_overview)
598 | A = self.compute_gso(0)
599 | self.render_adjacency(A, 0, self.ax_overview, stepsize=stepsize)
600 |
601 | self.fig.canvas.draw()
602 | self.fig.subplots_adjust(bottom=0, top=1, left=0, right=1)
603 | return self.fig
604 |
605 | class CoverageEnvExplAdv(CoverageEnv):
606 | def __init__(self, cfg):
607 | super().__init__(cfg)
608 |
609 | def render(self, interpreter_obs, mode='human', stepsize=1.0):
610 | if self.fig is None:
611 | plt.ion()
612 | self.fig = plt.figure(figsize=(6, 3))
613 | gs = self.fig.add_gridspec(ncols=2, nrows=1)
614 | gs.update(wspace=0, hspace=0)
615 | self.ax_overview = self.fig.add_subplot(gs[0])
616 | ax_expl = self.fig.add_subplot(gs[1])
617 | ax_expl.set_xticks([])
618 | ax_expl.set_yticks([])
619 | self.im_expl_cov = ax_expl.imshow(np.zeros((1, 1)), vmin=0, vmax=1)
620 | self.im_expl_map = ax_expl.imshow(np.zeros((1, 1)), vmin=0, vmax=1)
621 |
622 | self.clear_patches(self.ax_overview)
623 | self.render_overview(self.ax_overview, stepsize)
624 | self.render_global_coverages(self.ax_overview)
625 | A = self.compute_gso(0)
626 | self.render_adjacency(A, 0, self.ax_overview, stepsize=stepsize)
627 |
628 | adv_coverage = interpreter_obs[0][0]
629 | cmap_own_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), list(self.teams_agents_color[0][0])+[0.5]])
630 | cmap_map = colors.ListedColormap([(0,0,0,0), (0,0,0,1)]) # free, obstacle, unknown
631 | self.im_expl_cov.set_data(cmap_own_cov(self.teams[0][0].to_abs_frame(adv_coverage)))
632 | self.im_expl_map.set_data(cmap_map(self.map.map))
633 |
634 | self.fig.canvas.draw()
635 | self.fig.subplots_adjust(bottom=0, top=1, left=0, right=1)
636 | return self.fig
637 |
638 | class CoverageEnvAdvDec(CoverageEnv):
639 | def __init__(self, cfg):
640 | super().__init__(cfg)
641 |
642 | def render(self, interpreter_obs, mode='human', stepsize=1.0):
643 | if self.fig is None:
644 | plt.ion()
645 | self.fig = plt.figure(figsize=(6, 3))
646 | gs = self.fig.add_gridspec(ncols=2, nrows=1)
647 | gs.update(wspace=0, hspace=0)
648 | self.ax_overview = self.fig.add_subplot(gs[0])
649 | ax_expl = self.fig.add_subplot(gs[1])
650 | ax_expl.set_xticks([])
651 | ax_expl.set_yticks([])
652 | self.im_expl_cov = ax_expl.imshow(np.zeros((1, 1)), vmin=0, vmax=1)
653 | self.im_expl_map = ax_expl.imshow(np.zeros((1, 1)), vmin=0, vmax=1)
654 |
655 | self.clear_patches(self.ax_overview)
656 | self.render_overview(self.ax_overview, stepsize)
657 | self.render_global_coverages(self.ax_overview)
658 | A = self.compute_gso(0)
659 | self.render_adjacency(A, 0, self.ax_overview, stepsize=stepsize)
660 |
661 | cmap_own_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), list(self.teams_agents_color[0][0])+[0.5]])
662 | cmap_map = colors.ListedColormap([(0,0,0,0), (0,0,0,1)]) # free, obstacle, unknown
663 | self.im_expl_cov.set_data(cmap_own_cov(interpreter_obs[0][1]))
664 | #self.im_expl_map.set_data(cmap_map(self.map.map))
665 | self.im_expl_map.set_data(cmap_map(interpreter_obs[0][0]))
666 |
667 | self.fig.canvas.draw()
668 | self.fig.subplots_adjust(bottom=0, top=1, left=0, right=1)
669 | return self.fig
670 |
671 | class CoverageEnvSingleSaliency(CoverageEnv):
672 | def __init__(self, cfg):
673 | super().__init__(cfg)
674 |
675 | def render(self, interpreter_obs, interpr_index=0, mode='human', stepsize=1.0):
676 | if self.fig is None:
677 | plt.ion()
678 | self.fig = plt.figure(figsize=(6, 3))
679 | gs = self.fig.add_gridspec(ncols=3, nrows=1)
680 | gs.update(wspace=0, hspace=0)
681 | self.ax_overview = self.fig.add_subplot(gs[0])
682 | ax_expl_map = self.fig.add_subplot(gs[1])
683 | ax_expl_cov = self.fig.add_subplot(gs[2])
684 | ax_expl_map.set_xticks([])
685 | ax_expl_map.set_yticks([])
686 | ax_expl_cov.set_xticks([])
687 | ax_expl_cov.set_yticks([])
688 | self.im_expl_cov = ax_expl_cov.imshow(np.zeros((1, 1)), vmin=0, vmax=1)
689 | self.im_expl_map = ax_expl_map.imshow(np.zeros((1, 1)), vmin=0, vmax=1)
690 |
691 | self.clear_patches(self.ax_overview)
692 | self.render_overview(self.ax_overview, stepsize)
693 | self.render_global_coverages(self.ax_overview)
694 | A = self.compute_gso(0)
695 | self.render_adjacency(A, 0, self.ax_overview, stepsize=stepsize)
696 |
697 | #cmap_own_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), list(self.teams_agents_color[0][0])+[0.5]])
698 | #cmap_map = colors.ListedColormap([(0,0,0,0), (0,0,0,1)]) # free, obstacle, unknown
699 |
700 | saliency_limits = (np.min(interpreter_obs), np.max(interpreter_obs))
701 | self.im_expl_cov.set_clim(saliency_limits[0], saliency_limits[1])
702 | self.im_expl_map.set_clim(saliency_limits[0], saliency_limits[1])
703 | self.im_expl_cov.set_data(interpreter_obs[interpr_index][:, :, 1])
704 | self.im_expl_map.set_data(interpreter_obs[interpr_index][:, :, 0])
705 |
706 | #self.im_expl_cov.set_data(cmap_own_cov(interpreter_obs[interpr_index][1]))
707 | #self.im_expl_map.set_data(cmap_map(self.map.map))
708 | #self.im_expl_map.set_data(cmap_map(interpreter_obs[interpr_index][0]))
709 |
710 | self.fig.canvas.draw()
711 | self.fig.subplots_adjust(bottom=0, top=1, left=0, right=1)
712 | return self.fig
713 |
714 | class CoverageEnvSaliency(CoverageEnv):
715 | def __init__(self, cfg):
716 | super().__init__(cfg)
717 |
718 | def render(self, mode='human', saliency_obs=None, interpreter_obs=None):
719 | if self.fig is None:
720 | plt.ion()
721 | self.fig = plt.figure(constrained_layout=True, figsize=(16, 10))
722 | grid_spec = self.fig.add_gridspec(ncols=max(self.cfg['n_agents']) * 3,
723 | nrows=1 + 2 * len(self.cfg['n_agents']),
724 | height_ratios=[1] + [1, 1] * len(self.cfg['n_agents']))
725 |
726 | self.ax_overview = self.fig.add_subplot(grid_spec[0, :])
727 |
728 | self.ax_im_agent = {}
729 | for team_key, team in self.teams.items():
730 | self.ax_im_agent[team_key] = []
731 | for i in range(self.cfg['n_agents'][team_key]):
732 | self.ax_im_agent[team_key].append({})
733 | for j, col_id in enumerate(['map', 'coverage']):
734 | self.ax_im_agent[team_key][i][col_id] = {}
735 | for k, row_id in enumerate(['obs', 'sal']):
736 | ax = self.fig.add_subplot(grid_spec[j + 1 + team_key * 2, i * 3 + k])
737 | ax.set_xticks([])
738 | ax.set_yticks([])
739 | self.ax_im_agent[team_key][i][col_id][row_id] = {'ax': ax, 'im': None}
740 | self.ax_im_agent[team_key][i][col_id]['sal']['im'] = self.ax_im_agent[team_key][i][col_id]['sal']['ax'].imshow(np.zeros((1, 1)), vmin=-5, vmax=5)
741 | #self.ax_im_agent[team_key][i][col_id]['int']['im'] = self.ax_im_agent[team_key][i][col_id]['int']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=1)
742 | self.ax_im_agent[team_key][i]['map']['obs']['im'] = self.ax_im_agent[team_key][i]['map']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=3)
743 | self.ax_im_agent[team_key][i]['coverage']['obs']['im'] = self.ax_im_agent[team_key][i]['coverage']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=1, alpha=0.3)
744 |
745 | self.clear_patches(self.ax_overview)
746 | self.render_overview(self.ax_overview)
747 | A = self.compute_gso(0)
748 | self.render_adjacency(A, 0, self.ax_overview)
749 |
750 | if saliency_obs is not None:
751 | saliency_limits = (np.min(saliency_obs), np.max(saliency_obs))
752 | img_map_id = 0
753 | for team_key, team in self.teams.items():
754 | for i, robot in enumerate(team):
755 |
756 | self.ax_im_agent[team_key][i]['map']['obs']['im'].set_data(
757 | self.map_colormap(robot.to_abs_frame(robot.state[..., 0])))
758 | this_coverage_colormap = colors.ListedColormap([(0, 0, 0, 0), self.teams_agents_color[team_key][i]])
759 | self.ax_im_agent[team_key][i]['coverage']['obs']['im'].set_data(
760 | this_coverage_colormap(robot.to_abs_frame(robot.state[..., 1])))
761 |
762 | if saliency_obs is not None:
763 | self.ax_im_agent[team_key][i]['map']['sal']['im'].set_data(
764 | robot.to_abs_frame(saliency_obs[img_map_id][..., 0]))
765 | self.ax_im_agent[team_key][i]['map']['sal']['im'].set_clim(saliency_limits[0], saliency_limits[1])
766 | self.ax_im_agent[team_key][i]['coverage']['sal']['im'].set_data(
767 | robot.to_abs_frame(saliency_obs[img_map_id][..., 1]))
768 | self.ax_im_agent[team_key][i]['coverage']['sal']['im'].set_clim(saliency_limits[0], saliency_limits[1])
769 |
770 | if interpreter_obs is not None:
771 | self.ax_im_agent[team_key][i]['map']['int']['im'].set_data(
772 | robot.to_abs_frame(interpreter_obs[img_map_id][1]))
773 | self.ax_im_agent[team_key][i]['coverage']['int']['im'].set_data(
774 | robot.to_abs_frame(interpreter_obs[img_map_id][0]))
775 |
776 | img_map_id += 1
777 |
778 | self.ax_im_agent[team_key][i]['map']['obs']['ax'].set_title(
779 | f'{i}') #\nc: {0:.2f}\nr: {robot.reward:.2f}')
780 |
781 | # self.render_connectivity(self.ax_overview, 0, 3)
782 | self.fig.canvas.draw()
783 | return self.fig
784 |
--------------------------------------------------------------------------------
/adversarial_comms/environments/path_planning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | import matplotlib.pyplot as plt
4 | import matplotlib.patches as patches
5 | import gym
6 | from gym import spaces
7 | from gym.utils import seeding, EzPickle
8 | from matplotlib import colors
9 | from functools import partial
10 | from enum import Enum
11 | import copy
12 |
13 | from ray.rllib.env.multi_agent_env import MultiAgentEnv
14 |
15 | # https://bair.berkeley.edu/blog/2018/12/12/rllib/
16 |
17 | DEFAULT_OPTIONS = {
18 | 'world_shape': [12, 12],
19 | 'state_size': 24,
20 | 'max_episode_len': 50,
21 | "n_agents": [8],
22 | "disabled_teams_step": [False],
23 | "disabled_teams_comms": [False],
24 | 'communication_range': 5.0,
25 | 'ensure_connectivity': True,
26 | 'position_mode': 'random', # random or fixed
27 | 'agents': {
28 | 'visibility_distance': 3,
29 | 'relative_coord_frame': True
30 | }
31 | }
32 |
33 | X = 1
34 | Y = 0
35 |
36 | class Dir(Enum):
37 | RIGHT = 0
38 | LEFT = 1
39 | UP = 2
40 | DOWN = 3
41 |
42 | class WorldMap():
43 | def __init__(self, shape, mode):
44 | self.shape = shape
45 | self.mode = mode
46 | self.reset()
47 |
48 | def reset(self):
49 | if self.mode == "traffic":
50 | self.map = np.array([
51 | [1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1],
52 | [1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1],
53 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
54 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
55 | [1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1],
56 | [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0],
57 | [1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1],
58 | [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
59 | [1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1],
60 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
61 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
62 | [1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1],
63 | ])
64 | elif self.mode == "warehouse":
65 | self.map = np.zeros(self.shape, dtype=np.uint8)
66 | for y in range(1, self.shape[Y]-1, 2):
67 | for x in range(1, self.shape[X]-1, 6):
68 | self.map[y:y+1,x:x+4] = True
69 | else:
70 | raise NotImplementedError
71 |
72 | class Action(Enum):
73 | NOP = 0
74 | MOVE_RIGHT = 1
75 | MOVE_LEFT = 2
76 | MOVE_UP = 3
77 | MOVE_DOWN = 4
78 |
79 | class Robot():
80 | def __init__(self,
81 | world,
82 | agent_observability_radius,
83 | state_size,
84 | coordinate_frame_is_local):
85 | self.world = world
86 | self.state_size = state_size
87 | self.coordinate_frame_is_local = coordinate_frame_is_local
88 | self.agent_observability_radius = agent_observability_radius
89 | self.reset([0, 0], [0, 0])
90 |
91 | def reset(self, pose, goal):
92 | self.pose = np.array(pose, dtype=np.int)
93 | self.prev_pose = self.pose.copy()
94 | self.goal = np.array(goal, dtype=np.int)
95 |
96 | def step(self, action):
97 | action = Action(action)
98 |
99 | delta_pose = {
100 | Action.MOVE_RIGHT: [ 0, 1],
101 | Action.MOVE_LEFT: [ 0, -1],
102 | Action.MOVE_UP: [-1, 0],
103 | Action.MOVE_DOWN: [ 1, 0],
104 | Action.NOP: [ 0, 0]
105 | }[action]
106 |
107 | def is_occupied(p):
108 | for team_key, team in self.world.teams.items():
109 | if self.world.cfg['disabled_teams_step'][team_key]:
110 | continue
111 | for o in team:
112 | if p[X] == o.pose[X] and p[Y] == o.pose[Y] and o is not self:
113 | return True
114 | return False
115 |
116 | is_valid_pose = lambda p: all([p[c] >= 0 and p[c] < self.world.map.shape[c] for c in [Y, X]])
117 | is_obstacle = lambda p: self.world.map.map[p[Y]][p[X]]
118 |
119 | self.prev_pose = self.pose.copy()
120 | desired_pos = self.pose + delta_pose
121 | if is_valid_pose(desired_pos) and not is_occupied(desired_pos) and not is_obstacle(desired_pos):
122 | self.pose = desired_pos
123 |
124 | def update_state(self):
125 | pose_map = np.zeros(self.world.map.shape, dtype=np.uint8)
126 | for team in self.world.teams.values():
127 | for r in team:
128 | if not r is self and np.sum((r.pose - self.pose)**2) <= self.agent_observability_radius**2:
129 | pose_map[r.pose[Y], r.pose[X]] = 2
130 | pose_map[self.pose[Y], self.pose[1]] = 1
131 |
132 | goal_map = np.zeros(self.world.map.shape, dtype=np.bool)
133 | cy, cx = self.goal - self.pose
134 | if abs(cx) < self.state_size/2 and abs(cy) < self.state_size/2:
135 | goal_map[self.goal[Y], self.goal[X]] = True
136 | else:
137 | u = max(abs(cx), abs(cy))
138 | p_sq = np.round(self.pose + int(self.state_size / 2) * np.array([cy / u, cx / u])).astype(np.int)
139 | goal_map[p_sq[Y], p_sq[X]] = True
140 |
141 | self.state = np.stack([self.to_coordinate_frame(self.world.map.map, 1), self.to_coordinate_frame(goal_map, 0), self.to_coordinate_frame(pose_map, 0)], axis=-1).astype(np.uint8)
142 | done = all(self.pose == self.goal)
143 | return self.state, done
144 |
145 | def to_coordinate_frame(self, m, fill=0):
146 | if self.coordinate_frame_is_local:
147 | half_state_shape = np.array([self.state_size/2]*2, dtype=np.int)
148 | padded = np.pad(m,([half_state_shape[Y]]*2,[half_state_shape[X]]*2), mode='constant', constant_values=fill)
149 | return padded[self.pose[Y]:self.pose[Y] + self.state_size, self.pose[X]:self.pose[X] + self.state_size]
150 | else:
151 | return m
152 |
153 | class PathPlanningEnv(gym.Env, EzPickle):
154 | def __init__(self, env_config):
155 | EzPickle.__init__(self)
156 | self.seed()
157 |
158 | self.cfg = copy.deepcopy(DEFAULT_OPTIONS)
159 | self.cfg.update(env_config)
160 |
161 | self.fig = None
162 | self.map_colormap = colors.ListedColormap(['white', 'black', 'gray']) # free, obstacle, unknown
163 |
164 | hsv = np.ones((sum(self.cfg['n_agents']), 3))
165 | hsv[..., 0] = np.linspace(0, 1, sum(self.cfg['n_agents']) + 1)[:-1]
166 | self.teams_agents_color = {}
167 | current_index = 0
168 | for i, n_agents in enumerate(self.cfg['n_agents']):
169 | self.teams_agents_color[i] = colors.hsv_to_rgb(hsv[current_index:current_index+n_agents])
170 | current_index += n_agents
171 |
172 | hsv = np.ones((len(self.cfg['n_agents']), 3))
173 | hsv[..., 0] = np.linspace(0, 1, len(self.cfg['n_agents']) + 1)[:-1]
174 | self.teams_colors = ['r', 'b'] #colors.hsv_to_rgb(hsv)
175 |
176 | n_all_agents = sum(self.cfg['n_agents'])
177 | self.observation_space = spaces.Dict({
178 | 'agents': spaces.Tuple((
179 | spaces.Dict({
180 | 'map': spaces.Box(0, np.inf, shape=(self.cfg['state_size'], self.cfg['state_size'], 3)),
181 | 'pos': spaces.Box(low=np.array([0,0]), high=np.array([self.cfg['world_shape'][Y], self.cfg['world_shape'][X]]), dtype=np.int),
182 | }),
183 | )*n_all_agents), # Do not add this as additional dimension of map and pos since this way it is easier to handle in the model
184 | 'gso': spaces.Box(-np.inf, np.inf, shape=(n_all_agents, n_all_agents)),
185 | 'state': spaces.Box(low=0, high=3, shape=self.cfg['world_shape']+[sum(self.cfg['n_agents'])]),
186 | })
187 | self.action_space = spaces.Tuple((spaces.Discrete(5),)*sum(self.cfg['n_agents']))
188 |
189 | self.map = WorldMap(self.cfg['world_shape'], self.cfg['world_mode'])
190 |
191 | self.teams = {
192 | i: [
193 | Robot(
194 | self,
195 | self.cfg['agents']['visibility_distance'],
196 | self.cfg['state_size'],
197 | self.cfg['agents']['relative_coord_frame']
198 | ) for _ in range(n_agents)
199 | ] for i, n_agents in enumerate(self.cfg['n_agents'])
200 | }
201 |
202 | self.reset()
203 |
204 | def seed(self, seed=None):
205 | self.random_state, seed_agents = seeding.np_random(seed)
206 | return [seed_agents]
207 |
208 | def reset(self):
209 | self.timestep = 0
210 | self.dones = {key: [False for _ in team] for key, team in self.teams.items()}
211 | self.map.reset()
212 |
213 | def sample_random_pos():
214 | x = self.random_state.randint(0, self.map.shape[X])
215 | y = self.random_state.randint(0, self.map.shape[Y])
216 | return np.array([y, x])
217 |
218 | def sample_valid_random_pos(up_to=None):
219 | def get_agents():
220 | return [o for team in self.teams.values() for o in team][:up_to]
221 | def is_occupied(p):
222 | return any([all(p == o.pose) for o in get_agents()])
223 | def is_other_goal(p):
224 | return any([all(p == o.goal) for o in get_agents()])
225 | is_obstacle = lambda p: self.map.map[p[Y]][p[X]]
226 |
227 | pose_seed = sample_random_pos()
228 | while is_obstacle(pose_seed) or is_occupied(pose_seed) or is_other_goal(pose_seed):
229 | pose_seed = sample_random_pos()
230 | return pose_seed
231 |
232 | agent_index = 0
233 | for team_key, team in self.teams.items():
234 | if self.cfg['disabled_teams_step'][team_key]:
235 | continue
236 | for agent in team:
237 | agent.reset(sample_valid_random_pos(agent_index), sample_valid_random_pos(agent_index))
238 | agent_index += 1
239 |
240 | return self.step([Action.NOP]*sum(self.cfg['n_agents']))[0]
241 |
242 | def compute_gso(self, team_id=0):
243 | own_team_agents = [(agent, self.cfg['disabled_teams_comms'][team_id]) for agent in self.teams[team_id]]
244 | other_agents = [(agent, self.cfg['disabled_teams_comms'][other_team_id]) for other_team_id, team in self.teams.items() for agent in team if not team_id == other_team_id]
245 |
246 | all_agents = own_team_agents + other_agents # order is important since in model the data is concatenated in this order as well
247 | dists = np.zeros((len(all_agents), len(all_agents)))
248 | done_matrix = np.zeros((len(all_agents), len(all_agents)), dtype=np.bool)
249 | for agent_y in range(len(all_agents)):
250 | for agent_x in range(agent_y):
251 | dst = np.sum(np.array(all_agents[agent_x][0].pose - all_agents[agent_y][0].pose)**2)
252 | dists[agent_y, agent_x] = dst
253 | dists[agent_x, agent_y] = dst
254 |
255 | d = all_agents[agent_x][1] or all_agents[agent_y][1]
256 | done_matrix[agent_y, agent_x] = d
257 | done_matrix[agent_x, agent_y] = d
258 |
259 | current_dist = self.cfg['communication_range']
260 | A = dists < (current_dist**2)
261 | active_row = ~np.array([a[1] for a in all_agents])
262 | if self.cfg['ensure_connectivity']:
263 | def is_connected(m):
264 | def walk_dfs(m, index):
265 | for i in range(len(m)):
266 | if m[index][i]:
267 | m[index][i] = False
268 | walk_dfs(m, i)
269 |
270 | m_c = m.copy()
271 | walk_dfs(m_c, 0)
272 | return not np.any(m_c.flatten())
273 |
274 | # set done teams as generally connected since they should not be included by increasing connectivity
275 | while not is_connected(A[active_row][:, active_row]):
276 | current_dist *= 1.1
277 | A = (dists < current_dist**2)
278 |
279 | # Mask out done agents
280 | A = (A & ~done_matrix).astype(np.int)
281 |
282 | # normalization: refer https://github.com/QingbiaoLi/GraphNets/blob/master/Flocking/Utils/dataTools.py#L601
283 | np.fill_diagonal(A, 0)
284 | deg = np.sum(A, axis = 1) # nNodes (degree vector)
285 | D = np.diag(deg)
286 | Dp = np.diag(np.nan_to_num(np.power(deg, -1/2)))
287 | L = A # D-A
288 | gso = Dp @ L @ Dp
289 | return gso
290 |
291 | def step(self, actions):
292 | self.timestep += 1
293 | action_index = 0
294 | for i, team in enumerate(self.teams.values()):
295 | for j, agent in enumerate(team):
296 | if not self.cfg['disabled_teams_step'][i]: # and not self.dones[i][j]:
297 | agent.step(actions[action_index])
298 | action_index += 1
299 |
300 | states, rewards = {}, {}
301 | for team_key, team in self.teams.items():
302 | states[team_key] = []
303 | rewards[team_key] = {}
304 | for i, agent in enumerate(team):
305 | state, done = agent.update_state()
306 | states[team_key].append(state)
307 | rewards[team_key][i] = 1 if done else 0 # reward while at goal, incentives moving as quickly as possible
308 | if done:
309 | self.dones[team_key][i] = True
310 |
311 | if self.cfg['reward_type'] == 'local':
312 | pass
313 | elif self.cfg['reward_type'] == 'greedy_only':
314 | rewards[1] = {agent_key: sum(rewards[0].values()) for agent_key in rewards[1].keys()}
315 | elif self.cfg['reward_type'] == 'coop_only':
316 | rewards[0] = {agent_key: sum(rewards[1].values()) for agent_key in rewards[0].keys()}
317 | else:
318 | raise NotImplementedError("Unknown reward type", self.cfg['reward_type'])
319 |
320 | done = self.timestep == self.cfg['max_episode_len'] # or all(self.dones[1])
321 |
322 | global_state = np.stack([self.map.map.copy() for _ in range(sum(self.cfg['n_agents']))], axis=-1).astype(np.uint8)
323 | global_state_layer = 0
324 | for team in self.teams.values():
325 | for r in team:
326 | global_state[r.pose[Y], r.pose[X], global_state_layer] = 2
327 | global_state[r.goal[Y], r.goal[X], global_state_layer] = 3
328 | global_state_layer += 1
329 |
330 | state = {
331 | 'agents': tuple([{
332 | 'map': states[key][agent_i],
333 | 'pos': self.teams[key][agent_i].pose
334 | } for key in self.teams.keys() for agent_i in range(self.cfg['n_agents'][key])]),
335 | 'gso': self.compute_gso(0),
336 | 'state': global_state
337 | }
338 |
339 | flattened_rewards = {}
340 | agent_index = 0
341 | for key in self.teams.keys():
342 | for r in rewards[key].values():
343 | flattened_rewards[agent_index] = r
344 | agent_index += 1
345 | info = {
346 | 'rewards_teams': rewards,
347 | 'rewards': flattened_rewards
348 | }
349 | return state, sum([sum(t.values()) for i, t in enumerate(rewards.values()) if not self.cfg['disabled_teams_step'][i]]), done, info
350 |
351 | def clear_patches(self, ax):
352 | [p.remove() for p in reversed(ax.patches)]
353 | [t.remove() for t in reversed(ax.texts)]
354 |
355 | def render_adjacency(self, A, team_id, ax, color='b', stepsize=1.0):
356 | A = A.copy()
357 | own_team_agents = [agent for agent in self.teams[team_id]]
358 | other_agents = [agent for other_team_id, team in self.teams.items() for agent in team if not team_id == other_team_id]
359 | all_agents = own_team_agents + other_agents
360 | for agent_id, agent in enumerate(all_agents):
361 | for connected_agent_id in np.arange(len(A)):
362 | if A[agent_id][connected_agent_id] > 0:
363 | current_agent_pose = agent.prev_pose + (agent.pose - agent.prev_pose) * stepsize
364 | other_agent = all_agents[connected_agent_id]
365 | other_agent_pose = other_agent.prev_pose + (other_agent.pose - other_agent.prev_pose) * stepsize
366 | ax.add_patch(patches.ConnectionPatch(
367 | [current_agent_pose[X], current_agent_pose[Y]],
368 | [other_agent_pose[X], other_agent_pose[Y]],
369 | "data", edgecolor='g', facecolor='none', lw=1, ls=":"
370 | ))
371 |
372 | A[connected_agent_id][agent_id] = 0 # don't draw same connection again
373 |
374 | def render_overview(self, ax, stepsize=1.0):
375 | if not hasattr(self, 'im_map'):
376 | ax.set_xticks([])
377 | ax.set_yticks([])
378 | self.im_map = ax.imshow(np.zeros(self.map.shape), vmin=0, vmax=1)
379 |
380 | self.im_map.set_data(self.map_colormap(self.map.map))
381 | agent_i = 0
382 | for (team_key, team) in self.teams.items():
383 | if self.cfg['disabled_teams_step'][team_key]:
384 | continue
385 | for agent in team:
386 | rect_size = 1
387 | pose_microstep = agent.prev_pose + (agent.pose - agent.prev_pose)*stepsize
388 | rect = patches.Rectangle((pose_microstep[1] - rect_size / 2, pose_microstep[0] - rect_size / 2), rect_size, rect_size,
389 | linewidth=1, edgecolor=self.teams_colors[team_key], facecolor='none')
390 | ax.add_patch(rect)
391 | ax.text(pose_microstep[1]-0.45, pose_microstep[0], f"{agent_i}", color=self.teams_colors[team_key])
392 | agent_i += 1
393 |
394 | #ax.set_title(
395 | # f'Global coverage: {int(self.map.get_coverage_fraction()*100)}%\n'
396 | #)
397 |
398 | def render_goals(self, ax):
399 | agent_i = 0
400 | for team_key, team in self.teams.items():
401 | if self.cfg['disabled_teams_step'][team_key]:
402 | continue
403 | for agent in team:
404 | rect = patches.Circle((agent.goal[1], agent.goal[0]), 0.1,
405 | linewidth=1, facecolor=self.teams_colors[team_key])
406 | ax.add_patch(rect)
407 | ax.text(agent.goal[1] - 0.45, agent.goal[0] + 0.5, f"{agent_i}", color=self.teams_colors[team_key])
408 | agent_i += 1
409 |
410 | def render_connectivity(self, ax, agent_id, K):
411 | if K <= 1:
412 | return
413 |
414 | for connected_agent_id in np.arange(self.cfg['n_agents'])[self.A[agent_id] == 1]:
415 | current_agent_pose = self.robots[agent_id].pose
416 | connected_agent_d_pose = self.robots[connected_agent_id].pose - current_agent_pose
417 | ax.add_patch(patches.Arrow(
418 | current_agent_pose[X],
419 | current_agent_pose[Y],
420 | connected_agent_d_pose[X],
421 | connected_agent_d_pose[Y],
422 | edgecolor='b',
423 | facecolor='none'
424 | ))
425 | self.render_connectivity(ax, connected_agent_id, K-1)
426 |
427 | def render_future_steps(self, future_steps, ax, stepsize=1.0):
428 | if future_steps is None:
429 | return
430 |
431 | current_agent_i = 0
432 | for team_key, team in self.teams.items():
433 | for agent in team:
434 | if current_agent_i not in future_steps:
435 | current_agent_i += 1
436 | continue
437 |
438 | previous_agent_pose = future_steps[current_agent_i][0].copy()
439 | for i, current_pos in enumerate(future_steps[current_agent_i][1:]):
440 | if i == len(future_steps[current_agent_i][1:])-1:
441 | current_pos = previous_agent_pose + (
442 | future_steps[current_agent_i][-1] - previous_agent_pose) * stepsize
443 | ax.add_patch(
444 | patches.Rectangle((current_pos[1] - 1/2, current_pos[0] - 1/2), 1, 1,
445 | linewidth=1, edgecolor=self.teams_colors[team_key],
446 | facecolor='none', ls=":")
447 | )
448 | ax.add_patch(patches.ConnectionPatch(
449 | [previous_agent_pose[X], previous_agent_pose[Y]],
450 | [current_pos[X], current_pos[Y]],
451 | "data", edgecolor=self.teams_colors[team_key], facecolor='none', lw=2
452 | ))
453 |
454 | previous_agent_pose = current_pos.copy()
455 |
456 | current_agent_i += 1
457 |
458 | def render(self, mode='human', future_steps=None, stepsize=1.0):
459 | if self.fig is None:
460 | self.fig = plt.figure(figsize=(3, 3))
461 | self.ax_overview = self.fig.add_subplot(1, 1, 1, aspect='equal')
462 |
463 | self.clear_patches(self.ax_overview)
464 | self.render_future_steps(future_steps, self.ax_overview, stepsize)
465 | self.render_overview(self.ax_overview, stepsize)
466 | self.render_goals(self.ax_overview)
467 | A = self.compute_gso(0)
468 | self.render_adjacency(A, 0, self.ax_overview, stepsize=stepsize)
469 |
470 | self.fig.canvas.draw()
471 | self.fig.subplots_adjust(bottom=0, top=1, left=0, right=1)
472 | return self.fig
473 |
474 | class PathPlanningEnvSaliency(PathPlanningEnv):
475 | def __init__(self, cfg):
476 | super().__init__(cfg)
477 |
478 | def render(self, mode='human', saliency_obs=None, saliency_pos=None):
479 | if self.fig is None:
480 | plt.ion()
481 | self.fig = plt.figure(constrained_layout=True, figsize=(16, 10))
482 | grid_spec = self.fig.add_gridspec(ncols=max(self.cfg['n_agents']) * 2,
483 | nrows=1 + 3 * len(self.cfg['n_agents']),
484 | height_ratios=[1] + [1, 1, 1] * len(self.cfg['n_agents']))
485 |
486 | self.ax_overview = self.fig.add_subplot(grid_spec[0, :])
487 |
488 | self.ax_im_agent = {}
489 | for team_key, team in self.teams.items():
490 | self.ax_im_agent[team_key] = []
491 | for i in range(self.cfg['n_agents'][team_key]):
492 | self.ax_im_agent[team_key].append({})
493 | for j, col_id in enumerate(['map', 'goal', 'pos']):
494 | self.ax_im_agent[team_key][i][col_id] = {}
495 | for k, row_id in enumerate(['obs', 'sal']):
496 | ax = self.fig.add_subplot(grid_spec[j + 1 + team_key * 2, i * 2 + k])
497 | ax.set_xticks([])
498 | ax.set_yticks([])
499 | self.ax_im_agent[team_key][i][col_id][row_id] = {'ax': ax, 'im': None}
500 | self.ax_im_agent[team_key][i][col_id]['sal']['im'] = \
501 | self.ax_im_agent[team_key][i][col_id]['sal']['ax'].imshow(
502 | np.zeros((1, 1)), vmin=-5, vmax=5)
503 | self.ax_im_agent[team_key][i]['map']['obs']['im'] = self.ax_im_agent[team_key][i]['map']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=3)
504 | self.ax_im_agent[team_key][i]['goal']['obs']['im'] = self.ax_im_agent[team_key][i]['goal']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=1)
505 | self.ax_im_agent[team_key][i]['pos']['obs']['im'] = self.ax_im_agent[team_key][i]['pos']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=1)
506 |
507 |
508 | self.clear_patches(self.ax_overview)
509 | self.render_overview(self.ax_overview)
510 | A = self.compute_gso(0)
511 | self.render_adjacency(A, 0, self.ax_overview)
512 |
513 | if saliency_obs is not None:
514 | saliency_limits = (np.min(saliency_obs), np.max(saliency_obs))
515 | saliency_map_id = 0
516 | for team_key, team in self.teams.items():
517 | for i, robot in enumerate(team):
518 | self.ax_im_agent[team_key][i]['map']['obs']['im'].set_data(self.map_colormap(robot.state[..., 0]))
519 | self.ax_im_agent[team_key][i]['goal']['obs']['im'].set_data(self.map_colormap(robot.state[..., 1]))
520 | self.ax_im_agent[team_key][i]['pos']['obs']['im'].set_data(self.map_colormap(robot.state[..., 2]))
521 |
522 | if saliency_obs is not None:
523 | self.ax_im_agent[team_key][i]['map']['sal']['im'].set_data(saliency_obs[saliency_map_id][..., 0])
524 | self.ax_im_agent[team_key][i]['map']['sal']['im'].set_clim(saliency_limits[0], saliency_limits[1])
525 | self.ax_im_agent[team_key][i]['coverage']['sal']['im'].set_data(saliency_obs[saliency_map_id][..., 1])
526 | self.ax_im_agent[team_key][i]['coverage']['sal']['im'].set_clim(saliency_limits[0], saliency_limits[1])
527 |
528 | saliency_map_id += 1
529 |
530 | if False: # saliency_pos is not None:
531 | print("T", saliency_pos[i][2:].numpy())
532 | self.ax_im_agent[i]['map']['sal']['ax'].set_title(
533 | f'{saliency_pos[i][0]:.2f}\n{saliency_pos[i][1]:.2f}\n{np.mean(saliency_pos[i][2:].numpy()):.2f}')
534 |
535 | #self.ax_im_agent[team_key][i]['map']['obs']['ax'].set_title(
536 | # f'{i}\nc: {0:.2f}\nr: {robot.reward:.2f}')
537 |
538 | # self.render_connectivity(self.ax_overview, 0, 3)
539 | self.fig.canvas.draw()
540 | return self.fig
541 |
542 | class PathPlanningEnvOverview(PathPlanningEnv):
543 | def __init__(self, cfg):
544 | super().__init__(cfg)
545 | self.map_colormap = colors.ListedColormap(['white', 'black', 'blue', 'red', 'green']) # free, obstacle, pos, goal
546 |
547 | def render(self, mode='human'):
548 | if self.fig is None:
549 | plt.ion()
550 | self.fig = plt.figure(constrained_layout=True, figsize=(16, 10))
551 | grid_spec = self.fig.add_gridspec(ncols=max(self.cfg['n_agents']),
552 | nrows=1 + len(self.cfg['n_agents']),
553 | height_ratios=[1] + [1] * len(self.cfg['n_agents']))
554 |
555 | self.ax_overview = self.fig.add_subplot(grid_spec[0, :])
556 |
557 | self.ax_im_agent = {}
558 | for team_key, team in self.teams.items():
559 | self.ax_im_agent[team_key] = []
560 | for i in range(self.cfg['n_agents'][team_key]):
561 | self.ax_im_agent[team_key].append({})
562 | for j, col_id in enumerate(['overview']):
563 | self.ax_im_agent[team_key][i][col_id] = {}
564 | for k, row_id in enumerate(['obs']):
565 | ax = self.fig.add_subplot(grid_spec[j + 1 + team_key , i + k])
566 | ax.set_xticks([])
567 | ax.set_yticks([])
568 | self.ax_im_agent[team_key][i][col_id][row_id] = {'ax': ax, 'im': None}
569 | self.ax_im_agent[team_key][i]['overview']['obs']['im'] = self.ax_im_agent[team_key][i]['overview']['obs']['ax'].imshow(np.zeros((1, 1)), vmin=0, vmax=4)
570 |
571 |
572 | self.clear_patches(self.ax_overview)
573 | self.render_overview(self.ax_overview)
574 | self.render_goals(self.ax_overview)
575 | A = self.compute_gso(0)
576 | self.render_adjacency(A, 0, self.ax_overview)
577 |
578 | saliency_map_id = 0
579 | for team_key, team in self.teams.items():
580 | for i, robot in enumerate(team):
581 | state = robot.state[..., 0].copy().astype(np.uint8) # map
582 | state[robot.state[..., 1]==1] = 2
583 | state[robot.state[..., 2]==1] = 3
584 | state[robot.state[..., 2]==2] = 4
585 | #state[robot.state[..., 2]] = 3
586 | self.ax_im_agent[team_key][i]['overview']['obs']['im'].set_data(self.map_colormap(state))
587 |
588 | self.ax_im_agent[team_key][i]['overview']['obs']['ax'].set_title(
589 | f'{i}') #\nc: {0:.2f}\nr: {robot.reward:.2f}')
590 |
591 | # self.render_connectivity(self.ax_overview, 0, 3)
592 | self.fig.canvas.draw()
593 | return self.fig
594 |
--------------------------------------------------------------------------------
/adversarial_comms/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections.abc
3 | import json
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import os
7 | import pandas as pd
8 | import ray
9 | import time
10 | import traceback
11 |
12 | from pathlib import Path
13 | from ray.rllib.models import ModelCatalog
14 | from ray.tune.logger import NoopLogger
15 | from ray.tune.registry import register_env
16 | from ray.util.multiprocessing import Pool
17 |
18 | from .environments.coverage import CoverageEnv
19 | from .environments.path_planning import PathPlanningEnv
20 | from .models.adversarial import AdversarialModel
21 | from .trainers.multiagent_ppo import MultiPPOTrainer
22 | from .trainers.random_heuristic import RandomHeuristicTrainer
23 |
24 | def update_dict(d, u):
25 | for k, v in u.items():
26 | if isinstance(v, collections.abc.Mapping):
27 | d[k] = update_dict(d.get(k, {}), v)
28 | else:
29 | d[k] = v
30 | return d
31 |
32 | def run_trial(trainer_class=MultiPPOTrainer, checkpoint_path=None, trial=0, cfg_update={}, render=False):
33 | try:
34 | t0 = time.time()
35 | cfg = {'env_config': {}, 'model': {}}
36 | if checkpoint_path is not None:
37 | # We might want to run policies that are not loaded from a checkpoint
38 | # (e.g. the random policy) and therefore need this to be optional
39 | with open(Path(checkpoint_path).parent/"params.json") as json_file:
40 | cfg = json.load(json_file)
41 |
42 | if 'evaluation_config' in cfg:
43 | # overwrite the environment config with evaluation one if it exists
44 | cfg = update_dict(cfg, cfg['evaluation_config'])
45 |
46 | cfg = update_dict(cfg, cfg_update)
47 |
48 | trainer = trainer_class(
49 | env=cfg['env'],
50 | logger_creator=lambda config: NoopLogger(config, ""),
51 | config={
52 | "framework": "torch",
53 | "seed": trial,
54 | "num_workers": 0,
55 | "env_config": cfg['env_config'],
56 | "model": cfg['model']
57 | }
58 | )
59 | if checkpoint_path is not None:
60 | checkpoint_file = Path(checkpoint_path)/('checkpoint-'+os.path.basename(checkpoint_path).split('_')[-1])
61 | trainer.restore(str(checkpoint_file))
62 |
63 | envs = {'coverage': CoverageEnv, 'path_planning': PathPlanningEnv}
64 | env = envs[cfg['env']](cfg['env_config'])
65 | env.seed(trial)
66 | obs = env.reset()
67 |
68 | results = []
69 | for i in range(cfg['env_config']['max_episode_len']):
70 | actions = trainer.compute_action(obs)
71 | obs, reward, done, info = env.step(actions)
72 | if render:
73 | env.render()
74 | for j, reward in enumerate(list(info['rewards'].values())):
75 | results.append({
76 | 'step': i,
77 | 'agent': j,
78 | 'trial': trial,
79 | 'reward': reward
80 | })
81 |
82 | print("Done", time.time() - t0)
83 | except Exception as e:
84 | print(e, traceback.format_exc())
85 | raise
86 | df = pd.DataFrame(results)
87 | return df
88 |
89 | def path_to_hash(path):
90 | path_split = path.split('/')
91 | checkpoint_number_string = path_split[-1].split('_')[-1]
92 | path_hash = path_split[-2].split('_')[-2]
93 | return path_hash + '-' + checkpoint_number_string
94 |
95 | def serve_config(checkpoint_path, trials, cfg_change={}, trainer=MultiPPOTrainer):
96 | with Pool() as p:
97 | results = pd.concat(p.starmap(run_trial, [(trainer, checkpoint_path, t, cfg_change) for t in range(trials)]))
98 | return results
99 |
100 | def initialize():
101 | ray.init()
102 | register_env("coverage", lambda config: CoverageEnv(config))
103 | register_env("path_planning", lambda config: PathPlanningEnv(config))
104 | ModelCatalog.register_custom_model("adversarial", AdversarialModel)
105 |
106 | def eval_nocomm(env_config_func, prefix):
107 | parser = argparse.ArgumentParser()
108 | parser.add_argument("checkpoint")
109 | parser.add_argument("out_path")
110 | parser.add_argument("-t", "--trials", type=int, default=100)
111 | args = parser.parse_args()
112 |
113 | initialize()
114 | results = []
115 | for comm in [False, True]:
116 | cfg_change={'env_config': env_config_func(comm)}
117 | df = serve_config(args.checkpoint, args.trials, cfg_change=cfg_change, trainer=MultiPPOTrainer)
118 | df['comm'] = comm
119 | results.append(df)
120 |
121 | with open(Path(args.checkpoint).parent/"params.json") as json_file:
122 | cfg = json.load(json_file)
123 | if 'evaluation_config' in cfg:
124 | update_dict(cfg, cfg['evaluation_config'])
125 |
126 | df = pd.concat(results)
127 | df.attrs = cfg
128 | filename = prefix + "-" + path_to_hash(args.checkpoint) + ".pkl"
129 | df.to_pickle(Path(args.out_path)/filename)
130 |
131 | def eval_nocomm_coop():
132 | # Cooperative agents can communicate or not (without comm interference from adversarial agent)
133 | eval_nocomm(lambda comm: {
134 | 'disabled_teams_comms': [True, not comm],
135 | 'disabled_teams_step': [True, False]
136 | }, "eval_coop")
137 |
138 | def eval_nocomm_adv():
139 | # all cooperative agents can still communicate, but adversarial communication is switched
140 | eval_nocomm(lambda comm: {
141 | 'disabled_teams_comms': [not comm, False], # en/disable comms for adv and always enabled for coop
142 | 'disabled_teams_step': [False, False] # both teams operating
143 | }, "eval_adv")
144 |
145 | def plot_agent(ax, df, color, step_aggregation='sum', linestyle='-'):
146 | world_shape = df.attrs['env_config']['world_shape']
147 | max_cov = world_shape[0]*world_shape[1]*df.attrs['env_config']['min_coverable_area_fraction']
148 | d = (df.sort_values(['trial', 'step']).groupby(['trial', 'step'])['reward'].apply(step_aggregation, 'step').groupby('trial').cumsum()/max_cov*100).groupby('step')
149 | ax.plot(d.mean(), color=color, ls=linestyle)
150 | ax.fill_between(np.arange(len(d.mean())), np.clip(d.mean()-d.std(), 0, None), d.mean()+d.std(), alpha=0.1, color=color)
151 |
152 | def plot():
153 | parser = argparse.ArgumentParser()
154 | parser.add_argument("data")
155 | parser.add_argument("-o", "--out_file", default=None)
156 | args = parser.parse_args()
157 |
158 | fig_overview = plt.figure(figsize=[4, 4])
159 | ax = fig_overview.subplots(1, 1)
160 |
161 | df = pd.read_pickle(args.data)
162 | if Path(args.data).name.startswith('eval_adv'):
163 | plot_agent(ax, df[(df['comm'] == False) & (df['agent'] == 0)], 'r', step_aggregation='mean', linestyle=':')
164 | plot_agent(ax, df[(df['comm'] == False) & (df['agent'] > 0)], 'b', step_aggregation='mean', linestyle=':')
165 | plot_agent(ax, df[(df['comm'] == True) & (df['agent'] == 0)], 'r', step_aggregation='mean', linestyle='-')
166 | plot_agent(ax, df[(df['comm'] == True) & (df['agent'] > 0)], 'b', step_aggregation='mean', linestyle='-')
167 | elif Path(args.data).name.startswith('eval_coop'):
168 | plot_agent(ax, df[(df['comm'] == False) & (df['agent'] > 0)], 'b', step_aggregation='sum', linestyle=':')
169 | plot_agent(ax, df[(df['comm'] == True) & (df['agent'] > 0)], 'b', step_aggregation='sum', linestyle='-')
170 | elif Path(args.data).name.startswith('eval_rand'):
171 | plot_agent(ax, df[df['agent'] > 0], 'b', step_aggregation='sum', linestyle='-')
172 |
173 | ax.set_ylabel("Coverage %")
174 | ax.set_ylim(0, 100)
175 | ax.set_xlabel("Episode time steps")
176 | ax.margins(x=0, y=0)
177 | ax.grid()
178 |
179 | fig_overview.tight_layout()
180 | if args.out_file is not None:
181 | fig_overview.savefig(args.out_file, dpi=300)
182 |
183 | plt.show()
184 |
185 | def serve():
186 | parser = argparse.ArgumentParser()
187 | parser.add_argument("checkpoint")
188 | parser.add_argument("-s", "--seed", type=int, default=0)
189 | args = parser.parse_args()
190 |
191 | initialize()
192 | run_trial(checkpoint_path=args.checkpoint, trial=args.seed, render=True)
193 |
194 |
--------------------------------------------------------------------------------
/adversarial_comms/generate_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import ray
3 | from ray.util.multiprocessing import Pool
4 | import json
5 | import os
6 | from ray.tune.registry import register_env
7 | from ray.rllib.models import ModelCatalog
8 | from ray.tune.logger import NoopLogger
9 | #from model_team_adversarial import AdaptedVisionNetwork as AdversarialTeamModel
10 | #from model_team_adversarial_2 import AdaptedVisionNetwork as AdversarialTeamModel2
11 | from model_team_adversarial_2_vaegp import AdaptedVisionNetwork as AdversarialTeamModel2VAEGP
12 | from multiagent_ppo_trainer_2 import MultiPPOTrainer as MultiPPOTrainer2
13 | import matplotlib.style as mplstyle
14 | mplstyle.use('fast')
15 |
16 | from world_teams_2 import World as TeamWorld2
17 | from world_flow import WorldOverview as FlowWorld
18 | import pickle
19 | import torch
20 |
21 | import copy
22 |
23 | def generate(seed, checkpoint_path, sample_iterations, termination_mode, frame_take_prob=0.1, disable_adv_comm=False, ensure_conn=False, t_fac=1.5):
24 | with open(checkpoint_path + '/../params.json') as json_file:
25 | checkpoint_config = json.load(json_file)
26 |
27 | checkpoint_config['env_config']['ensure_connectivity'] = ensure_conn
28 |
29 | checkpoint_config['env_config']['disabled_teams_comms'] = [disable_adv_comm, False]
30 | checkpoint_config['env_config']['disabled_teams_step'] = [False, False]
31 |
32 | trainer_cfg = {
33 | "framework": "torch",
34 | "num_workers": 1,
35 | "num_gpus": 1,
36 | "env_config": checkpoint_config['env_config'],
37 | "model": checkpoint_config['model'],
38 | "seed": seed
39 | }
40 |
41 | trainer = MultiPPOTrainer2(
42 | logger_creator=lambda config: NoopLogger(config, ""),
43 | env=checkpoint_config['env'],
44 | config=trainer_cfg
45 | )
46 | checkpoint_file = checkpoint_path + '/checkpoint-' + os.path.basename(checkpoint_path).split('_')[-1]
47 | trainer.restore(checkpoint_file)
48 |
49 | envs = {
50 | 'flowworld': FlowWorld,
51 | 'teamworld2': TeamWorld2
52 | }
53 | env = envs[checkpoint_config['env']](checkpoint_config['env_config'])
54 | env.seed(seed)
55 | obs = env.reset()
56 |
57 | samples = []
58 | model = trainer.get_policy().model
59 |
60 | cnn_outputs = []
61 | def record_cnn_output(module, input_, output):
62 | cnn_outputs.append(output[0].detach().cpu().numpy())
63 | gnn_outputs = []
64 | def record_gnn_output(module, input_, output):
65 | gnn_outputs.append(output[0].detach().cpu().numpy())
66 | #model.coop_convs[-1].register_forward_hook(record_cnn_output)
67 | #model.greedy_convs[-1].register_forward_hook(record_cnn_output)
68 | model.GFL.register_forward_hook(record_gnn_output)
69 |
70 | while len(samples) < sample_iterations:
71 | actions = trainer.compute_action(obs)
72 | for j in range(1, sum(checkpoint_config['env_config']['n_agents'])):
73 | #obs['agents'][j]['cnn_out'] = cnn_outputs[j]
74 | z, mu, log = model.coop_vaegp.vae.encode(torch.from_numpy(np.array([obs['agents'][j]['map']])).float().permute(0,3,1,2))
75 | obs['agents'][j]['cnn_out'] = z[0].detach()
76 | obs['agents'][j]['gnn_out'] = gnn_outputs[0][..., j]
77 | cnn_outputs = []
78 | gnn_outputs = []
79 |
80 | if np.random.rand() <= frame_take_prob:
81 | samples.append(copy.deepcopy({'obs': obs, 'actions': actions}))
82 | print(len(samples))
83 |
84 | obs, reward, done, info = env.step(actions)
85 | if (termination_mode == 'path' and done) or (termination_mode == 'cov' and env.timestep == int(t_fac*(info['coverable_area']/checkpoint_config['env_config']['n_agents'][1]))):
86 | obs = env.reset()
87 | return samples
88 |
89 | def run(seed, checkpoint_path, samples, workers, generated_path, termination_mode, frame_take_prob=0.2, disable_adv_comm=False, t_fac=1.5):
90 | results = []
91 | with Pool(workers) as p:
92 | for res in p.starmap(generate, [(seed+i, checkpoint_path, int(samples/workers), termination_mode, frame_take_prob, disable_adv_comm, t_fac) for i in range(workers)]):
93 | results += res
94 | print("DONE", len(results))
95 | pickle.dump(results, open(generated_path, "wb"))
96 |
97 | if __name__ == "__main__":
98 | ray.init()
99 | #ModelCatalog.register_custom_model("vis_torch_adv_team", AdversarialTeamModel)
100 | #ModelCatalog.register_custom_model("vis_torch_adv_team_2", AdversarialTeamModel2)
101 | ModelCatalog.register_custom_model("vis_torch_adv_team_2_vaegp", AdversarialTeamModel2VAEGP)
102 |
103 | register_env("teamworld2", lambda config: TeamWorld2(config))
104 | register_env("flowworld", lambda config: FlowWorld(config))
105 |
106 | # cooperative trainings
107 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-19_01-15-57_hu8xcpq/checkpoint_1560" # coverage
108 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-18_23-44-12k2_enqa8/checkpoint_150" # split
109 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_flowworld_0_2020-07-16_00-47-53k6vmhzpl/checkpoint_1300" # flow 7x7
110 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_flowworld_0_2020-07-16_10-53-29pe06c7bw/checkpoint_3100" # flow 24x24
111 |
112 | # adversarial
113 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-20_23-31-52zj__fmp3/checkpoint_4600" # cov
114 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-19_09-34-02u_h77o5y/checkpoint_1400" # split
115 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_flowworld_0_2020-07-16_10-59-27c8iboc7_/checkpoint_3800" # flow 7x7
116 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_flowworld_0_2020-07-24_00-56-5896e2idut/checkpoint_8100" # flow 24x24
117 |
118 | #re-adapt
119 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-27_00-38-42vpz3xf0k/checkpoint_5690" # cov
120 | #checkpoint_path = "/local/scratch/jb2270/corl_evaluation/MultiPPO/MultiPPO_teamworld2_0_2020-07-27_00-48-12zecm5uk7/checkpoint_2190" # split
121 |
122 | #checkpoint_path = "/local/scratch/jb2270/vaegp_eval/MultiPPO/MultiPPO_teamworld2_0_2020-08-23_11-27-05u14jlcjb/checkpoint_1560" # simple
123 |
124 | #checkpoint_path = "/local/scratch/jb2270/vaegp_eval/MultiPPO/MultiPPO_teamworld2_0_2020-08-24_20-43-40mnea2uga/checkpoint_1560" # train with frozen VAE
125 | checkpoint_path = "/local/scratch/jb2270/vaegp_eval/MultiPPO/MultiPPO_teamworld2_c8d29_00000/checkpoint_410"
126 |
127 | termination_mode = "cov" # cov/path
128 |
129 | checkpoint_num = checkpoint_path.split("_")[-1]
130 | checkpoint_id = checkpoint_path.split("/")[-2].split("-")[-1]
131 | #generate(0, checkpoint_path, 1000, 0.1, 1.5)
132 | #exit()
133 | run(0, checkpoint_path, 50000, 32, f"/local/scratch/jb2270/datasets_corl/explainability_data_{checkpoint_id}_{checkpoint_num}_train.pkl",termination_mode, disable_adv_comm=True)
134 | run(1, checkpoint_path, 10000, 32, f"/local/scratch/jb2270/datasets_corl/explainability_data_{checkpoint_id}_{checkpoint_num}_valid.pkl", termination_mode, disable_adv_comm=True)
135 | run(2, checkpoint_path, 1000, 1, f"/local/scratch/jb2270/datasets_corl/explainability_data_{checkpoint_id}_{checkpoint_num}_test.pkl", termination_mode, disable_adv_comm=True, frame_take_prob=1.0, t_fac=4)
136 | #run(2, checkpoint_path, 1000, 1, f"/local/scratch/jb2270/datasets_corl/explainability_data_{checkpoint_id}_{checkpoint_num}_nocomm_test.pkl", termination_mode, disable_adv_comm=True, frame_take_prob=1.0)
137 |
138 |
--------------------------------------------------------------------------------
/adversarial_comms/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proroklab/adversarial_comms/889a41e239958bd519365c6b0c143469cd27e49f/adversarial_comms/models/__init__.py
--------------------------------------------------------------------------------
/adversarial_comms/models/adversarial.py:
--------------------------------------------------------------------------------
1 | from ray.rllib.models.modelv2 import ModelV2
2 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
3 | from ray.rllib.policy.rnn_sequencing import add_time_dimension
4 | from ray.rllib.models.torch.misc import normc_initializer, same_padding, SlimConv2d, SlimFC
5 | from ray.rllib.utils.annotations import override
6 | from ray.rllib.utils import try_import_torch
7 |
8 | from .gnn import adversarialGraphML as gml_adv
9 | from .gnn import graphML as gml
10 | from .gnn import graphTools
11 | import numpy as np
12 | import copy
13 |
14 | torch, nn = try_import_torch()
15 | from torchsummary import summary
16 |
17 | # https://ray.readthedocs.io/en/latest/using-ray-with-pytorch.html
18 |
19 | DEFAULT_OPTIONS = {
20 | "activation": "relu",
21 | "agent_split": 1,
22 | "cnn_compression": 512,
23 | "cnn_filters": [[32, [8, 8], 4], [64, [4, 4], 2], [128, [4, 4], 2]],
24 | "cnn_residual": False,
25 | "freeze_coop": True,
26 | "freeze_coop_value": False,
27 | "freeze_greedy": False,
28 | "freeze_greedy_value": False,
29 | "graph_edge_features": 1,
30 | "graph_features": 512,
31 | "graph_layers": 1,
32 | "graph_tabs": 3,
33 | "relative": True,
34 | "value_cnn_compression": 512,
35 | "value_cnn_filters": [[32, [8, 8], 2], [64, [4, 4], 2], [128, [4, 4], 2]],
36 | "forward_values": True
37 | }
38 |
39 | class AdversarialModel(TorchModelV2, nn.Module):
40 | def __init__(self, obs_space, action_space, num_outputs, model_config, name):#,
41 | #graph_layers, graph_features, graph_tabs, graph_edge_features, cnn_filters, value_cnn_filters, value_cnn_compression, cnn_compression, relative, activation):
42 | TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
43 | nn.Module.__init__(self)
44 |
45 | self.cfg = copy.deepcopy(DEFAULT_OPTIONS)
46 | self.cfg.update(model_config['custom_model_config'])
47 |
48 | #self.cfg = model_config['custom_options']
49 | self.n_agents = len(obs_space.original_space['agents'])
50 | self.graph_features = self.cfg['graph_features']
51 | self.cnn_compression = self.cfg['cnn_compression']
52 | self.activation = {
53 | 'relu': nn.ReLU,
54 | 'leakyrelu': nn.LeakyReLU
55 | }[self.cfg['activation']]
56 |
57 | layers = []
58 | input_shape = obs_space.original_space['agents'][0]['map'].shape
59 | (w, h, in_channels) = input_shape
60 |
61 | in_size = [w, h]
62 | for out_channels, kernel, stride in self.cfg['cnn_filters'][:-1]:
63 | padding, out_size = same_padding(in_size, kernel, [stride, stride])
64 | layers.append(SlimConv2d(in_channels, out_channels, kernel, stride, padding, activation_fn=self.activation))
65 | in_channels = out_channels
66 | in_size = out_size
67 |
68 | out_channels, kernel, stride = self.cfg['cnn_filters'][-1]
69 | layers.append(
70 | SlimConv2d(in_channels, out_channels, kernel, stride, None))
71 | layers.append(nn.Flatten(1, -1))
72 | #if isinstance(cnn_compression, int):
73 | # layers.append(nn.Linear(cnn_compression, self.cfg['graph_features']-2)) # reserve 2 for pos
74 | # layers.append(self.activation{))
75 | self.coop_convs = nn.Sequential(*layers)
76 | self.greedy_convs = copy.deepcopy(self.coop_convs)
77 |
78 | self.coop_value_obs_convs = copy.deepcopy(self.coop_convs)
79 | self.greedy_value_obs_convs = copy.deepcopy(self.coop_convs)
80 |
81 | summary(self.coop_convs, device="cpu", input_size=(input_shape[2], input_shape[0], input_shape[1]))
82 |
83 | gfl = []
84 | for i in range(self.cfg['graph_layers']):
85 | gfl.append(gml_adv.GraphFilterBatchGSOA(self.graph_features, self.graph_features, self.cfg['graph_tabs'], self.cfg['agent_split'], self.cfg['graph_edge_features'], False))
86 | #gfl.append(gml.GraphFilterBatchGSO(self.graph_features, self.graph_features, self.cfg['graph_tabs'], self.cfg['graph_edge_features'], False))
87 | gfl.append(self.activation())
88 |
89 | self.GFL = nn.Sequential(*gfl)
90 |
91 | #gso_sum = torch.zeros(2, 1, 8, 8)
92 | #self.GFL[0].addGSO(gso_sum)
93 | #summary(self.GFL, device="cuda" if torch.cuda.is_available() else "cpu", input_size=(self.graph_features, 8))
94 |
95 | logits_inp_features = self.graph_features
96 | if self.cfg['cnn_residual']:
97 | logits_inp_features += self.cnn_compression
98 |
99 | post_logits = [
100 | nn.Linear(logits_inp_features, 64),
101 | self.activation(),
102 | nn.Linear(64, 32),
103 | self.activation()
104 | ]
105 | logit_linear = nn.Linear(32, 5)
106 | nn.init.xavier_uniform_(logit_linear.weight)
107 | nn.init.constant_(logit_linear.bias, 0)
108 | post_logits.append(logit_linear)
109 | self.coop_logits = nn.Sequential(*post_logits)
110 | self.greedy_logits = copy.deepcopy(self.coop_logits)
111 | summary(self.coop_logits, device="cpu", input_size=(logits_inp_features,))
112 |
113 | ##############################
114 |
115 | layers = []
116 | input_shape = np.array(obs_space.original_space['state'].shape)
117 | (w, h, in_channels) = input_shape
118 |
119 | in_size = [w, h]
120 | for out_channels, kernel, stride in self.cfg['value_cnn_filters'][:-1]:
121 | padding, out_size = same_padding(in_size, kernel, [stride, stride])
122 | layers.append(SlimConv2d(in_channels, out_channels, kernel, stride, padding, activation_fn=self.activation))
123 | in_channels = out_channels
124 | in_size = out_size
125 |
126 | out_channels, kernel, stride = self.cfg['value_cnn_filters'][-1]
127 | layers.append(
128 | SlimConv2d(in_channels, out_channels, kernel, stride, None))
129 | layers.append(nn.Flatten(1, -1))
130 |
131 | self.coop_value_cnn = nn.Sequential(*layers)
132 | self.greedy_value_cnn = copy.deepcopy(self.coop_value_cnn)
133 | summary(self.greedy_value_cnn, device="cpu", input_size=(input_shape[2], input_shape[0], input_shape[1]))
134 |
135 | layers = [
136 | nn.Linear(self.cnn_compression + self.cfg['value_cnn_compression'], 64),
137 | self.activation(),
138 | nn.Linear(64, 32),
139 | self.activation()
140 | ]
141 | values_linear = nn.Linear(32, 1)
142 | normc_initializer()(values_linear.weight)
143 | nn.init.constant_(values_linear.bias, 0)
144 | layers.append(values_linear)
145 |
146 | self.coop_value_branch = nn.Sequential(*layers)
147 | self.greedy_value_branch = copy.deepcopy(self.coop_value_branch)
148 | summary(self.coop_value_branch, device="cpu", input_size=(self.cnn_compression + self.cfg['value_cnn_compression'],))
149 |
150 | self._cur_value = None
151 |
152 | self.freeze_coop_value(self.cfg['freeze_coop_value'])
153 | self.freeze_greedy_value(self.cfg['freeze_greedy_value'])
154 | self.freeze_coop(self.cfg['freeze_coop'])
155 | self.freeze_greedy(self.cfg['freeze_greedy'])
156 |
157 | def freeze_coop(self, freeze):
158 | all_params = \
159 | list(self.coop_convs.parameters()) + \
160 | [self.GFL[0].weight1] + \
161 | list(self.coop_logits.parameters())
162 |
163 | for param in all_params:
164 | param.requires_grad = not freeze
165 |
166 | def freeze_greedy(self, freeze):
167 | all_params = \
168 | list(self.greedy_logits.parameters()) + \
169 | list(self.greedy_convs.parameters()) + \
170 | [self.GFL[0].weight0]
171 |
172 | for param in all_params:
173 | param.requires_grad = not freeze
174 |
175 | def freeze_greedy_value(self, freeze):
176 | all_params = \
177 | list(self.greedy_value_branch.parameters()) + \
178 | list(self.greedy_value_cnn.parameters()) + \
179 | list(self.greedy_value_obs_convs)
180 |
181 | for param in all_params:
182 | param.requires_grad = not freeze
183 |
184 | def freeze_coop_value(self, freeze):
185 | all_params = \
186 | list(self.coop_value_cnn.parameters()) + \
187 | list(self.coop_value_branch.parameters()) + \
188 | list(self.coop_value_obs_convs)
189 |
190 | for param in all_params:
191 | param.requires_grad = not freeze
192 |
193 | @override(ModelV2)
194 | def forward(self, input_dict, state, seq_lens):
195 | batch_size = input_dict["obs"]['gso'].shape[0]
196 | o_as = input_dict["obs"]['agents']
197 |
198 | gso = input_dict["obs"]['gso'].unsqueeze(1)
199 | device = gso.device
200 |
201 | for i in range(len(self.GFL)//2):
202 | self.GFL[i*2].addGSO(gso)
203 |
204 | greedy_cnn = self.greedy_convs(o_as[0]['map'].permute(0, 3, 1, 2))
205 | coop_agents_cnn = {id_agent: self.coop_convs(o_as[id_agent]['map'].permute(0, 3, 1, 2)) for id_agent in range(1, len(o_as))}
206 |
207 | greedy_value_obs_cnn = self.greedy_value_obs_convs(o_as[0]['map'].permute(0, 3, 1, 2))
208 | coop_value_obs_cnn = {id_agent: self.coop_value_obs_convs(o_as[id_agent]['map'].permute(0, 3, 1, 2)) for id_agent in range(1, len(o_as))}
209 |
210 | extract_feature_map = torch.zeros(batch_size, self.graph_features, self.n_agents).to(device)
211 | extract_feature_map[:, :self.cnn_compression, 0] = greedy_cnn
212 | for id_agent in range(1, len(o_as)):
213 | extract_feature_map[:, :self.cnn_compression, id_agent] = coop_agents_cnn[id_agent]
214 |
215 | shared_feature = self.GFL(extract_feature_map)
216 |
217 | logits = torch.empty(batch_size, self.n_agents, 5).to(device)
218 | values = torch.empty(batch_size, self.n_agents).to(device)
219 |
220 | logits_inp = shared_feature[..., 0]
221 | if self.cfg['cnn_residual']:
222 | logits_inp = torch.cat([logits_inp, greedy_cnn], dim=1)
223 | logits[:, 0] = self.greedy_logits(logits_inp)
224 | if self.cfg['forward_values']:
225 | greedy_value_cnn = self.greedy_value_cnn(input_dict["obs"]["state"].permute(0, 3, 1, 2))
226 | coop_value_cnn = self.coop_value_cnn(input_dict["obs"]["state"].permute(0, 3, 1, 2))
227 |
228 | values[:, 0] = self.greedy_value_branch(torch.cat([greedy_value_obs_cnn, greedy_value_cnn], dim=1)).squeeze(1)
229 |
230 | for id_agent in range(1, len(o_as)):
231 | this_entity = shared_feature[..., id_agent]
232 | if self.cfg['cnn_residual']:
233 | this_entity = torch.cat([this_entity, coop_agents_cnn[id_agent]], dim=1)
234 | logits[:, id_agent] = self.coop_logits(this_entity)
235 |
236 | if self.cfg['forward_values']:
237 | value_cat = torch.cat([coop_value_cnn, coop_value_obs_cnn[id_agent]], dim=1)
238 | values[:, id_agent] = self.coop_value_branch(value_cat).squeeze(1)
239 |
240 | self._cur_value = values
241 | return logits.view(batch_size, self.n_agents*5), state
242 |
243 | @override(ModelV2)
244 | def value_function(self):
245 | assert self._cur_value is not None, "must call forward() first"
246 | return self._cur_value
247 |
248 |
--------------------------------------------------------------------------------
/adversarial_comms/models/gnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proroklab/adversarial_comms/889a41e239958bd519365c6b0c143469cd27e49f/adversarial_comms/models/gnn/__init__.py
--------------------------------------------------------------------------------
/adversarial_comms/models/gnn/adversarialGraphML.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .graphML import GraphFilterBatchGSO
7 |
8 | def batchLSIGFA(h0, h1, N0, SK, x, bias=None, aggregation=lambda y, dim: torch.sum(y, dim=dim)):
9 | """
10 | batchLSIGF(filter_taps, GSO_K, input, bias=None) Computes the output of a
11 | linear shift-invariant graph filter on input and then adds bias.
12 |
13 | In this case, we consider that there is a separate GSO to be used for each
14 | of the signals in the batch. In other words, SK[b] is applied when filtering
15 | x[b] as opposed to applying the same SK to all the graph signals in the
16 | batch.
17 |
18 | Inputs:
19 | filter_taps: vector of filter taps; size:
20 | output_features x edge_features x filter_taps x input_features
21 | GSO_K: collection of matrices; size:
22 | batch_size x edge_features x filter_taps x number_nodes x number_nodes
23 | input: input signal; size:
24 | batch_size x input_features x number_nodes
25 | bias: size: output_features x number_nodes
26 | if the same bias is to be applied to all nodes, set number_nodes = 1
27 | so that b_{f} vector becomes b_{f} \mathbf{1}_{N}
28 |
29 | Outputs:
30 | output: filtered signals; size:
31 | batch_size x output_features x number_nodes
32 | """
33 | # Get the parameter numbers:
34 | assert h0.shape == h1.shape
35 | F = h0.shape[0]
36 | E = h0.shape[1]
37 | K = h0.shape[2]
38 | G = h0.shape[3]
39 | B = SK.shape[0]
40 | assert SK.shape[1] == E
41 | assert SK.shape[2] == K
42 | N = SK.shape[3]
43 | assert SK.shape[4] == N
44 | assert x.shape[0] == B
45 | assert x.shape[1] == G
46 | assert x.shape[2] == N
47 | # Or, in the notation I've been using:
48 | # h in F x E x K x G
49 | # SK in B x E x K x N x N
50 | # x in B x G x N
51 | # b in F x N
52 | # y in B x F x N
53 | SK = SK.permute(1, 2, 0, 3, 4)
54 | # Now, SK is of shape E x K x B x N x N so that we can multiply by x of
55 | # size B x G x N to get
56 | z = torch.matmul(x, SK)
57 | # which is of size E x K x B x G x N
58 | # Now, we have already carried out the multiplication across the dimension
59 | # of the nodes. Now we need to focus on the K, F, G.
60 | # Let's start by putting B and N in the front
61 | z = z.permute(1, 2, 4, 0, 3).reshape([K, B, N, E * G])
62 | # so that we get z in B x N x EKG.
63 | # Now adjust the filter taps so they are of the form EKG x F
64 | h0 = h0.permute(2, 1, 3, 0).reshape([K, G * E, F])
65 | h1 = h1.permute(2, 1, 3, 0).reshape([K, G * E, F])
66 | #h1 = h1.reshape([F, G * E * K]).permute(1, 0)
67 | # Multiply
68 | if N0 == 0:
69 | y = torch.empty(K, B, N, G * E).to(z.device)
70 | for k in range(K):
71 | y[k] = torch.matmul(z[k], h1[k])
72 | y = aggregation(y, 0)
73 | # to get a result of size B x N x F. And permute
74 | y = y.permute(0, 2, 1)
75 | else:
76 | z0 = z[:, :, :N0]
77 | z1 = z[:, :, N0:]
78 | y0 = torch.empty(K, B, N0, G * E).to(z.device)
79 | y1 = torch.empty(K, B, N-N0, G * E).to(z.device)
80 | for k in range(K):
81 | y0[k] = torch.matmul(z0[k], h0[k])
82 | y1[k] = torch.matmul(z1[k], h1[k])
83 | y0 = aggregation(y0, 0)
84 | y1 = aggregation(y1, 0)
85 | # to get a result of size B x N x F. And permute
86 | y0 = y0.permute(0, 2, 1)
87 | y1 = y1.permute(0, 2, 1)
88 | y = torch.cat([y0, y1], dim = 2) # concat along N
89 | # to get it back in the right order: B x F x N.
90 | # Now, in this case, each element x[b,:,:] has adequately been filtered by
91 | # the GSO S[b,:,:,:]
92 | if bias is not None:
93 | y = y + bias
94 | return y
95 |
96 | class GraphFilterBatchGSOA(GraphFilterBatchGSO):
97 | def __init__(self, G, F, K, N0, E = 1, bias = True, aggregation='sum'):
98 | super().__init__(G, F, K, E, bias)
99 | self.weight0 = self.weight
100 | self.weight1 = nn.parameter.Parameter(torch.Tensor(self.F, self.E, self.K, self.G))
101 | self.N0 = N0
102 | self.reset_parameters()
103 | self.aggregation = {
104 | "sum": lambda y, dim: torch.sum(y, dim=dim),
105 | "median": lambda y, dim: torch.median(y, dim=dim)[0],
106 | "min": lambda y, dim: torch.min(y, dim=dim)[0]
107 | }[aggregation]
108 |
109 | def reset_parameters(self):
110 | super().reset_parameters()
111 | if hasattr(self, 'weight1'):
112 | stdv = 1. / math.sqrt(self.G * self.K)
113 | self.weight1.data.uniform_(-stdv, stdv)
114 |
115 | def forward(self, x):
116 | return self.forward_gpvae(x) if self.K == 2 else batchLSIGFA(self.weight0, self.weight1, self.N0, self.SK, x, self.bias, aggregation=self.aggregation)
117 |
118 | def forward_gpvae(self, x):
119 | # K=1
120 | hx_0_0 = torch.matmul(self.weight0[:, 0, 0, :], x[:, :, :self.N0])
121 | hx_0_1 = torch.matmul(self.weight1[:, 0, 0, :], x[:, :, self.N0:])
122 | hx_0 = torch.cat([hx_0_0, hx_0_1], dim=2)
123 |
124 | # K=2
125 | neighbors = self.aggregation(x[:, :, :, None] * self.S, dim=2)
126 | hx_1_0 = torch.matmul(self.weight0[:, 0, 1, :], neighbors[:, :, :self.N0])
127 | hx_1_1 = torch.matmul(self.weight1[:, 0, 1, :], neighbors[:, :, self.N0:])
128 | hx_1 = torch.cat([hx_1_0, hx_1_1], dim=2)
129 |
130 | output = hx_0 + hx_1
131 | return output
132 |
133 | def forward_naive(self, x):
134 | bs, features, n_agents = x.shape
135 | output = torch.zeros(bs, features, n_agents)
136 | for b in range(bs):
137 | sxas = torch.zeros(self.K, features, n_agents)
138 | sk = torch.eye(n_agents).expand(n_agents, n_agents)
139 | for k in range(self.K):
140 | sx = torch.matmul(x[b], sk)
141 | h0 = self.weight0[:, 0, k, :]
142 | h1 = self.weight1[:, 0, k, :]
143 | if self.N0 == 0:
144 | sxas[k] = torch.matmul(h1, sx)
145 | else:
146 | sxa0 = torch.matmul(h0, sx[:, :self.N0])
147 | sxa1 = torch.matmul(h1, sx[:, self.N0:])
148 | sxas[k] = torch.cat([sxa0, sxa1], dim=1) # concat along N
149 | sk = torch.matmul(self.S[b, 0], sk)
150 |
151 | output[b] = self.aggregation(sxas, 0)
152 | return output
153 |
--------------------------------------------------------------------------------
/adversarial_comms/train_interpreter.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import torch
3 |
4 | from ray.rllib.models.modelv2 import ModelV2
5 | from ray.rllib.utils.annotations import override
6 | from ray.rllib.utils import try_import_torch
7 |
8 | import utils.graphML as gml
9 | import utils.graphTools
10 | import numpy as np
11 |
12 | torch, nn = try_import_torch()
13 | from torch.utils.data import Dataset
14 |
15 | from torch.optim import SGD, Adam
16 | import torch.nn.functional as F
17 |
18 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
19 | from ignite.metrics import Precision, Recall, Fbeta, Loss, RunningAverage
20 | #from ignite.contrib.metrics import ROC_AUC, AveragePrecision
21 | from ignite.handlers import ModelCheckpoint, global_step_from_engine, EarlyStopping, TerminateOnNan
22 | from ignite.contrib.handlers import ProgressBar
23 | from ignite.contrib.handlers.tensorboard_logger import *
24 |
25 | from torchsummary import summary
26 |
27 | from collections import OrderedDict
28 | import matplotlib.pyplot as plt
29 | import matplotlib.patches as patches
30 | from matplotlib import colors
31 | import json
32 | import random
33 | from pathlib import Path
34 | import time
35 | import os
36 | import copy
37 |
38 | # https://ray.readthedocs.io/en/latest/using-ray-with-pytorch.html
39 |
40 | X = 1
41 | Y = 0
42 |
43 | def get_transpose_cnn(inp_features, out_shape, out_classes):
44 | return [
45 | nn.ConvTranspose2d(in_channels=inp_features, out_channels=64, kernel_size=3, stride=1),
46 | nn.LeakyReLU(inplace=True),
47 | nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2),
48 | nn.LeakyReLU(inplace=True),
49 | nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=4, stride=2),
50 | nn.LeakyReLU(inplace=True),
51 | nn.ZeroPad2d([1,1,1,1]),
52 | nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=4, stride=2),
53 | nn.LeakyReLU(inplace=True),
54 | nn.ZeroPad2d([1,1,1,1]),
55 | nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3, stride=1),
56 | nn.LeakyReLU(inplace=True),
57 | nn.Conv2d(8, out_classes, 3, 1),
58 | nn.Sigmoid(),
59 | ]
60 |
61 | def get_upsampling_cnn(inp_features, out_shape, out_classes):
62 | if out_shape == 7:
63 | return [
64 | nn.ZeroPad2d([2]*4),
65 | nn.Conv2d(in_channels=inp_features, out_channels=16, kernel_size=3),
66 | nn.LeakyReLU(inplace=True),
67 | nn.Upsample(scale_factor=2),
68 | nn.ZeroPad2d([1]*4),
69 | nn.Conv2d(in_channels=16, out_channels=8, kernel_size=4),
70 | nn.LeakyReLU(inplace=True),
71 | nn.Upsample(scale_factor=2),
72 | nn.Conv2d(in_channels=8, out_channels=out_classes, kernel_size=4),
73 | nn.Sigmoid(),
74 | ]
75 | elif out_shape == 12:
76 | return [
77 | nn.ZeroPad2d([2]*4),
78 | nn.Conv2d(in_channels=inp_features, out_channels=16, kernel_size=3),
79 | nn.LeakyReLU(inplace=True),
80 | nn.Upsample(scale_factor=2),
81 | nn.ZeroPad2d([1]*4),
82 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=4),
83 | nn.LeakyReLU(inplace=True),
84 | nn.Upsample(scale_factor=2),
85 | nn.Conv2d(in_channels=16, out_channels=8, kernel_size=4),
86 | nn.LeakyReLU(inplace=True),
87 | nn.Upsample(scale_factor=2),
88 | nn.Conv2d(in_channels=8, out_channels=out_classes, kernel_size=3),
89 | nn.Sigmoid(),
90 | ]
91 | elif out_shape == 24:
92 | return [
93 | nn.ZeroPad2d([2]*4),
94 | nn.Conv2d(in_channels=inp_features, out_channels=32, kernel_size=3),
95 | nn.LeakyReLU(inplace=True),
96 | nn.Upsample(scale_factor=2),
97 | nn.ZeroPad2d([1]*4),
98 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3),
99 | nn.LeakyReLU(inplace=True),
100 | nn.Upsample(scale_factor=2),
101 | nn.ZeroPad2d([1]*4),
102 | nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3),
103 | nn.LeakyReLU(inplace=True),
104 | nn.Upsample(scale_factor=2),
105 | nn.ZeroPad2d([1]*4),
106 | nn.Conv2d(in_channels=8, out_channels=out_classes, kernel_size=3),
107 | nn.Sigmoid(),
108 | ]
109 | elif out_shape == 48:
110 | return [
111 | nn.ZeroPad2d([2]*4),
112 | nn.Conv2d(in_channels=inp_features, out_channels=64, kernel_size=3),
113 | nn.LeakyReLU(inplace=True),
114 | nn.Upsample(scale_factor=2),
115 | nn.ZeroPad2d([1]*4),
116 | nn.Conv2d(in_channels=64, out_channels=48, kernel_size=3),
117 | nn.LeakyReLU(inplace=True),
118 | nn.Upsample(scale_factor=2),
119 | nn.ZeroPad2d([1]*4),
120 | nn.Conv2d(in_channels=48, out_channels=32, kernel_size=3),
121 | nn.LeakyReLU(inplace=True),
122 | nn.Upsample(scale_factor=2),
123 | nn.ZeroPad2d([1]*4),
124 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3),
125 | nn.LeakyReLU(inplace=True),
126 | nn.Upsample(scale_factor=2),
127 | nn.ZeroPad2d([1]*4),
128 | nn.Conv2d(in_channels=16, out_channels=out_classes, kernel_size=3),
129 | nn.Sigmoid()
130 | ]
131 | assert False
132 |
133 | class Model(nn.Module):
134 | def __init__(self, dataset, config):
135 | nn.Module.__init__(self)
136 |
137 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
138 | self.config = config
139 | self.inp_features = {
140 | 'gnn': dataset.gnn_features,
141 | 'cnn': dataset.cnn_features,
142 | 'gnn_cnn': dataset.cnn_features+dataset.gnn_features
143 | }[self.config['nn_mode']]
144 |
145 | if self.config['format']=='relative':
146 | if self.config['pred_mode'] == 'global':
147 | self.out_size = dataset.world_shape[0]*2
148 | elif self.config['pred_mode'] == 'local':
149 | self.out_size = dataset.obs_size
150 | else:
151 | raise NotImplementedError
152 | elif self.config['format']=='absolute':
153 | self.out_size = dataset.world_shape[0]
154 | else:
155 | raise NotImplementedError
156 |
157 | if self.config['type'] == 'cov':
158 | self.classes = 1 if self.config['prediction']=='cov_only' else 2
159 | elif self.config['type'] == 'path':
160 | self.classes = 3 if self.config['prediction']=='all' else 1
161 | else:
162 | raise NotImplementedError("Invalid type")
163 |
164 | layers = get_upsampling_cnn(self.inp_features, self.out_size, self.classes)
165 | cnn = nn.Sequential(*layers)
166 | #summary(cnn, device="cpu", input_size=(self.inp_features, 1, 1))
167 | self._post_cnn = cnn.to(self.device)
168 |
169 | @override(ModelV2)
170 | def forward(self, input_dict):
171 | agent_observations = input_dict["obs"]['agents']
172 | batch_size = input_dict["obs"]['gso'].shape[0]
173 |
174 | prediction = torch.empty(batch_size, len(agent_observations), self.classes, self.out_size, self.out_size).to(
175 | self.device)
176 | for this_id, this_state in enumerate(agent_observations):
177 | gnn_out = this_state['gnn_out']
178 | cnn_out = this_state['cnn_out']
179 | this_entity = {
180 | 'gnn': gnn_out,
181 | 'cnn': cnn_out,
182 | 'gnn_cnn': torch.cat([gnn_out, cnn_out], dim=1)
183 | }[self.config['nn_mode']]
184 | prediction[:, this_id] = self._post_cnn(this_entity.view(batch_size, self.inp_features, 1, 1))
185 |
186 | return prediction.double()
187 |
188 | class BaseDataset(Dataset):
189 | def __init__(self, path):
190 | try:
191 | with open(Path(path), "rb") as f:
192 | self.data = pickle.load(f)
193 | assert (len(self.data) > 0)
194 | except TypeError:
195 | self.data = [{'obs': path}]
196 | self.world_shape = self.data[0]['obs']['state'].shape[:2]
197 | self.obs_size = self.data[0]['obs']['agents'][0]['map'].shape[0]
198 | self.cnn_features = self.data[0]['obs']['agents'][-1]['cnn_out'].shape[0]
199 | self.gnn_features = self.data[0]['obs']['agents'][-1]['gnn_out'].shape[0]
200 |
201 | def __len__(self):
202 | return len(self.data)
203 |
204 | def get_coverable_area(self, idx):
205 | coverable_area = ~(self.data[idx]['obs']['state'][...,0] > 0)
206 | return np.sum(coverable_area)
207 |
208 | def get_coverage_fraction(self, idx):
209 | coverable_area = ~(self.data[idx]['obs']['state'][...,0] > 0)
210 | covered_area = self.data[idx]['obs']['state'][...,1] & coverable_area
211 | return np.sum(covered_area) / np.sum(coverable_area)
212 |
213 | def to_agent_coord_frame(self, m, state_size, pose, fill=0):
214 | half_state_shape = np.array([state_size / 2] * 2, dtype=np.int)
215 | padded = np.pad(m, ([half_state_shape[Y]] * 2, [half_state_shape[X]] * 2), mode='constant',
216 | constant_values=fill)
217 | return padded[pose[Y]:pose[Y] + state_size, pose[X]:pose[X] + state_size]
218 |
219 | class CoverageDataset(BaseDataset):
220 | def __init__(self, path, config):
221 | # is_relative: Agent relative or world absolute prediction
222 | # cov_only: Predict only coverage or predict both coverage and map
223 | # is_global: Predict local coverage and map or global coverage and map
224 |
225 | super().__init__(path)
226 | self.is_relative=config['format']=='relative'
227 | self.cov_only=config['prediction']=='cov_only'
228 | self.is_global=config['pred_mode']=='global'
229 | self.skip_agents=config['skip_agents'] if 'skip_agents' in config else 0
230 | self.stop_agents=config['stop_agents'] if 'stop_agents' in config else None
231 |
232 | def __getitem__(self, idx):
233 | if torch.is_tensor(idx):
234 | idx = idx.tolist()
235 |
236 | y = []
237 | weights = []
238 | for agent_obs in self.data[idx]['obs']['agents'][self.skip_agents:self.stop_agents]:
239 | if self.is_global:
240 | obs_cov = self.data[idx]['obs']['state'][...,1]
241 | obs_map = self.data[idx]['obs']['state'][...,0]
242 | if self.is_relative:
243 | obs_cov = self.to_agent_coord_frame(obs_cov, self.obs_size, agent_obs['pos'], fill=0)
244 | obs_map = self.to_agent_coord_frame(obs_map, self.obs_size, agent_obs['pos'], fill=1)
245 | else:
246 | if self.is_relative:
247 | # directly use agent's relative view coverage
248 | obs_cov = agent_obs['map'][..., 1]
249 | obs_map = agent_obs['map'][..., 0]
250 | else:
251 | # shift the agent's local coverage to an absolute view
252 | m = np.roll(agent_obs['map'], agent_obs['pos'], axis=(0,1))[int(self.obs_size/2):,int(self.obs_size/2):]
253 | obs_cov = m[...,1]
254 | obs_map = m[...,0]
255 |
256 | if self.cov_only:
257 | # only predict local coverage and use world map as mask
258 | y.append([obs_cov])
259 | weights.append([(~obs_map.astype(np.bool)).astype(np.int)])
260 | else:
261 | # predict both local coverage and world map, but mask out everything outside the world shifted to the agents position
262 | d = np.stack([obs_cov, obs_map], axis=0)
263 | y.append(d)
264 | weight = np.ones(obs_cov.shape)
265 | if self.is_relative:
266 | weight = self.to_agent_coord_frame(weight, self.obs_size, agent_obs['pos'], fill=0)
267 | weight = np.stack([weight]*2, axis=0)
268 | #print(d.shape, weight.shape)
269 | weights.append(weight)
270 |
271 | y = np.array(y, dtype=np.double)
272 | w = np.array(weights, dtype=np.double)
273 | obs = self.data[idx]['obs']
274 | obs['agents'] = obs['agents'][self.skip_agents:self.stop_agents]
275 | return {'obs': self.data[idx]['obs']}, {'y': y, 'w': w}
276 |
277 | class PathplanningDataset(BaseDataset):
278 | def __init__(self, path, config):
279 | # cov_only: Predict only coverage or predict both coverage and map
280 | # is_global: Predict local coverage and map or global coverage and map
281 |
282 | super().__init__(path)
283 | self.is_relative=config['format']=='relative'
284 | self.pred_mode=config['prediction']
285 | self.is_global=config['pred_mode']=='global'
286 | self.skip_agents=config['skip_agents'] if 'skip_agents' in config else 0
287 | self.stop_agents=config['stop_agents'] if 'stop_agents' in config else None
288 |
289 | def __getitem__(self, idx):
290 | if torch.is_tensor(idx):
291 | idx = idx.tolist()
292 |
293 | y = []
294 | weights = []
295 | for agent_obs in self.data[idx]['obs']['agents'][self.skip_agents:self.stop_agents]:
296 | if self.is_global:
297 | obs_map = np.zeros(self.data[idx]['obs']['state'].shape[:2], dtype=np.float)
298 | obs_pos = np.zeros(self.data[idx]['obs']['state'].shape[:2], dtype=np.float)
299 | obs_goal = np.zeros(self.data[idx]['obs']['state'].shape[:2], dtype=np.float)
300 | obs_map[self.data[idx]['obs']['state'][..., 0] == 1] = 1
301 | for i in range(self.data[idx]['obs']['state'].shape[-1]):
302 | obs_pos[self.data[idx]['obs']['state'][..., i] == 2] = 1
303 | obs_goal[self.data[idx]['obs']['state'][..., i] == 3] = 1
304 |
305 | if self.is_relative:
306 | obs_goal = self.to_agent_coord_frame(obs_goal, self.world_shape[0]*2, agent_obs['pos'], fill=0)
307 | obs_pos = self.to_agent_coord_frame(obs_pos, self.world_shape[0]*2, agent_obs['pos'], fill=0)
308 | obs_map = self.to_agent_coord_frame(obs_map, self.world_shape[0]*2, agent_obs['pos'], fill=1)
309 | else:
310 | # directly use agent's relative view coverage
311 | obs_map = agent_obs['map'][..., 0]
312 | obs_goal = agent_obs['map'][..., 1]
313 | obs_pos = agent_obs['map'][..., 2]
314 |
315 | if self.pred_mode == "goal":
316 | # only predict local coverage and use world map as mask
317 | y.append(np.stack([obs_goal], axis=0))
318 | weight = (~obs_map.astype(np.bool)).astype(np.int)
319 | # goal can generally be on the margin if it is projected!
320 | for row in [0, -1]:
321 | weight[row] = 1
322 | weight[:, row] = 1
323 | weights.append(copy.deepcopy([weight]))
324 | elif self.pred_mode == "all":
325 | # predict both local coverage and world map, but mask out everything outside the world shifted to the agents position
326 | d = np.stack([obs_map, obs_goal, obs_pos], axis=0)
327 | y.append(d)
328 | weight = np.ones(obs_map.shape)
329 | weight = np.stack([weight]*3, axis=0)
330 | weights.append(weight)
331 |
332 | y = np.array(y, dtype=np.double)
333 | w = np.array(weights, dtype=np.double)
334 | obs = self.data[idx]['obs']
335 | obs['agents'] = obs['agents'][self.skip_agents:self.stop_agents]
336 | return {'obs': self.data[idx]['obs']}, {'y': y, 'w': w}
337 |
338 | dataset_classes = {
339 | "path": PathplanningDataset,
340 | "cov": CoverageDataset
341 | }
342 |
343 | def inference(model_checkpoint_path,
344 | data_path,
345 | seed=None, run_eval=False, save_dirname=None):
346 | if seed is None:
347 | seed = time.time()
348 | torch.manual_seed(seed)
349 | random.seed(seed)
350 | batch_size = 1
351 |
352 | checkpoint_file = Path(model_checkpoint_path)
353 | with open(checkpoint_file.parent / 'config.json', 'r') as config_file:
354 | config = json.load(config_file)
355 | config['skip_agents'] = 0
356 | dataset = dataset_classes[config['type']](data_path, config)
357 | loader = torch.utils.data.DataLoader(
358 | dataset,
359 | batch_size=batch_size,
360 | shuffle=True,
361 | num_workers=1
362 | )
363 | model = load_model(checkpoint_file, Model(dataset, config))
364 |
365 | cmap_map = colors.LinearSegmentedColormap.from_list("cmap_map", [(0, 0, 0, 0), (0, 0, 0, 1)])
366 | cmap_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (0, 1, 0, 1)])
367 | cmap_own_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (1, 1, 0, 1)])
368 |
369 | def transform_rel_abs(data, pos):
370 | return np.roll(data, pos, axis=(0, 1))[int(dataset.obs_size / 2):, int(dataset.obs_size / 2):]
371 |
372 | if run_eval:
373 | evaluator = create_evaluator(model)
374 | evaluator.run(loader)
375 | metrics = evaluator.state.metrics
376 | print(metrics)
377 | identifier = f"{config['format']}-{config['nn_mode']}-{config['pred_mode']}"
378 | batch_index = 0
379 | for x, y_true in loader:
380 | fig, axs = plt.subplots(batch_size*2, 5, figsize=[8, 3.2])
381 | if run_eval:
382 | #axs[0][0].set_title(f"ap: {metrics['ap']:.4f}, auc: {metrics['auc']:.4f}")
383 | fig.suptitle(f"{identifier} f1: {metrics['f1']:.4f}, ap: {metrics['ap']:.4f}, auc: {metrics['auc']:.4f}")
384 |
385 | y_pred = model(x).detach().numpy()
386 | #y_pred = torch.round(model(x).detach()).numpy()
387 | for i in range(batch_size):
388 | for j in range(5):
389 | #for agent in x['obs']['agents']:
390 | # print(agent['pos'])
391 |
392 | #axs[i*2][j].imshow(y_true['y'][i][j][1, :, :], cmap=cmap_cov)
393 | agent_obs = x['obs']['agents'][j]
394 | pos = agent_obs['pos'][i]
395 | agent_map = agent_obs['map'][i]
396 |
397 | if config['format'] == 'relative':
398 | #axs[i*2][j].imshow(transform_rel_abs(agent_map[...,0], pos), cmap=cmap_map) # obstacles
399 | #axs[i*2][j].imshow(transform_rel_abs(y_true['y'][i][j][0, :, :], pos), cmap=cmap_cov)
400 | #axs[i*2][j].imshow(transform_rel_abs(agent_map[...,1], pos), cmap=cmap_own_cov)
401 |
402 | axs[i*2][j].imshow(agent_map[...,0], cmap=cmap_map) # obstacles
403 | #axs[i*2][j].imshow(y_true['y'][i][j][1, :, :], cmap=cmap_map)
404 | axs[i*2][j].imshow(y_true['y'][i][j][0, :, :], cmap=cmap_cov)
405 |
406 | #print(y_pred[i][j][1, :, :])
407 | #axs[i*2+1][j].imshow(y_pred[i][j][1, :, :], cmap=cmap_map)
408 | #axs[i*2+1][j].imshow(y_pred[i][j][0, :, :], cmap=cmap_cov)
409 |
410 | #axs[i*2][j].imshow(agent_map[...,1], cmap=cmap_own_cov)
411 |
412 | axs[i*2+1][j].imshow(y_pred[i][j][0, :, :], cmap=cmap_cov)
413 | #axs[i*2+1][j].imshow(y_pred[i][j][1, :, :], cmap=cmap_map)
414 | axs[i * 2+1][j].imshow(agent_map[..., 0], cmap=cmap_map) # obstacles
415 | #axs[i*2+1][j].imshow(y_true['w'][i][j][0, :, :], cmap=cmap_map) # weighting
416 | #axs[i*2+1][j].imshow(transform_rel_abs(y_pred[i][j][0, :, :], pos), cmap=cmap_own_cov if config['pred_mode'] == 'local' else cmap_cov)
417 |
418 | #map_data = transform_rel_abs(agent_obs['map'][i][...,0], pos) if len(y_pred[i][j]) == 1 else y_pred[i][j][1, :, :]
419 | #axs[i*2+1][j].imshow(map_data, cmap=cmap_map) # obstacles
420 |
421 | else:
422 | axs[i*2][j].imshow(x['obs']['state'][i][...,0], cmap=cmap_map) # obstacles
423 | axs[i*2][j].imshow(x['obs']['state'][i][...,1], cmap=cmap_cov)
424 | m = np.roll(agent_obs['map'][i], agent_obs['pos'][i], axis=(0,1))[int(dataset.obs_size/2):,int(dataset.obs_size/2):]
425 | axs[i*2][j].imshow(m[...,1], cmap=cmap_own_cov)
426 |
427 | axs[i*2+1][j].imshow(y_pred[i][j][0, :, :], cmap=cmap_own_cov if config['pred_mode'] == 'local' else cmap_cov)
428 | axs[i*2+1][j].imshow(x['obs']['state'][i][...,0], cmap=cmap_map) # obstacles
429 |
430 | '''
431 | for k in range(2):
432 | rect = patches.Rectangle((agent_obs['pos'][i][1] - 1 / 2, agent_obs['pos'][i][0] - 1 / 2), 1, 1,
433 | linewidth=1, edgecolor='r', facecolor='none')
434 | axs[i*2+k][j].add_patch(rect)
435 | '''
436 | for k in range(2):
437 | axs[i * 2 + k][j].set_xticks([])
438 | axs[i * 2 + k][j].set_yticks([])
439 |
440 | fig.tight_layout() #rect=[0, 0.03, 1, 0.95])
441 | if save_dirname is not None:
442 | img_path = checkpoint_file.parent/save_dirname
443 | img_path.mkdir(exist_ok=True)
444 | frame_path = img_path/f"{batch_index:05d}.png"
445 | print("Frame", frame_path)
446 | plt.savefig(frame_path, dpi=300)
447 | else:
448 | plt.show()
449 | plt.close()
450 | batch_index += 1
451 | if batch_index == 300:
452 | break
453 |
454 | def inference_gnn_cnn(cnn_model_checkpoint_path,
455 | gnn_model_checkpoint_path,
456 | data_path,
457 | seed=None, run_eval=False, save_dirname=None):
458 | if seed is None:
459 | seed = time.time()
460 | torch.manual_seed(seed)
461 | random.seed(seed)
462 | batch_size = 1
463 |
464 | checkpoint_file = Path(gnn_model_checkpoint_path)
465 | with open(checkpoint_file.parent / 'config.json', 'r') as config_file:
466 | config = json.load(config_file)
467 | config['skip_agents'] = 0
468 | gnn_dataset = dataset_classes[config['type']](data_path, config)
469 | gnn_loader = torch.utils.data.DataLoader(
470 | gnn_dataset,
471 | batch_size=batch_size,
472 | shuffle=False,
473 | num_workers=1
474 | )
475 | gnn_model = load_model(checkpoint_file, Model(gnn_dataset, config))
476 |
477 | checkpoint_file = Path(cnn_model_checkpoint_path)
478 | with open(checkpoint_file.parent / 'config.json', 'r') as config_file:
479 | config = json.load(config_file)
480 | config['skip_agents'] = 0
481 | cnn_dataset = dataset_classes[config['type']](data_path, config)
482 | cnn_loader = torch.utils.data.DataLoader(
483 | cnn_dataset,
484 | batch_size=batch_size,
485 | shuffle=False,
486 | num_workers=1
487 | )
488 | cnn_model = load_model(checkpoint_file, Model(cnn_dataset, config))
489 |
490 |
491 | cmap_map = colors.LinearSegmentedColormap.from_list("cmap_map", [(0, 0, 0, 0), (0, 0, 0, 1)])
492 | cmap_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (0, 1, 0, 1)])
493 | cmap_own_cov = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (0, 0, 1, 1)])
494 |
495 | def transform_rel_abs(data, pos):
496 | return np.roll(data, pos, axis=(0, 1))[int(gnn_dataset.obs_size / 2):, int(gnn_dataset.obs_size / 2):]
497 |
498 | identifier = f"{config['format']}-{config['nn_mode']}-{config['pred_mode']}"
499 | batch_index = 0
500 | for (x, y_true), (x_cnn, y_cnn_true) in zip(gnn_loader, cnn_loader):
501 | fig, axs = plt.subplots(batch_size*2, 5, figsize=[8, 3.2])
502 |
503 | y_pred = gnn_model(x).detach().numpy()
504 | y_pred_cnn = cnn_model(x_cnn).detach().numpy()
505 | for i in range(batch_size):
506 | for j in range(5):
507 | if j == 0:
508 | agent_obs = x['obs']['agents'][j]
509 | pos = agent_obs['pos'][i]
510 | agent_map = agent_obs['map'][i]
511 | else:
512 | agent_obs = x_cnn['obs']['agents'][j]
513 | pos = agent_obs['pos'][i]
514 | agent_map = agent_obs['map'][i]
515 |
516 | axs[i*2][j].imshow(transform_rel_abs(agent_map[...,0], pos), cmap=cmap_map) # obstacles
517 | axs[i*2][j].imshow(transform_rel_abs(y_true['y'][i][j][0, :, :], pos), cmap=cmap_cov)
518 | axs[i*2][j].imshow(transform_rel_abs(agent_map[...,1], pos), cmap=cmap_own_cov)
519 |
520 | if j == 0:
521 | axs[i*2+1][j].imshow(transform_rel_abs(y_pred_cnn[i][j][0, :, :], pos), cmap=cmap_own_cov)
522 | else:
523 | axs[i*2+1][j].imshow(transform_rel_abs(y_pred[i][j][0, :, :], pos), cmap=cmap_cov)
524 |
525 | map_data = agent_obs['map'][i][...,0] if len(y_pred[i][j]) == 1 else y_pred[i][j][1, :, :]
526 | axs[i*2+1][j].imshow(transform_rel_abs(map_data, pos), cmap=cmap_map) # obstacles
527 |
528 | for k in range(2):
529 | rect = patches.Rectangle((agent_obs['pos'][i][1] - 1 / 2, agent_obs['pos'][i][0] - 1 / 2), 1, 1,
530 | linewidth=1, edgecolor='r', facecolor='none')
531 | axs[i*2+k][j].add_patch(rect)
532 |
533 | for k in range(2):
534 | axs[i * 2 + k][j].set_xticks([])
535 | axs[i * 2 + k][j].set_yticks([])
536 |
537 | fig.tight_layout() #rect=[0, 0.03, 1, 0.95])
538 | if save_dirname is not None:
539 | img_path = checkpoint_file.parent/save_dirname
540 | img_path.mkdir(exist_ok=True)
541 | frame_path = img_path/f"{batch_index:05d}.png"
542 | print("Frame", frame_path)
543 | plt.savefig(frame_path, dpi=300)
544 | else:
545 | plt.show()
546 | plt.close()
547 | batch_index += 1
548 | if batch_index == 10:
549 | break
550 |
551 | def inference_path(model_checkpoint_path,
552 | data_path,
553 | seed=None, run_eval=False, save_dirname=None):
554 | if seed is None:
555 | seed = time.time()
556 | torch.manual_seed(seed)
557 | random.seed(seed)
558 | batch_size = 1
559 |
560 | checkpoint_file = Path(model_checkpoint_path)
561 | with open(checkpoint_file.parent / 'config.json', 'r') as config_file:
562 | config = json.load(config_file)
563 | config['skip_agents'] = 0
564 | dataset = dataset_classes[config['type']](data_path, config)
565 | loader = torch.utils.data.DataLoader(
566 | dataset,
567 | batch_size=batch_size,
568 | shuffle=True,
569 | num_workers=1
570 | )
571 | model = load_model(checkpoint_file, Model(dataset, config))
572 |
573 | cmap_map = colors.LinearSegmentedColormap.from_list("cmap_map", [(0, 0, 0, 0), (0, 0, 0, 1)])
574 | cmap_pos = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (0, 1, 0, 1)])
575 | cmap_goal = colors.LinearSegmentedColormap.from_list("cmap_cov", [(0, 0, 0, 0), (1, 1, 0, 1)])
576 |
577 | def transform_rel_abs(data, pos):
578 | return data #np.roll(data, pos, axis=(0, 1))#[int(dataset.obs_size / 2):, int(dataset.obs_size / 2):]
579 |
580 | if run_eval:
581 | evaluator = create_evaluator(model)
582 | evaluator.run(loader)
583 | metrics = evaluator.state.metrics
584 | print(metrics)
585 | identifier = f"{config['format']}-{config['nn_mode']}-{config['pred_mode']}"
586 | batch_index = 0
587 | for x, y_true in loader:
588 | n_agents = len(x['obs']['agents'])
589 | fig, axs = plt.subplots(batch_size*2, 5, figsize=[8, 3.2])
590 | if run_eval:
591 | #axs[0][0].set_title(f"ap: {metrics['ap']:.4f}, auc: {metrics['auc']:.4f}")
592 | fig.suptitle(f"{identifier} f1: {metrics['f1']:.4f}, ap: {metrics['ap']:.4f}, auc: {metrics['auc']:.4f}")
593 |
594 | y_pred = model(x).detach().numpy()
595 | #y_pred = torch.round(model(x).detach()).numpy()
596 | for i in range(batch_size):
597 | for j in range(5):
598 | #for agent in x['obs']['agents']:
599 | # print(agent['pos'])
600 |
601 | #axs[i*2][j].imshow(y_true['y'][i][j][1, :, :], cmap=cmap_cov)
602 | agent_obs = x['obs']['agents'][j]
603 | pos = agent_obs['pos'][i]
604 | agent_map = agent_obs['map'][i]
605 |
606 | axs[i*2][j].imshow(transform_rel_abs(agent_map[...,0], pos), cmap=cmap_map) # obstacles
607 | axs[i*2][j].imshow(transform_rel_abs(agent_map[...,2], pos), cmap=cmap_pos) # obstacles
608 | axs[i*2][j].imshow(transform_rel_abs(agent_map[...,1], pos), cmap=cmap_goal) # obstacles
609 |
610 | axs[i*2+1][j].imshow(transform_rel_abs(agent_map[...,0], pos), cmap=cmap_map) # obstacles
611 | axs[i*2+1][j].imshow(transform_rel_abs(agent_map[...,2], pos), cmap=cmap_pos) # obstacles
612 | axs[i*2+1][j].imshow(y_pred[i][j][0, :, :], cmap=cmap_goal)
613 |
614 | #axs[i*2+1][j].imshow(y_pred[i][j][1, :, :], cmap=cmap_goal)
615 | #axs[i*2+1][j].imshow(y_pred[i][j][2, :, :], cmap=cmap_pos)
616 |
617 | #axs[i*2+1][j].imshow(transform_rel_abs(y_pred[i][j][0, :, :], pos), cmap=cmap_own_cov if config['pred_mode'] == 'local' else cmap_cov)
618 | #axs[i*2+1][j].imshow(transform_rel_abs(agent_obs['map'][i][...,0], pos), cmap=cmap_map) # obstacles
619 |
620 | for k in range(2):
621 | axs[i * 2 + k][j].set_xticks([])
622 | axs[i * 2 + k][j].set_yticks([])
623 |
624 | fig.tight_layout() #rect=[0, 0.03, 1, 0.95])
625 | if save_dirname is not None:
626 | img_path = checkpoint_file.parent/save_dirname
627 | img_path.mkdir(exist_ok=True)
628 | frame_path = img_path/f"{batch_index:05d}.png"
629 | print("Frame", frame_path)
630 | plt.savefig(frame_path, dpi=300)
631 | else:
632 | plt.show()
633 | plt.close()
634 | batch_index += 1
635 | if batch_index == 300:
636 | break
637 |
638 | def thresholded_output_transform(output):
639 | y_pred, y = output
640 | y_pred = torch.round(y_pred)
641 | return y_pred, y
642 |
643 | def apply_weight_output_transform(x):
644 | y_pred_raw, y_raw = x[0], x[1] # shape each (batch size, agent, channel, x, y)
645 |
646 | classes = y_pred_raw.shape[2]
647 |
648 | w = y_raw['w'].permute([2, 0,1,3,4]).flatten()
649 | y = y_raw['y'].permute([2, 0,1,3,4]).flatten()[w==1].reshape(-1, classes)
650 | y_pred = y_pred_raw.permute([2, 0,1,3,4]).flatten()[w==1].reshape(-1, classes)
651 |
652 | return (y_pred, y)
653 |
654 | def apply_weight_threshold_output_transform(x):
655 | return apply_weight_output_transform(thresholded_output_transform(x))
656 |
657 | def weighted_binary_cross_entropy(y_pred, y):
658 | return F.binary_cross_entropy(y_pred, y['y'], weight=y['w'])
659 |
660 | from ignite.metrics import EpochMetric
661 |
662 | class AveragePrecision(EpochMetric):
663 | def __init__(self, output_transform=lambda x: x):
664 | def average_precision_compute_fn(y_preds, y_targets):
665 | try:
666 | from sklearn.metrics import average_precision_score
667 | except ImportError:
668 | raise RuntimeError("This contrib module requires sklearn to be installed.")
669 |
670 | y_true = y_targets.numpy()
671 | y_pred = y_preds.numpy()
672 | return average_precision_score(y_true, y_pred, average='micro')
673 |
674 | super(AveragePrecision, self).__init__(average_precision_compute_fn, output_transform=output_transform)
675 |
676 | def create_evaluator(model):
677 | return create_supervised_evaluator(
678 | model,
679 | metrics={
680 | #"p": Precision(apply_weight_threshold_output_transform),
681 | #"r": Recall(apply_weight_threshold_output_transform),
682 | #"f1": Fbeta(1, output_transform=apply_weight_threshold_output_transform),
683 | #"auc": ROC_AUC(output_transform=apply_weight_output_transform),
684 | "ap": AveragePrecision(output_transform=apply_weight_output_transform)
685 | },
686 | device=model.device
687 | )
688 |
689 | def load_model(checkpoint_path, model):
690 | model_state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
691 |
692 | # A basic remapping is required
693 | mapping = {k: v for k, v in zip(model_state.keys(), model.state_dict().keys())}
694 | mapped_model_state = OrderedDict([(mapping[k], v) for k, v in model_state.items()])
695 | model.load_state_dict(mapped_model_state, strict=False)
696 | return model
697 |
698 | def train(train_data_path,
699 | valid_data_path,
700 | config,
701 | out_dir="./explainability",
702 | batch_size=64,
703 | lr=1e-4,
704 | epochs=100):
705 | train_dataset = dataset_classes[config['type']](train_data_path, config)
706 | valid_dataset = dataset_classes[config['type']](valid_data_path, config)
707 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
708 | val_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
709 |
710 | model = Model(valid_dataset, config)
711 | path_cp = "./explainability_checkpoints/"+out_dir
712 | os.makedirs(path_cp, exist_ok=True)
713 | with open(path_cp+"/config.json", 'w') as config_file:
714 | json.dump(config, config_file)
715 |
716 | optimizer = Adam(model.parameters(), lr=lr)
717 | trainer = create_supervised_trainer(model, optimizer, weighted_binary_cross_entropy, device=model.device)
718 |
719 | validation_evaluator = create_evaluator(model)
720 |
721 | RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
722 |
723 | pbar = ProgressBar(persist=True)
724 | pbar.attach(trainer, metric_names="all")
725 |
726 | trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
727 |
728 | best_model_handler = ModelCheckpoint(dirname="./explainability_checkpoints/"+out_dir,
729 | filename_prefix="best",
730 | n_saved=1,
731 | global_step_transform=global_step_from_engine(trainer),
732 | score_name="val_ap",
733 | score_function=lambda engine: engine.state.metrics['ap'],
734 | require_empty=False)
735 | validation_evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {'model': model, })
736 |
737 | tb_logger = TensorboardLogger(log_dir='./explainability_tensorboard/'+out_dir)
738 | tb_logger.attach(
739 | trainer,
740 | log_handler=OutputHandler(
741 | tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all"
742 | ),
743 | event_name=Events.ITERATION_COMPLETED(every=100),
744 | )
745 |
746 | tb_logger.attach(
747 | validation_evaluator,
748 | log_handler=OutputHandler(tag="validation", metric_names=["ap"], another_engine=trainer),
749 | event_name=Events.EPOCH_COMPLETED,
750 | )
751 | #tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_COMPLETED(every=100))
752 | #tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))
753 | #tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100))
754 | #tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))
755 | #tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100))
756 |
757 | @trainer.on(Events.EPOCH_COMPLETED(every=5))
758 | def log_validation_results(engine):
759 | validation_evaluator.run(val_loader)
760 | metrics = validation_evaluator.state.metrics
761 | pbar.log_message(
762 | f"Validation Results - Epoch: {engine.state.epoch} ap: {metrics['ap']}" # f1: {metrics['f1']}, p: {metrics['p']}, r: {metrics['r']}
763 | )
764 |
765 | pbar.n = pbar.last_print_n = 0
766 |
767 | trainer.run(train_loader, max_epochs=epochs)
768 |
769 | def evaluate(model_checkpoint, data_path, **kwargs):
770 | checkpoint_file = Path(model_checkpoint)
771 | with open(checkpoint_file.parent/'config.json', 'r') as config_file:
772 | config = json.load(config_file)
773 |
774 | ap = []
775 | for start, end in [[0, 1], [1, None]]:
776 | config['stop_agents'] = end
777 | config['skip_agents'] = start
778 |
779 | dataset = dataset_classes[config['type']](data_path, config)
780 | loader = torch.utils.data.DataLoader(
781 | dataset,
782 | batch_size=64,
783 | shuffle=True,
784 | num_workers=2
785 | )
786 | model = load_model(checkpoint_file, Model(dataset, config))
787 | evaluator = create_evaluator(model)
788 | t0 = time.time()
789 | evaluator.run(loader)
790 | #print("T", time.time() - t0)
791 | m = evaluator.state.metrics['ap']
792 | if not isinstance(m, list):
793 | m = [m]
794 | ap.append(m)
795 | print(ap)
796 | return ap
797 |
798 | def analyse_dataset(path):
799 | with open(path, "rb") as f:
800 | data = pickle.load(f)
801 | analyse(data)
802 |
803 | def get_coverage_fraction(map, coverage):
804 | coverable_area = ~(map > 0)
805 | covered_area = coverage & coverable_area
806 | return np.sum(covered_area) / np.sum(coverable_area)
807 |
808 | coverage_fractions = []
809 | for sample in data:
810 | world_cov = sample['obs']['state'][..., 1]
811 | world_map = sample['obs']['state'][..., 0]
812 | coverage_fractions.append(get_coverage_fraction(world_map, world_cov))
813 |
814 | print(np.mean(coverage_fractions), np.std(coverage_fractions))
815 | # plt.hist(coverage_fractions)
816 | # plt.show()
817 |
818 |
819 | if __name__ == "__main__":
820 | #train("explainability_data_k3_268ugnliyw_2735_train.pkl", "explainability_data_k3_268ugnliyw_2735_valid.pkl", "./explainability_k3_sgd", epochs=10000, batch_size=32, lr=0.1, sgd_momentum=0.9)
821 | if False:
822 | #dataset_id = "031g2r0u73_3070"
823 | #dataset_id = "096dwc6v_g_2990"
824 | #dataset_id = "22hfzad070_3050"
825 | #dataset_id = "280kl0nkhl_2960"
826 | #dataset_id = "44cdqovnq4_1540" # split
827 | #dataset_id = "27c8iboc7__2250" # flow
828 |
829 | #dataset_id = "101jtn1ssr_2450"
830 | #dataset_id = "46t6qhvrxf_5400"
831 |
832 | dataset_id = "300d3g1xqj_3120"
833 | #dataset_id = "303qg71k5o_3120"
834 | train(
835 | f"/local/scratch/jb2270/datasets_corl/explainability_data_{dataset_id}_train.pkl",
836 | f"/local/scratch/jb2270/datasets_corl/explainability_data_{dataset_id}_valid.pkl",
837 | {
838 | 'format': 'relative', # absolute/relative
839 | 'nn_mode': 'cnn', # cnn/gnn/gnn_cnn
840 | 'pred_mode': 'local', # local (own coverage and map)/global (global coverage and map)
841 | 'prediction': 'cov_map', #cov_only/cov_map
842 | 'type': 'cov' # path/cov
843 | },
844 | f"explainability_cov_map_local_cnn_{dataset_id}",
845 | epochs=10000,
846 | batch_size=64,
847 | lr=5e-3,
848 | )
849 |
850 | #evaluate("explainability_data_56uhj2ync9_2650_valid.pkl", "./explainability_checkpoints/explainability_56uhj2ync9_rel_glob/best_model_636_val_auc=0.8963244891793261.pth")
851 |
852 | #inference("./results/0610/explainability_checkpoints/explainability_228pbizcxq_1955_glob/best_model_973_val_auc=0.8896503235407915.pth", "./explainability_data_228pbizcxq_1955_test.pkl", 11, False)
853 |
854 | #inference("./results/0610/explainability_checkpoints/explainability_228pbizcxq_1955_loc/best_model_1017_val_auc=0.9856628585466523.pth", "./explainability_data_228pbizcxq_1955_test.pkl", 11, False)
855 |
856 | # flow
857 | #evaluate("./results/0712/explainability_checkpoints/explainability_glob_27c8iboc7__2250/best_model_162_val_auc=0.9998915352938512.pth", "./results/0712/explainability_data_27c8iboc7__3500_comm_test.pkl", save_dirname="rendering_comm")
858 | #evaluate("./results/0712/explainability_checkpoints/explainability_glob_27c8iboc7__2250/best_model_162_val_auc=0.9998915352938512.pth", "./results/0712/explainability_data_27c8iboc7__3500_nocomm_test.pkl", save_dirname="rendering_nocomm")
859 | #inference_path("./results/0721/explainability_checkpoints/explainability_path_goal_only_local_cnn_27c8iboc7__2250/best_model_30_val_ap=1.0.pth", "./results/0712/explainability_data_27c8iboc7__3500_nocomm_test.pkl")
860 |
861 | # split
862 | #inference("./results/0712/explainability_checkpoints/explainability_cov_map_44cdqovnq4_1540/best_model_77_val_auc=0.999701756785555.pth", "./results/0712/explainability_data_44cdqovnq4_1540_nocomm_test.pkl")
863 | #evaluate("./results/0712/explainability_checkpoints/explainability_cov_map_44cdqovnq4_1540/best_model_77_val_auc=0.999701756785555.pth", "./results/0712/explainability_data_44cdqovnq4_1540_comm_test.pkl", save_dirname="rendering_comm")
864 | #evaluate("./results/0712/explainability_checkpoints/explainability_cov_map_global_gnn_44cdqovnq4_1540/best_model_67_val_auc=0.9980259594481737.pth", "./results/0712/explainability_data_44cdqovnq4_1540_comm_test.pkl", save_dirname="rendering_comm")
865 | #evaluate("./results/0712/explainability_checkpoints/explainability_cov_map_global_gnn_44cdqovnq4_1540/best_model_67_val_auc=0.9980259594481737.pth", "./results/0712/explainability_data_44cdqovnq4_1540_nocomm_test.pkl", save_dirname="rendering_nocomm")
866 |
867 | # coverage normal
868 | #evaluate("./results/0712/explainability_checkpoints/explainability_cov_map_local_cnn_300d3g1xqj_3120/best_model_139_val_auc=0.9885640826508932.pth", "./results/0712/explainability_data_300d3g1xqj_3120_nocomm_test.pkl", save_dirname="rendering_nocomm")
869 | #inference("./results/0712/explainability_checkpoints/explainability_cov_map_local_cnn_300d3g1xqj_3120/best_model_139_val_auc=0.9885640826508932.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #, save_dirname="rendering_comm")
870 | #inference("./results/0721/explainability_checkpoints/explainability_cov_cov_only_local_cnn_300d3g1xqj_3120/best_model_190_val_ap=0.8773743058356964.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #, save_dirname="rendering_comm")
871 | #evaluate("./results/0721/explainability_checkpoints/explainability_cov_cov_only_local_cnn_300d3g1xqj_3120/best_model_190_val_ap=0.8773743058356964.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #, save_dirname="rendering_comm")
872 |
873 | #inference("./results/0721/explainability_checkpoints/explainability_cov_cov_only_global_gnn_300d3g1xqj_3120/best_model_140_val_ap=0.8570488698746233.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #, save_dirname="rendering_comm")
874 | #inference_gnn_cnn(
875 | # "./results/0721/explainability_checkpoints/explainability_cov_cov_only_local_cnn_300d3g1xqj_3120/best_model_190_val_ap=0.8773743058356964.pth",
876 | # "./results/0721/explainability_checkpoints/explainability_cov_cov_only_global_gnn_300d3g1xqj_3120/best_model_140_val_ap=0.8570488698746233.pth",
877 | # "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl",
878 | # save_dirname="rendering_comm"
879 | #)
880 | #evaluate("./results/0712/explainability_checkpoints/explainability_cov_map_global_gnn_300d3g1xqj_3120/best_model_69_val_auc=0.961077006252264.pth", "./results/0712/explainability_data_300d3g1xqj_3120_nocomm_test.pkl", save_dirname="rendering_nocomm")
881 | #inference("./results/0712/explainability_checkpoints/explainability_cov_map_global_gnn_300d3g1xqj_3120/best_model_69_val_auc=0.961077006252264.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #,save_dirname="rendering_comm")
882 | #inference("./results/0712/explainability_checkpoints/explainability_cov_map_global_gnn_300d3g1xqj_3120/best_model_69_val_auc=0.961077006252264.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl") #,save_dirname="rendering_comm")
883 |
884 | #evaluate("./results/0712/explainability_checkpoints/explainability_300d3g1xqj_3120/best_model_384_val_auc=0.9779040424681612.pth", "./results/0712/explainability_data_300d3g1xqj_3120_nocomm_test.pkl", save_dirname="rendering_nocomm")
885 | #evaluate("./results/0712/explainability_checkpoints/explainability_300d3g1xqj_3120/best_model_384_val_auc=0.9779040424681612.pth", "./results/0712/explainability_data_300d3g1xqj_3120_comm_test.pkl", save_dirname="rendering_comm")
886 |
887 | #inference("./results/0823/expl_checkpoints/explainability_cov_cov_map_local_271e7f5bc3_1560/best_model_30_val_ap=0.9611180740842954.pth", "../../Internship/gpvae/data/explainability_data_271e7f5bc3_1560_test.pkl") #,save_dirname="rendering_comm")
888 | inference("./results/0823/expl_checkpoints/explainability_cov_cov_only_local_271e7f5bc3_1560/best_model_85_val_ap=0.8428215464016102.pth", "../../Internship/gpvae/data/explainability_data_271e7f5bc3_1560_test.pkl") #,save_dirname="rendering_comm")
889 |
--------------------------------------------------------------------------------
/adversarial_comms/train_policy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections.abc
3 | import yaml
4 | import json
5 | import os
6 | import ray
7 |
8 | import numpy as np
9 |
10 | from pathlib import Path
11 | from ray import tune
12 | from ray.rllib.utils import try_import_torch
13 | from ray.rllib.models import ModelCatalog
14 | from ray.tune.registry import register_env
15 | from ray.tune.logger import pretty_print, DEFAULT_LOGGERS, TBXLogger
16 | from ray.rllib.utils.schedules import PiecewiseSchedule
17 | from ray.rllib.agents.callbacks import DefaultCallbacks
18 |
19 | from .environments.coverage import CoverageEnv
20 | from .environments.path_planning import PathPlanningEnv
21 | from .models.adversarial import AdversarialModel
22 | from .trainers.multiagent_ppo import MultiPPOTrainer
23 | from .trainers.hom_multi_action_dist import TorchHomogeneousMultiActionDistribution
24 |
25 | torch, _ = try_import_torch()
26 |
27 | def update_dict(d, u):
28 | for k, v in u.items():
29 | if isinstance(v, collections.abc.Mapping):
30 | d[k] = update_dict(d.get(k, {}), v)
31 | else:
32 | d[k] = v
33 | return d
34 |
35 | def trial_dirname_creator(trial):
36 | return str(trial) #f"{ray.tune.trial.date_str()}_{trial}"
37 |
38 | def dir_path(string):
39 | if os.path.isdir(string):
40 | return string
41 | else:
42 | raise NotADirectoryError(string)
43 |
44 | def check_file(string):
45 | if os.path.isfile(string):
46 | return string
47 | else:
48 | raise FileNotFoundError(string)
49 |
50 | def get_config_base():
51 | return Path(os.path.dirname(os.path.realpath(__file__))) / "config"
52 |
53 | class EvaluationCallbacks(DefaultCallbacks):
54 | def on_episode_start(self, worker, base_env, policies, episode, **kwargs):
55 | episode.user_data["reward_greedy"] = []
56 | episode.user_data["reward_coop"] = []
57 |
58 | def on_episode_step(self, worker, base_env, episode, **kwargs):
59 | ep_info = episode.last_info_for()
60 | if ep_info is not None and ep_info:
61 | episode.user_data["reward_greedy"].append(sum(ep_info['rewards_teams'][0].values()))
62 | episode.user_data["reward_coop"].append(sum(ep_info['rewards_teams'][1].values()))
63 |
64 | def on_episode_end(self, worker, base_env, policies, episode, **kwargs):
65 | episode.custom_metrics["reward_greedy"] = np.sum(episode.user_data["reward_greedy"])
66 | episode.custom_metrics["reward_coop"] = np.sum(episode.user_data["reward_coop"])
67 |
68 | '''
69 | def on_train_result(self, trainer, result, **kwargs):
70 | greedy_mse_fac = trainer.config['model']['custom_model_config']['greedy_mse_fac']
71 | if isinstance(greedy_mse_fac, list):
72 | s = PiecewiseSchedule(greedy_mse_fac[0], "torch", outside_value=greedy_mse_fac[1])
73 | trainer.workers.foreach_worker(
74 | lambda w: w.foreach_policy(
75 | lambda p, p_id: p.model.update_config({'greedy_mse_fac': s(result['timesteps_total'])})))
76 | '''
77 |
78 | def initialize():
79 | ray.init()
80 | register_env("coverage", lambda config: CoverageEnv(config))
81 | register_env("path_planning", lambda config: PathPlanningEnv(config))
82 | ModelCatalog.register_custom_model("adversarial", AdversarialModel)
83 | ModelCatalog.register_custom_action_dist("hom_multi_action", TorchHomogeneousMultiActionDistribution)
84 |
85 | def start_experiment():
86 | parser = argparse.ArgumentParser()
87 | parser.add_argument("experiment")
88 | parser.add_argument("-o", "--override", help='Key in alternative_config from which to take data to override main config', default=None)
89 | parser.add_argument("-t", "--timesteps", help="Number of total time steps for training stop condition in millions", type=int, default=20)
90 | args = parser.parse_args()
91 |
92 | try:
93 | config_path = check_file(args.experiment)
94 | except FileNotFoundError:
95 | config_path = get_config_base() / (args.experiment + ".yaml")
96 |
97 | with open(config_path, "rb") as config_file:
98 | config = yaml.load(config_file)
99 | if args.override is not None:
100 | if not args.override in config['alternative_config']:
101 | print("Invalid alternative config key! Choose one from:")
102 | print(config['alternative_config'].keys())
103 | exit()
104 | update_dict(config, config['alternative_config'][args.override])
105 | config.pop('alternative_config', None)
106 | config['callbacks'] = EvaluationCallbacks
107 |
108 | initialize()
109 | tune.run(
110 | MultiPPOTrainer,
111 | checkpoint_freq=10,
112 | stop={"timesteps_total": args.timesteps*1e6},
113 | keep_checkpoints_num=1,
114 | config=config,
115 | #local_dir="/tmp",
116 | trial_dirname_creator=trial_dirname_creator,
117 | )
118 |
119 | def continue_experiment():
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument("checkpoint", type=dir_path)
122 | parser.add_argument("-t", "--timesteps", help="Number of total time steps for training stop condition in millions", type=int, default=20)
123 | parser.add_argument("-e", "--experiment", help="Path/id to training config", default=None)
124 | parser.add_argument("-o", "--override", help='Key in alternative_config from which to take data to override main config', default=None)
125 |
126 | args = parser.parse_args()
127 |
128 | with open(Path(args.checkpoint) / '..' / 'params.json', "rb") as config_file:
129 | config = json.load(config_file)
130 |
131 | if args.experiment is not None:
132 | try:
133 | config_path = check_file(args.experiment)
134 | except FileNotFoundError:
135 | config_path = get_config_base() / (args.experiment + ".yaml")
136 |
137 | with open(config_path, "rb") as config_file:
138 | update_dict(config, yaml.load(config_file)['alternative_config'][args.override])
139 |
140 | config['callbacks'] = EvaluationCallbacks
141 |
142 | checkpoint_file = Path(args.checkpoint) / ('checkpoint-' + os.path.basename(args.checkpoint).split('_')[-1])
143 |
144 | initialize()
145 | tune.run(
146 | MultiPPOTrainer,
147 | checkpoint_freq=20,
148 | stop={"timesteps_total": args.timesteps*1e6},
149 | restore=checkpoint_file,
150 | keep_checkpoints_num=1,
151 | config=config,
152 | #local_dir="/tmp",
153 | trial_dirname_creator=trial_dirname_creator,
154 | )
155 |
156 | if __name__ == '__main__':
157 | start_experiment()
158 | exit()
159 |
160 |
161 | ### Cooperative
162 | run_experiment("./config/coverage.yaml", {"timesteps_total": 20e6}, None)
163 | run_experiment("./config/coverage_split.yaml", {"timesteps_total": 3e6}, None)
164 | run_experiment("./config/path_planning.yaml", {"timesteps_total": 20e6}, None)
165 |
166 | ### Adversarial
167 | continue_experiment("checkpoint_cov", {"timesteps_total": 60e6}, "./config/coverage.yaml", "adversarial")
168 | continue_experiment("checkpoint_split", {"timesteps_total": 20e6}, "./config/coverage_split.yaml", "adversarial")
169 | continue_experiment("checkpoint_flow", {"timesteps_total": 60e6}, "./config/path_planning.yaml", "adversarial")
170 |
171 | ### Re-adapt
172 | continue_experiment("checkpoint_cov_adv", {"timesteps_total": 90e6}, "./config/coverage.yaml", "cooperative")
173 | continue_experiment("checkpoint_split_adv", {"timesteps_total": 30e6}, "./config/coverage_split.yaml", "cooperative")
174 | continue_experiment("checkpoint_flow_adv", {"timesteps_total": 90e6}, "./config/path_planning.yaml", "cooperative")
175 |
176 |
177 |
--------------------------------------------------------------------------------
/adversarial_comms/trainers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proroklab/adversarial_comms/889a41e239958bd519365c6b0c143469cd27e49f/adversarial_comms/trainers/__init__.py
--------------------------------------------------------------------------------
/adversarial_comms/trainers/hom_multi_action_dist.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | import tree
4 | from ray.rllib.models.torch.torch_action_dist import TorchMultiActionDistribution
5 | from ray.rllib.utils.annotations import override
6 | from ray.rllib.utils.framework import try_import_torch
7 |
8 | torch, nn = try_import_torch()
9 |
10 |
11 | class InvalidActionSpace(Exception):
12 | """Raised when the action space is invalid"""
13 |
14 | pass
15 |
16 |
17 | class TorchHomogeneousMultiActionDistribution(TorchMultiActionDistribution):
18 | @override(TorchMultiActionDistribution)
19 | def logp(self, x):
20 | logps = []
21 | for i, (d, action_space) in enumerate(
22 | zip(self.flat_child_distributions, self.action_space_struct)
23 | ):
24 | if isinstance(action_space, gym.spaces.box.Box):
25 | assert len(action_space.shape) == 1
26 | a_w = action_space.shape[0]
27 | x_sel = x[:, a_w * i : a_w * (i + 1)]
28 | elif isinstance(action_space, gym.spaces.discrete.Discrete):
29 | x_sel = x[:, i]
30 | else:
31 | raise InvalidActionSpace(
32 | "Expect gym.spaces.box or gym.spaces.discrete action space"
33 | )
34 | logps.append(d.logp(x_sel))
35 |
36 | return torch.stack(logps, axis=1)
37 |
38 | @override(TorchMultiActionDistribution)
39 | def entropy(self):
40 | return torch.stack(
41 | [d.entropy() for d in self.flat_child_distributions], axis=-1
42 | )
43 |
44 | @override(TorchMultiActionDistribution)
45 | def sampled_action_logp(self):
46 | return torch.stack(
47 | [d.sampled_action_logp() for d in self.flat_child_distributions], axis=-1
48 | )
49 |
50 | @override(TorchMultiActionDistribution)
51 | def kl(self, other):
52 | return torch.stack(
53 | [
54 | d.kl(o)
55 | for d, o in zip(
56 | self.flat_child_distributions, other.flat_child_distributions
57 | )
58 | ],
59 | axis=-1,
60 | )
61 |
--------------------------------------------------------------------------------
/adversarial_comms/trainers/multiagent_ppo.py:
--------------------------------------------------------------------------------
1 | """
2 | PyTorch policy class used for PPO.
3 | """
4 | import gym
5 | import logging
6 | import numpy as np
7 | from typing import Dict, List, Optional, Type, Union
8 |
9 | import ray
10 | from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
11 | from ray.rllib.agents.ppo.ppo_torch_policy import kl_and_loss_stats, \
12 | vf_preds_fetches, setup_mixins, KLCoeffMixin, ValueNetworkMixin
13 | from ray.rllib.agents.trainer_template import build_trainer
14 | from ray.rllib.evaluation.episode import MultiAgentEpisode
15 | from ray.rllib.evaluation.postprocessing import compute_advantages, \
16 | Postprocessing
17 | from ray.rllib.models.modelv2 import ModelV2
18 | from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
19 | from ray.rllib.policy.policy import Policy
20 | from ray.rllib.policy.policy_template import build_policy_class
21 | from ray.rllib.policy.sample_batch import SampleBatch
22 | from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
23 | LearningRateSchedule
24 | from ray.rllib.utils.framework import try_import_torch
25 | from ray.rllib.utils.torch_ops import apply_grad_clipping, \
26 | convert_to_torch_tensor, explained_variance, sequence_mask
27 | from ray.rllib.utils.typing import TensorType, TrainerConfigDict, AgentID
28 |
29 | torch, nn = try_import_torch()
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 | class InvalidActionSpace(Exception):
34 | """Raised when the action space is invalid"""
35 | pass
36 |
37 |
38 | def compute_gae_for_sample_batch(
39 | policy: Policy,
40 | sample_batch: SampleBatch,
41 | other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
42 | episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
43 | """Adds GAE (generalized advantage estimations) to a trajectory.
44 | The trajectory contains only data from one episode and from one agent.
45 | - If `config.batch_mode=truncate_episodes` (default), sample_batch may
46 | contain a truncated (at-the-end) episode, in case the
47 | `config.rollout_fragment_length` was reached by the sampler.
48 | - If `config.batch_mode=complete_episodes`, sample_batch will contain
49 | exactly one episode (no matter how long).
50 | New columns can be added to sample_batch and existing ones may be altered.
51 | Args:
52 | policy (Policy): The Policy used to generate the trajectory
53 | (`sample_batch`)
54 | sample_batch (SampleBatch): The SampleBatch to postprocess.
55 | other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
56 | dict of AgentIDs mapping to other agents' trajectory data (from the
57 | same episode). NOTE: The other agents use the same policy.
58 | episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
59 | object in which the agents operated.
60 | Returns:
61 | SampleBatch: The postprocessed, modified SampleBatch (or a new one).
62 | """
63 |
64 | # the trajectory view API will pass populate the info dict with a np.zeros((n,))
65 | # array in the first call, in that case the dtype will be float32 and we
66 | # have to ignore it. For regular calls, we extract the rewards from the info
67 | # dict into the samplebatch_infos_rewards dict, which now holds the rewards
68 | # for all agents as dict.
69 | samplebatch_infos_rewards = {'0': sample_batch[SampleBatch.INFOS]}
70 | if not sample_batch[SampleBatch.INFOS].dtype == "float32":
71 | samplebatch_infos = SampleBatch.concat_samples([
72 | SampleBatch({k: [v] for k, v in s.items()})
73 | for s in sample_batch[SampleBatch.INFOS]
74 | ])
75 | samplebatch_infos_rewards = SampleBatch.concat_samples([
76 | SampleBatch({str(k): [v] for k, v in s.items()})
77 | for s in samplebatch_infos["rewards"]
78 | ])
79 |
80 | if not isinstance(policy.action_space, gym.spaces.tuple.Tuple):
81 | raise InvalidActionSpace("Expect tuple action space")
82 |
83 | # samplebatches for each agents
84 | batches = []
85 | for key, action_space in zip(samplebatch_infos_rewards.keys(), policy.action_space):
86 | i = int(key)
87 | sample_batch_agent = sample_batch.copy()
88 | sample_batch_agent[SampleBatch.REWARDS] = (samplebatch_infos_rewards[key])
89 | if isinstance(action_space, gym.spaces.box.Box):
90 | assert len(action_space.shape) == 1
91 | a_w = action_space.shape[0]
92 | elif isinstance(action_space, gym.spaces.discrete.Discrete):
93 | a_w = 1
94 | else:
95 | raise InvalidActionSpace("Expect gym.spaces.box or gym.spaces.discrete action space")
96 |
97 | sample_batch_agent[SampleBatch.ACTIONS] = sample_batch[SampleBatch.ACTIONS][:, a_w * i : a_w * (i + 1)]
98 | sample_batch_agent[SampleBatch.VF_PREDS] = sample_batch[SampleBatch.VF_PREDS][:, i]
99 |
100 | # Trajectory is actually complete -> last r=0.0.
101 | if sample_batch[SampleBatch.DONES][-1]:
102 | last_r = 0.0
103 | # Trajectory has been truncated -> last r=VF estimate of last obs.
104 | else:
105 | # Input dict is provided to us automatically via the Model's
106 | # requirements. It's a single-timestep (last one in trajectory)
107 | # input_dict.
108 | # Create an input dict according to the Model's requirements.
109 | input_dict = policy.model.get_input_dict(
110 | sample_batch, index="last")
111 | all_values = policy._value(**input_dict, seq_lens=input_dict.seq_lens)
112 | last_r = all_values[i].item()
113 |
114 | # Adds the policy logits, VF preds, and advantages to the batch,
115 | # using GAE ("generalized advantage estimation") or not.
116 | batches.append(
117 | compute_advantages(
118 | sample_batch_agent,
119 | last_r,
120 | policy.config["gamma"],
121 | policy.config["lambda"],
122 | use_gae=policy.config["use_gae"],
123 | use_critic=policy.config.get("use_critic", True)
124 | )
125 | )
126 |
127 | # Now take original samplebatch and overwrite following elements as a concatenation of these
128 | for k in [
129 | SampleBatch.REWARDS,
130 | SampleBatch.VF_PREDS,
131 | Postprocessing.ADVANTAGES,
132 | Postprocessing.VALUE_TARGETS,
133 | ]:
134 | sample_batch[k] = np.stack([b[k] for b in batches], axis=-1)
135 |
136 | return sample_batch
137 |
138 |
139 | def ppo_surrogate_loss(
140 | policy: Policy, model: ModelV2,
141 | dist_class: Type[TorchDistributionWrapper],
142 | train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
143 | """Constructs the loss for Proximal Policy Objective.
144 | Args:
145 | policy (Policy): The Policy to calculate the loss for.
146 | model (ModelV2): The Model to calculate the loss for.
147 | dist_class (Type[ActionDistribution]: The action distr. class.
148 | train_batch (SampleBatch): The training data.
149 | Returns:
150 | Union[TensorType, List[TensorType]]: A single loss tensor or a list
151 | of loss tensors.
152 | """
153 | logits, state = model.from_batch(train_batch, is_training=True)
154 | curr_action_dist = dist_class(logits, model)
155 |
156 | # RNN case: Mask away 0-padded chunks at end of time axis.
157 | if state:
158 | B = len(train_batch["seq_lens"])
159 | max_seq_len = logits.shape[0] // B
160 | mask = sequence_mask(
161 | train_batch["seq_lens"],
162 | max_seq_len,
163 | time_major=model.is_time_major())
164 | mask = torch.reshape(mask, [-1])
165 | num_valid = torch.sum(mask)
166 |
167 | def reduce_mean_valid(t):
168 | return torch.sum(t[mask]) / num_valid
169 |
170 | # non-RNN case: No masking.
171 | else:
172 | mask = None
173 | reduce_mean_valid = torch.mean
174 |
175 | loss_data = []
176 |
177 | curr_action_dist = dist_class(logits, model)
178 | prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
179 | model)
180 | logps = curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
181 | entropies = curr_action_dist.entropy()
182 |
183 | action_kl = prev_action_dist.kl(curr_action_dist)
184 | mean_kl = reduce_mean_valid(torch.sum(action_kl, axis=1))
185 |
186 | for i in range(len(train_batch[SampleBatch.VF_PREDS][0])):
187 | logp_ratio = torch.exp(
188 | logps[:, i] -
189 | train_batch[SampleBatch.ACTION_LOGP][:, i])
190 |
191 | mean_entropy = reduce_mean_valid(entropies[:, i])
192 |
193 | surrogate_loss = torch.min(
194 | train_batch[Postprocessing.ADVANTAGES][..., i] * logp_ratio,
195 | train_batch[Postprocessing.ADVANTAGES][..., i] * torch.clamp(
196 | logp_ratio, 1 - policy.config["clip_param"],
197 | 1 + policy.config["clip_param"]))
198 | mean_policy_loss = reduce_mean_valid(-surrogate_loss)
199 |
200 | if policy.config["use_gae"]:
201 | prev_value_fn_out = train_batch[SampleBatch.VF_PREDS][..., i]
202 | value_fn_out = model.value_function()[..., i]
203 | vf_loss1 = torch.pow(
204 | value_fn_out - train_batch[Postprocessing.VALUE_TARGETS][..., i], 2.0)
205 | vf_clipped = prev_value_fn_out + torch.clamp(
206 | value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
207 | policy.config["vf_clip_param"])
208 | vf_loss2 = torch.pow(
209 | vf_clipped - train_batch[Postprocessing.VALUE_TARGETS][..., i], 2.0)
210 | vf_loss = torch.max(vf_loss1, vf_loss2)
211 | mean_vf_loss = reduce_mean_valid(vf_loss)
212 | total_loss = reduce_mean_valid(
213 | -surrogate_loss + policy.kl_coeff * action_kl[:, i] +
214 | policy.config["vf_loss_coeff"] * vf_loss -
215 | policy.entropy_coeff * entropies[:, i])
216 | else:
217 | mean_vf_loss = 0.0
218 | total_loss = reduce_mean_valid(-surrogate_loss +
219 | policy.kl_coeff * action_kl[:, i] -
220 | policy.entropy_coeff * entropies[:, i])
221 |
222 | # Store stats in policy for stats_fn.
223 | loss_data.append(
224 | {
225 | "total_loss": total_loss,
226 | "mean_policy_loss": mean_policy_loss,
227 | "mean_vf_loss": mean_vf_loss,
228 | "mean_entropy": mean_entropy,
229 | }
230 | )
231 |
232 | policy._total_loss = (torch.sum(torch.stack([o["total_loss"] for o in loss_data])),)
233 | policy._mean_policy_loss = torch.mean(
234 | torch.stack([o["mean_policy_loss"] for o in loss_data])
235 | )
236 | policy._mean_vf_loss = torch.mean(
237 | torch.stack([o["mean_vf_loss"] for o in loss_data])
238 | )
239 | policy._mean_entropy = torch.mean(
240 | torch.stack([o["mean_entropy"] for o in loss_data])
241 | )
242 | policy._vf_explained_var = explained_variance(
243 | train_batch[Postprocessing.VALUE_TARGETS],
244 | policy.model.value_function())
245 | policy._mean_kl = mean_kl
246 |
247 | return policy._total_loss
248 |
249 |
250 | class ValueNetworkMixin:
251 | """This is exactly the same mixin class as in ppo_torch_policy,
252 | but that one calls .item() on self.model.value_function()[0],
253 | which will not work for us since our value function returns
254 | multiple values. Instead, we call .item() in
255 | compute_gae_for_sample_batch above.
256 | """
257 |
258 | def __init__(self, obs_space, action_space, config):
259 | if config["use_gae"]:
260 |
261 | def value(**input_dict):
262 | input_dict = SampleBatch(input_dict)
263 | input_dict = self._lazy_tensor_dict(input_dict)
264 | model_out, _ = self.model(input_dict)
265 | # [0] = remove the batch dim.
266 | return self.model.value_function()[0]
267 |
268 | else:
269 |
270 | def value(*args, **kwargs):
271 | return 0.0
272 |
273 | self._value = value
274 |
275 |
276 | def setup_mixins_override(policy: Policy, obs_space: gym.spaces.Space,
277 | action_space: gym.spaces.Space,
278 | config: TrainerConfigDict) -> None:
279 | """Have to initialize the custom ValueNetworkMixin
280 | """
281 | setup_mixins(policy, obs_space, action_space, config)
282 | ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
283 |
284 |
285 | # Build a child class of `TorchPolicy`, given the custom functions defined
286 | # above.
287 | MultiPPOTorchPolicy = build_policy_class(
288 | name="MultiPPOTorchPolicy",
289 | framework="torch",
290 | get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
291 | loss_fn=ppo_surrogate_loss,
292 | stats_fn=kl_and_loss_stats,
293 | extra_action_out_fn=vf_preds_fetches,
294 | postprocess_fn=compute_gae_for_sample_batch,
295 | extra_grad_process_fn=apply_grad_clipping,
296 | before_init=setup_config,
297 | before_loss_init=setup_mixins_override,
298 | mixins=[
299 | LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
300 | ValueNetworkMixin
301 | ],
302 | )
303 |
304 | def get_policy_class(config):
305 | return MultiPPOTorchPolicy
306 |
307 | MultiPPOTrainer = build_trainer(
308 | name="MultiPPO",
309 | default_config=ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
310 | validate_config=ray.rllib.agents.ppo.ppo.validate_config,
311 | default_policy=MultiPPOTorchPolicy,
312 | get_policy_class=get_policy_class,
313 | execution_plan=ray.rllib.agents.ppo.ppo.execution_plan
314 | )
315 |
--------------------------------------------------------------------------------
/adversarial_comms/trainers/random_heuristic.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | import ray
4 | import random
5 |
6 | import numpy as np
7 |
8 | from enum import Enum
9 | from gym import spaces
10 | from ray.rllib import Policy
11 | from ray.rllib.agents import with_common_config
12 | from ray.rllib.agents.trainer_template import build_trainer
13 | from ray.rllib.evaluation.worker_set import WorkerSet
14 | from ray.rllib.execution.metric_ops import StandardMetricsReporting
15 | from ray.rllib.execution.rollout_ops import ParallelRollouts, SelectExperiences
16 | from ray.rllib.models.modelv2 import restore_original_dimensions
17 | from ray.rllib.utils import override
18 | from ray.rllib.utils.typing import TrainerConfigDict
19 | from ray.util.iter import LocalIterator
20 | from ray.tune.registry import register_env
21 |
22 | DEFAULT_CONFIG = with_common_config({})
23 |
24 | class Action(Enum):
25 | NOP = 0
26 | MOVE_RIGHT = 1
27 | MOVE_LEFT = 2
28 | MOVE_UP = 3
29 | MOVE_DOWN = 4
30 |
31 | X = 1
32 | Y = 0
33 |
34 | class RandomHeuristicPolicy(Policy, ABC):
35 | """
36 | Based on
37 | https://github.com/ray-project/ray/blob/releases/1.0.1/rllib/examples/policy/random_policy.py
38 | Visit a random uncovered neighboring cell or a random cell if all are covered
39 | """
40 |
41 | def __init__(self, *args, **kwargs):
42 | super().__init__(*args, **kwargs)
43 |
44 | def single_random_heuristic(self, obs):
45 | state_obstacles, state_coverage = (obs[:, :, i] for i in range(2))
46 | half_state_shape = (np.array(state_obstacles.shape)/2).astype(int)
47 | actions_deltas = {
48 | Action.MOVE_RIGHT.value: [ 0, 1],
49 | Action.MOVE_LEFT.value: [ 0, -1],
50 | Action.MOVE_UP.value: [-1, 0],
51 | Action.MOVE_DOWN.value: [ 1, 0],
52 | }
53 |
54 | options_free = []
55 | options_uncovered = []
56 | for a, dp in actions_deltas.items():
57 | p = half_state_shape + dp
58 | if state_obstacles[p[Y], p[X]] > 0:
59 | continue
60 | options_free.append(a)
61 |
62 | if state_coverage[p[Y], p[X]] > 0:
63 | continue
64 | options_uncovered.append(a)
65 |
66 | if len(options_uncovered) > 0:
67 | return random.choice(options_uncovered)
68 | elif len(options_free) > 0:
69 | return random.choice(options_free)
70 | return NOP.value
71 |
72 | @override(Policy)
73 | def compute_actions(self,
74 | obs_batch,
75 | state_batches=None,
76 | prev_action_batch=None,
77 | prev_reward_batch=None,
78 | info_batch=None,
79 | episodes=None,
80 | **kwargs):
81 |
82 | obs_batch = restore_original_dimensions(
83 | np.array(obs_batch, dtype=np.float32),
84 | self.observation_space,
85 | tensorlib=np)
86 |
87 | r = np.array([[self.single_random_heuristic(map_batch) for map_batch in agent['map']] for agent in obs_batch['agents']])
88 | return r.transpose(), [], {}
89 |
90 | def learn_on_batch(self, samples):
91 | pass
92 |
93 | def get_weights(self):
94 | pass
95 |
96 | def set_weights(self, weights):
97 | pass
98 |
99 |
100 | def execution_plan(workers: WorkerSet,
101 | config: TrainerConfigDict) -> LocalIterator[dict]:
102 | rollouts = ParallelRollouts(workers, mode="async")
103 |
104 | # Collect batches for the trainable policies.
105 | rollouts = rollouts.for_each(
106 | SelectExperiences(workers.trainable_policies()))
107 |
108 | # Return training metrics.
109 | return StandardMetricsReporting(rollouts, workers, config)
110 |
111 |
112 | RandomHeuristicTrainer = build_trainer(
113 | name="RandomHeuristic",
114 | default_config=DEFAULT_CONFIG,
115 | default_policy=RandomHeuristicPolicy,
116 | execution_plan=execution_plan)
117 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.8.1
2 | ray[rllib]==1.3.0
3 | matplotlib==3.4.1
4 | sklearn==0.24.1
5 | torchsummary==1.5.1
6 |
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import setuptools
3 |
4 | package_dir = os.path.dirname(os.path.realpath(__file__))
5 |
6 | with open(package_dir + "/README.md", "r") as fh:
7 | long_description = fh.read()
8 |
9 | requirements_dir = package_dir + '/requirements.txt'
10 | install_requires = []
11 | with open(requirements_dir) as f:
12 | install_requires = f.read().splitlines()
13 |
14 | setuptools.setup(
15 | name="adversarial-comms",
16 | version="1.1",
17 | author="Jan Blumenkamp",
18 | author_email="jb2270@cam.ac.uk",
19 | description="Package accompanying the paper 'The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning'",
20 | long_description=long_description,
21 | long_description_content_type="text/markdown",
22 | url="https://github.com/proroklab/adversarial_comms",
23 | packages=setuptools.find_packages(),
24 | classifiers=[
25 | "Programming Language :: Python :: 3",
26 | "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)",
27 | "Operating System :: OS Independent",
28 | ],
29 | install_requires=install_requires,
30 | entry_points = {
31 | 'console_scripts': [
32 | 'train_policy=adversarial_comms.train_policy:start_experiment',
33 | 'continue_policy=adversarial_comms.train_policy:continue_experiment',
34 | 'evaluate_coop=adversarial_comms.evaluate:eval_nocomm_coop',
35 | 'evaluate_adv=adversarial_comms.evaluate:eval_nocomm_adv',
36 | 'evaluate_random=adversarial_comms.evaluate:eval_random',
37 | 'evaluate_plot=adversarial_comms.evaluate:plot',
38 | 'evaluate_serve=adversarial_comms.evaluate:serve'
39 | ],
40 | },
41 | python_requires='>=3.7',
42 | )
43 |
--------------------------------------------------------------------------------