├── HMM
└── tmp
├── LICENSE
├── README.md
├── environment.yml
├── figures
└── paper_figures.ipynb
├── pyproject.toml
├── requirements.txt
├── scripts
├── MultiRunner.py
├── classify.py
├── create_cross_validation_data.py
├── cross_validate.py
├── optimize_ML.py
├── predict_hypothetical.py
├── rarefaction_analysis.py
└── utils.py
└── src
└── genomic_embeddings
├── Embeddings.py
├── Gff.py
├── __init__.py
├── corpus.py
├── data.py
├── gene2vec.py
├── models.py
└── plot.py
/HMM/tmp:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # genomic-nlp
2 |
3 | This repository contains the code used for compiling and analyzing the "biological corpus" presented in the paper:
4 |
5 | **Deciphering microbial gene function using natural language processing**
6 |
7 | [](https://zenodo.org/badge/latestdoi/449665025)
8 | [](https://doi.org/10.5281/zenodo.7047944)
9 |
10 | :round_pushpin:The model developed in the paper is available as a web service [here](https://gnlp.bursteinlab.org/).
11 |
12 | ## Getting the data
13 |
14 | Start by downloading the data files from the Zenodo database.
15 |
16 | 1. Click on the Zenodo link at the top of the repository or use [this link](https://zenodo.org/record/7047944) to download the data zip file
17 | 2. Alternatively, use the command line as follows:
18 | ```
19 | mkdir data
20 | cd data
21 |
22 | wget https://zenodo.org/record/7047944/files/models_and_data.tar.gz?download=1
23 | tar -zxvf models_and_data.tar.gz
24 | rm models_and_data.tar.gz
25 | ```
26 |
27 | ## Setting up the working environment
28 | First, set up python environment and dependencies.
29 | #### using pip
30 | ```
31 | python3 -m venv g2v-env
32 | source g2v-env/bin/activate
33 | pip install -r requirements.txt
34 | ```
35 | #### using conda
36 | ```
37 | conda env create -f environment.yml
38 | conda activate g2v-env
39 | ```
40 |
41 | The setup was tested on Python 3.7.
42 | Versions of all required programs appear in `requirements.txt` (for pip) and `environment.yml` (for conda).
43 |
44 | ### code availability
45 | The source code used to train the word2vec model, extract its embedding and functional classifier can be
46 | downloaded using pip:
47 |
48 | ```
49 | pip install genomic-embeddings
50 | ```
51 |
52 | ### Trained gene annotation embedding
53 | The trained word2vec model on the entire genomic corpus are available in `models_and_data` as a gensim model.
54 | To farther use them for downstream analysis set up your working environment and load the model.
55 |
56 | In python:
57 | ```
58 | from genomic_embeddings import Embeddings
59 |
60 | model_path = model_and_data/gene2vec_w5_v300_tf24_annotation_extended/gene2vec_w5_v300_tf24_annotation_extended_2021-10-03.w2v
61 | gene_embeddings = Embeddings.load_embeddings(model_path)
62 | ```
63 |
64 | from here you may use [gensim api](https://radimrehurek.com/gensim/models/word2vec.html) to extract words embeddings,
65 | calculate distances between words and more
66 | For example:
67 | ```
68 | gene_embeddings.wv["K09140.2"]
69 | ```
70 | will obtain the embedding of the word `K09140.2`, a sub-cluster of the KO identifier `K09140` in KEGG.
71 |
72 | ### Two-dimensional embedding space
73 | Gene embeddings after dimension reduction using UMAP are available as a pickle file.
74 |
75 | In python:
76 | ```
77 | from genomic-embeddings import Embeddings
78 |
79 | embeddings_2d_rep_path = "model_and_data/gene2vec_w5_v300_tf24_annotation_extended/words_umap_2021-10-03"
80 | embeddings_2d = Embeddings.get_2d_mapping(embeddings_2d_rep_path)
81 | ```
82 |
83 | ### Functional classifier
84 | To re-train all function classifier\generate performance plots:
85 |
86 | ```
87 | from genomic_embeddings.models import NNClf
88 | from genomic_embeddings.data import Embedding
89 | from genomic_embeddings.plot import ModelPlots
90 |
91 | metadata_path = '/models_and_data/metadata.csv'
92 | labels = ['Prokaryotic defense system','Ribosome','Secretion system'] # example labels
93 |
94 | # load embedding
95 | emb = Embedding(mdl=model_path, metadata=metadata_path, labels=labels)
96 | emb.process_data_pipeline(label='label', q='', add_other=True)
97 | X, y = emb.data.drop(columns=['label']).values, emb.data['label'].values
98 |
99 | # classify
100 | clf = NNClf(X=X, y=y, out_dir='./')
101 | clf.classification_pipeline('label', alias='DNN')
102 |
103 | # plot
104 | plotter = ModelPlots(mdl=clf)
105 | plotter.plot_precision_recall()
106 | plotter.plot_roc()
107 | ```
108 | ### Function classification model validation
109 | Function classification validations are available in:
110 | `models_and_data/gene2vec_w5_v300_tf24_annotation_extended/predictions`.
111 | To re-run validations and generate AUC and AUPR graphs run the following script:
112 | ```
113 | python scripts/classify.py --model PATH_TO_W2V_MDL --output PATH_TO_OUT_DIR --metadata PATH_TO_METADATA
114 | ```
115 | The csv file `metadata.csv` can be found in `models_and_data`.
116 | Running this script will produce all data found under the folder:
117 | `models_and_data/gene2vec_w5_v300_tf24_annotation_extended/predictions`
118 |
119 | ### Function classification of all hypothetical proteins
120 | All predictions of hypothetical proteins in the corpus can be found here:
121 | `models_and_data/gene2vec_w5_v300_tf24_annotation_extended/predictions/hypothetical_predictions.pkl`
122 |
123 | To load the file as table, run in python:
124 | ```
125 | import
126 | preds_path = "models_and_data/gene2vec_w5_v300_tf24_annotation_extended/predictions/hypothetical_predictions.pkl"
127 | preds = get_functional_prediction(preds_path)
128 | ```
129 | or the alternative
130 | ```
131 | import pandas as pd
132 | table = pd.read_pickle("models_and_data/gene2vec_w5_v300_tf24_annotation_extended/predictions/hypothetical_predictions.pkl")
133 | ```
134 | To **regenerate** the model predictions run:
135 | ```
136 | cd models_and_data/gene2vec_w5_v300_tf24_annotation_extended/
137 | python scripts/predict_hypothetical.py --model PATH_TO_W2V_MDL --output PATH_TO_OUT_DIR --metadata ../metadata.csv
138 | ```
139 |
140 |
141 |
142 | ### Re-training word embeddings using the corpus
143 | Re-training word embeddings with different parameters can be executed using the following script:
144 | 1. First, go to `models_and_data` folder and extract the corpus files
145 | ```
146 | cd models_and_data
147 | tar -zxvf corpus.tar.gz
148 | ```
149 | 2. Train the model
150 | ```
151 | python src/gene2vec.py --input 'corpus/*.txt'
152 | ```
153 | To change specific parameters of the algorithm run
154 | `python src/gene2vec.py --help` and configure accordingly.
155 |
156 |
157 | ### Running times
158 | Model loading, result generation and analysis script are anticipated to run from few seconds up to 4-5 min.\
159 | re-training of language model, and dimensionality reduction can take up to 10h with 20 CPUs.
160 |
161 |
162 | ### Paper figure reproducibility
163 | All paper figures (excluding illustrations) are available as a jupyter notebook.
164 | To run the notebook on your computer, go to `figures/` and type `jupyter notebook` in your command line.
165 | The notebook `paper_figures.ipynb` will be available on your local machine.
166 |
167 | *Note:* running the notebook requires the `models_and_data` folder, configure paths accordingly.
168 |
169 | ### HMM DB
170 | The HMM database used to annotate the KEGG orthologs (KOs), can be found here:
171 | [**kg.05_21.ren4prok.2.hmm.db.gz** (2.7GB)](https://drive.google.com/file/d/1am-9fxYXtoZ_RGyzJ-UXW2Qbpfv2srX1/view?usp=sharing).
172 |
173 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: g2v-env
2 | channels:
3 | - pytorch
4 | - plotly
5 | - pyviz
6 | - etetoolkit
7 | - conda-forge
8 | - bioconda
9 | - defaults
10 | dependencies:
11 | - _ipyw_jlab_nb_ext_conf=0.1.0=py37_0
12 | - _py-xgboost-mutex=2.0=cpu_0
13 | - _pytorch_select=0.1=cpu_0
14 | - _tflow_select=2.3.0=mkl
15 | - absl-py=0.12.0=py37hecd8cb5_0
16 | - aiohttp=3.7.4=py37h9ed2024_1
17 | - alabaster=0.7.12=py37_0
18 | - anaconda-client=1.7.2=py37_0
19 | - anaconda-navigator=1.10.0=py37_0
20 | - appnope=0.1.2=py37hecd8cb5_1001
21 | - argon2-cffi=20.1.0=py37h9ed2024_1
22 | - astor=0.8.1=py37hecd8cb5_0
23 | - async-timeout=3.0.1=py37hecd8cb5_0
24 | - async_generator=1.10=py37h28b3542_0
25 | - attrs=20.3.0=pyhd3eb1b0_0
26 | - babel=2.9.0=pyhd3eb1b0_0
27 | - backcall=0.2.0=pyhd3eb1b0_0
28 | - backports=1.0=pyhd3eb1b0_2
29 | - backports.functools_lru_cache=1.6.1=pyhd3eb1b0_0
30 | - backports.tempfile=1.0=pyhd3eb1b0_1
31 | - backports.weakref=1.0.post1=py_1
32 | - beautifulsoup4=4.9.3=pyha847dfd_0
33 | - biopython=1.78=py37haf1e3a3_0
34 | - blas=1.0=mkl
35 | - bleach=3.3.0=pyhd3eb1b0_0
36 | - blosc=1.21.0=h2842e9f_0
37 | - bokeh=2.2.3=py37_0
38 | - boto=2.49.0=py37_0
39 | - boto3=1.17.27=pyhd3eb1b0_0
40 | - botocore=1.20.28=pyhd3eb1b0_1
41 | - branca=0.4.0=py_0
42 | - brotlipy=0.7.0=py37h9ed2024_1003
43 | - bzip2=1.0.8=h1de35cc_0
44 | - c-ares=1.17.1=h9ed2024_0
45 | - ca-certificates=2021.5.30=h033912b_0
46 | - cachetools=4.2.1=pyhd3eb1b0_0
47 | - cairo=1.14.12=hc4e6be7_4
48 | - certifi=2021.5.30=py37hf985489_0
49 | - cffi=1.14.0=py37hb5b8e2f_0
50 | - chardet=3.0.4=py37hecd8cb5_1003
51 | - click=7.1.2=pyhd3eb1b0_0
52 | - cloudpickle=1.6.0=py_0
53 | - clyent=1.2.2=py37_1
54 | - colorcet=2.0.6=pyhd3eb1b0_0
55 | - conda=4.10.1=py37hf985489_0
56 | - conda-build=3.20.5=py37_1
57 | - conda-env=2.6.0=1
58 | - conda-package-handling=1.7.2=py37h22f3db7_0
59 | - conda-verify=3.4.2=py_1
60 | - coverage=5.5=py37h9ed2024_2
61 | - cryptography=3.4.6=py37h2fd3fbb_0
62 | - curl=7.69.1=ha441bb4_0
63 | - cycler=0.10.0=py37_0
64 | - cython=0.29.22=py37h23ab428_0
65 | - cytoolz=0.11.0=py37haf1e3a3_0
66 | - dask=2021.3.0=pyhd3eb1b0_0
67 | - dask-core=2021.3.0=pyhd3eb1b0_0
68 | - datashader=0.11.1=py_0
69 | - datashape=0.5.4=py37hecd8cb5_1
70 | - dbus=1.13.18=h18a8e69_0
71 | - decorator=4.4.2=pyhd3eb1b0_0
72 | - defusedxml=0.7.1=pyhd3eb1b0_0
73 | - distributed=2021.3.0=py37hecd8cb5_0
74 | - docutils=0.16=py37_1
75 | - entrypoints=0.3=py37_0
76 | - ete3=3.1.1=pyhf5214e1_0
77 | - expat=2.2.10=hb1e8313_2
78 | - filelock=3.0.12=pyhd3eb1b0_1
79 | - folium=0.10.1=py_0
80 | - fontconfig=2.13.1=ha9ee91d_0
81 | - freetype=2.10.4=ha233b18_0
82 | - fribidi=1.0.10=haf1e3a3_0
83 | - fsspec=0.8.3=py_0
84 | - future=0.18.2=py37_1
85 | - gast=0.2.2=py37_0
86 | - gensim=3.8.0=py37h6440ff4_0
87 | - gettext=0.19.8.1=h15daf44_3
88 | - glib=2.63.1=hd977a24_0
89 | - glob2=0.7=pyhd3eb1b0_0
90 | - gmp=6.2.1=h23ab428_2
91 | - google-api-core=1.25.1=pyhd3eb1b0_0
92 | - google-auth=1.27.1=pyhd3eb1b0_0
93 | - google-cloud-core=1.6.0=pyhd3eb1b0_0
94 | - google-cloud-storage=1.36.2=pyhd3eb1b0_0
95 | - google-crc32c=1.1.2=py37h9ed2024_0
96 | - google-pasta=0.2.0=py_0
97 | - google-resumable-media=1.2.0=pyhd3eb1b0_1
98 | - googleapis-common-protos=1.52.0=py37hecd8cb5_0
99 | - graphite2=1.3.14=h38d11af_0
100 | - graphviz=2.40.1=hefbbd9a_2
101 | - grpcio=1.36.1=py37h97de6d8_1
102 | - h5py=2.10.0=py37h3134771_0
103 | - harfbuzz=1.8.8=hb8d4a28_0
104 | - hdf5=1.10.4=hfa1e0ec_0
105 | - heapdict=1.0.1=py_0
106 | - holoviews=1.14.1=py_0
107 | - htslib=1.9=h3a161e8_7
108 | - icu=58.2=h0a44026_3
109 | - idna=2.10=pyhd3eb1b0_0
110 | - igraph=0.7.1=h0a67f88_1005
111 | - imagesize=1.2.0=pyhd3eb1b0_0
112 | - importlib-metadata=3.7.3=py37hecd8cb5_1
113 | - importlib_metadata=3.7.3=hd3eb1b0_1
114 | - intel-openmp=2019.4=233
115 | - ipykernel=5.3.4=py37h5ca1d4c_0
116 | - ipython=7.21.0=py37h01d92e1_0
117 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
118 | - ipywidgets=7.6.3=pyhd3eb1b0_1
119 | - itypes=1.1.0=py_0
120 | - jedi=0.17.0=py37_0
121 | - jinja2=2.11.3=pyhd3eb1b0_0
122 | - jmespath=0.10.0=py_0
123 | - joblib=1.0.1=pyhd8ed1ab_0
124 | - jpeg=9b=he5867d9_2
125 | - json5=0.9.5=py_0
126 | - jsonschema=3.2.0=py_2
127 | - jupyter_client=6.1.7=py_0
128 | - jupyter_core=4.7.1=py37hecd8cb5_0
129 | - jupyterlab=2.2.6=pyhd3eb1b0_1
130 | - jupyterlab_pygments=0.1.2=py_0
131 | - jupyterlab_server=1.2.0=py_0
132 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
133 | - jupyterthemes=0.20.0=py_1
134 | - keras=2.3.1=0
135 | - keras-applications=1.0.8=py_1
136 | - keras-base=2.3.1=py37_0
137 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0
138 | - kiwisolver=1.3.1=py37h23ab428_0
139 | - krb5=1.17.1=hddcf347_0
140 | - lcms2=2.11=h92f6f08_0
141 | - leidenalg=0.7.0=py37h570ac47_2
142 | - lesscpy=0.14.0=pyhd8ed1ab_0
143 | - libarchive=3.4.2=haa3ed63_0
144 | - libblas=3.8.0=14_mkl
145 | - libcblas=3.8.0=14_mkl
146 | - libcrc32c=1.1.1=hb1e8313_2
147 | - libcurl=7.69.1=h051b688_0
148 | - libcxx=11.1.0=habf9029_0
149 | - libdeflate=1.0=h1de35cc_1
150 | - libedit=3.1.20181209=hb402a30_0
151 | - libffi=3.2.1=h0a44026_1007
152 | - libgfortran=3.0.1=h93005f0_2
153 | - libiconv=1.16=h1de35cc_0
154 | - liblapack=3.8.0=14_mkl
155 | - liblief=0.10.1=h0a44026_0
156 | - libllvm10=10.0.1=h76017ad_5
157 | - libmklml=2019.0.5=0
158 | - libpng=1.6.37=ha441bb4_0
159 | - libprotobuf=3.14.0=h2842e9f_0
160 | - libsodium=1.0.18=h1de35cc_0
161 | - libssh2=1.9.0=ha12b0ac_1
162 | - libtiff=4.2.0=h9da4c3f_0
163 | - libwebp-base=1.2.0=h9ed2024_0
164 | - libxgboost=1.3.3=he49afe7_2
165 | - libxml2=2.9.10=h7cdb67c_3
166 | - libxslt=1.1.34=h83b36ba_0
167 | - llvm-openmp=11.0.1=h7c73e74_0
168 | - llvmlite=0.34.0=py37h739e7dc_4
169 | - locket=0.2.1=py37hecd8cb5_1
170 | - lxml=4.6.2=py37h26b266a_0
171 | - lz4-c=1.9.3=h23ab428_0
172 | - lzo=2.10=haf1e3a3_2
173 | - markdown=3.3.4=py37hecd8cb5_0
174 | - markupsafe=1.1.1=py37h1de35cc_0
175 | - matplotlib=3.3.4=py37hecd8cb5_0
176 | - matplotlib-base=3.3.4=py37h8b3ea08_0
177 | - mistune=0.8.4=py37h1de35cc_0
178 | - mkl=2019.4=233
179 | - mkl-service=2.3.0=py37h9ed2024_0
180 | - mkl_fft=1.3.0=py37ha059aab_0
181 | - mkl_random=1.1.1=py37h959d312_0
182 | - mock=4.0.3=pyhd3eb1b0_0
183 | - msgpack-python=1.0.2=py37hf7b0b51_1
184 | - multidict=5.1.0=py37h9ed2024_2
185 | - multipledispatch=0.6.0=py37_0
186 | - navigator-updater=0.2.1=py37_0
187 | - nbclient=0.5.3=pyhd3eb1b0_0
188 | - nbconvert=6.0.7=py37_0
189 | - nbformat=5.1.2=pyhd3eb1b0_1
190 | - ncurses=6.1=h0a44026_1
191 | - nest-asyncio=1.5.1=pyhd3eb1b0_0
192 | - networkx=2.5=py_0
193 | - ninja=1.10.2=py37hf7b0b51_0
194 | - nltk=3.5=py_0
195 | - notebook=6.2.0=py37hecd8cb5_0
196 | - numba=0.51.2=py37h959d312_1
197 | - numexpr=2.7.3=py37h16bde0e_0
198 | - numpy=1.20.0=py37ha9839cc_0
199 | - olefile=0.46=py37_0
200 | - openapi-codec=1.3.2=py_0
201 | - openssl=1.1.1k=h0d85af4_0
202 | - opt_einsum=3.1.0=py_0
203 | - packaging=20.9=pyhd3eb1b0_0
204 | - pandas=1.2.2=py37hb2f4e1b_0
205 | - pandoc=2.12=hecd8cb5_0
206 | - pandocfilters=1.4.3=py37hecd8cb5_1
207 | - panel=0.10.3=pyhd3eb1b0_1
208 | - pango=1.42.4=h7e27002_1
209 | - param=1.10.1=pyhd3eb1b0_0
210 | - parso=0.8.1=pyhd3eb1b0_0
211 | - partd=1.1.0=py_0
212 | - patsy=0.5.1=py37_0
213 | - pcre=8.44=hb1e8313_0
214 | - pexpect=4.8.0=pyhd3eb1b0_3
215 | - pickleshare=0.7.5=pyhd3eb1b0_1003
216 | - pillow=8.1.2=py37h5270095_0
217 | - pixman=0.40.0=haf1e3a3_0
218 | - pkginfo=1.7.0=py37hecd8cb5_0
219 | - plotly=4.14.1=pyhd3eb1b0_0
220 | - plotly-orca=1.3.1=1
221 | - ply=3.11=py_1
222 | - prometheus_client=0.9.0=pyhd3eb1b0_0
223 | - prompt-toolkit=3.0.8=py_0
224 | - protobuf=3.14.0=py37h23ab428_1
225 | - psutil=5.8.0=py37h9ed2024_1
226 | - ptyprocess=0.7.0=pyhd3eb1b0_2
227 | - py-lief=0.10.1=py37haf313ee_0
228 | - py-xgboost=1.3.3=py37hf985489_2
229 | - pyasn1=0.4.8=py_0
230 | - pyasn1-modules=0.2.8=py_0
231 | - pycairo=1.19.1=py37h06c6e95_0
232 | - pycosat=0.6.3=py37h9ed2024_0
233 | - pycparser=2.20=py_2
234 | - pyct=0.4.8=py37_0
235 | - pygments=2.8.1=pyhd3eb1b0_0
236 | - pygraphviz=1.3=py37h1de35cc_1
237 | - pynndescent=0.5.2=pyhd3eb1b0_0
238 | - pyopenssl=20.0.1=pyhd3eb1b0_1
239 | - pyparsing=2.4.7=pyhd3eb1b0_0
240 | - pyqt=5.9.2=py37h655552a_2
241 | - pyrsistent=0.17.3=py37haf1e3a3_0
242 | - pysocks=1.7.1=py37hecd8cb5_0
243 | - pytables=3.6.1=py37h5bccee9_0
244 | - python=3.7.5=h359304d_0
245 | - python-coreapi=2.3.3=py_0
246 | - python-coreschema=0.0.4=py_0
247 | - python-dateutil=2.8.1=pyhd3eb1b0_0
248 | - python-igraph=0.7.1.post7=py37h0b31af3_0
249 | - python-libarchive-c=2.9=pyhd3eb1b0_0
250 | - python.app=3=py37h9ed2024_0
251 | - python_abi=3.7=1_cp37m
252 | - pytorch=1.6.0=cpu_py37hd70000b_0
253 | - pytz=2021.1=pyhd3eb1b0_0
254 | - pyviz_comms=2.0.1=pyhd3eb1b0_0
255 | - pyyaml=5.4.1=py37h9ed2024_1
256 | - pyzmq=20.0.0=py37h23ab428_1
257 | - qt=5.9.7=h468cd18_1
258 | - qtpy=1.9.0=py_0
259 | - readline=7.0=h1de35cc_5
260 | - regex=2020.11.13=py37h9ed2024_0
261 | - requests=2.25.1=pyhd3eb1b0_0
262 | - retrying=1.3.3=py37_2
263 | - ripgrep=12.1.1=0
264 | - rsa=4.7.2=pyhd3eb1b0_1
265 | - ruamel_yaml=0.15.87=py37haf1e3a3_1
266 | - s3transfer=0.3.4=pyhd3eb1b0_0
267 | - samtools=1.9=h8aa4d43_12
268 | - scikit-learn=0.24.1=py37hb2f4e1b_0
269 | - scipy=1.6.0=py37h2515648_0
270 | - seaborn=0.11.1=pyhd3eb1b0_0
271 | - send2trash=1.5.0=pyhd3eb1b0_1
272 | - setuptools=52.0.0=py37hecd8cb5_0
273 | - simplejson=3.17.2=py37h9ed2024_2
274 | - sip=4.19.8=py37h0a44026_0
275 | - six=1.15.0=py37hecd8cb5_0
276 | - smart_open=4.2.0=pyhd3eb1b0_0
277 | - snowballstemmer=2.1.0=pyhd3eb1b0_0
278 | - sortedcontainers=2.3.0=pyhd3eb1b0_0
279 | - soupsieve=2.2=pyhd3eb1b0_0
280 | - sphinx=3.2.1=py_0
281 | - sphinx_rtd_theme=0.4.3=py_0
282 | - sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0
283 | - sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0
284 | - sphinxcontrib-htmlhelp=1.0.3=pyhd3eb1b0_0
285 | - sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0
286 | - sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0
287 | - sphinxcontrib-serializinghtml=1.1.4=pyhd3eb1b0_0
288 | - sqlite=3.31.1=ha441bb4_0
289 | - statsmodels=0.12.2=py37h9ed2024_0
290 | - tblib=1.7.0=py_0
291 | - tensorboard=2.0.0=pyhb38c66f_1
292 | - tensorflow=2.0.0=mkl_py37hda344b4_0
293 | - tensorflow-base=2.0.0=mkl_py37h66b1bf0_0
294 | - tensorflow-estimator=2.0.0=pyh2649769_0
295 | - termcolor=1.1.0=py37hecd8cb5_1
296 | - terminado=0.9.3=py37hecd8cb5_0
297 | - testpath=0.4.4=pyhd3eb1b0_0
298 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
299 | - tk=8.6.10=hb0a8c7a_0
300 | - toolz=0.11.1=pyhd3eb1b0_0
301 | - torchtext=0.6.0=py_1
302 | - tornado=6.1=py37h9ed2024_0
303 | - tqdm=4.59.0=pyhd3eb1b0_1
304 | - traitlets=5.0.5=pyhd3eb1b0_0
305 | - typing-extensions=3.7.4.3=hd3eb1b0_0
306 | - typing_extensions=3.7.4.3=pyh06a4308_0
307 | - umap-learn=0.5.1=py37hf985489_0
308 | - uritemplate=3.0.0=py_1
309 | - urllib3=1.26.4=pyhd3eb1b0_0
310 | - viennarna=2.4.13=py37hd9629dc_2
311 | - vincent=0.4.4=py_1
312 | - wcwidth=0.2.5=py_0
313 | - webencodings=0.5.1=py37_1
314 | - werkzeug=1.0.1=pyhd3eb1b0_0
315 | - wheel=0.36.2=pyhd3eb1b0_0
316 | - widgetsnbextension=3.5.1=py37_0
317 | - wrapt=1.12.1=py37h1de35cc_1
318 | - xarray=0.17.0=pyhd3eb1b0_0
319 | - xmltodict=0.12.0=py_0
320 | - xz=5.2.5=h1de35cc_0
321 | - yaml=0.2.5=haf1e3a3_0
322 | - yarl=1.5.1=py37haf1e3a3_0
323 | - zeromq=4.3.3=hb1e8313_3
324 | - zict=2.0.0=pyhd3eb1b0_0
325 | - zipp=3.4.1=pyhd3eb1b0_0
326 | - zlib=1.2.11=h1de35cc_3
327 | - zstd=1.4.5=h41d2c2f_0
328 | - pip:
329 | - anndata==0.7.5
330 | - ansi2html==1.6.0
331 | - bcbio-gff==0.6.6
332 | - biom-format==2.1.10
333 | - brotli==1.0.9
334 | - budgitree==0.0.9
335 | - cachecontrol==0.12.6
336 | - chart-studio==1.1.0
337 | - colour==0.1.5
338 | - dash==1.19.0
339 | - dash-auth==1.4.1
340 | - dash-bootstrap-components==0.11.1
341 | - dash-core-components==1.15.0
342 | - dash-html-components==1.1.2
343 | - dash-renderer==1.9.0
344 | - dash-table==4.11.2
345 | - distance==0.1.3
346 | - dna-features-viewer==3.1.0
347 | - dovpanda==0.0.5
348 | - dtreeviz==1.1.3
349 | - et-xmlfile==1.1.0
350 | - explainerdashboard==0.2.19.1
351 | - fa2==0.3.5
352 | - flask==1.1.2
353 | - flask-compress==1.8.0
354 | - flask-seasurf==0.3.0
355 | - flask-simplelogin==0.0.7
356 | - flask-wtf==0.14.3
357 | - get-version==2.1
358 | - hdbscan==0.8.27
359 | - hdmedians==0.14.1
360 | - iniconfig==1.1.1
361 | - itsdangerous==1.1.0
362 | - jupyter-dash==0.3.1
363 | - kraken-biom==1.0.1
364 | - legacy-api-wrap==1.2
365 | - lockfile==0.12.2
366 | - natsort==7.1.1
367 | - openpyxl==3.0.7
368 | - oyaml==1.0
369 | - pdpbox==0.2.0
370 | - pip==19.0.3
371 | - pluggy==0.13.1
372 | - py==1.10.0
373 | - py4j==0.10.9
374 | - pyahocorasick==1.4.1
375 | - pyastronomy==0.14.0
376 | - pyfastaq==3.17.0
377 | - pymatch==0.3.4
378 | - pypdf2==1.26.0
379 | - pyspark==3.0.1
380 | - pytest==6.2.1
381 | - python-graphviz==0.16
382 | - pyvolve==1.0.0
383 | - pywaffle==0.6.3
384 | - scanpy==1.7.1
385 | - scikit-bio==0.5.6
386 | - shap==0.37.0
387 | - shortuuid==1.0.1
388 | - sinfo==0.3.1
389 | - slicer==0.0.3
390 | - squarify==0.4.3
391 | - stdlib-list==0.8.0
392 | - suffix-trees==0.3.0
393 | - tbb==2021.1.1
394 | - toml==0.10.2
395 | - ua-parser==0.10.0
396 | - waitress==1.4.4
397 | - weblogo==3.5.0
398 | - wtforms==2.3.3
399 | - xgboost==1.3.1
400 | - xlrd==2.0.1
401 |
--------------------------------------------------------------------------------
/figures/paper_figures.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "pycharm": {
8 | "name": "#%%\n"
9 | }
10 | },
11 | "outputs": [],
12 | "source": [
13 | "import numpy as np\n",
14 | "import pandas as pd \n",
15 | "\n",
16 | "import matplotlib\n",
17 | "import matplotlib.pyplot as plt\n",
18 | "import seaborn as sns\n",
19 | "import matplotlib as mpl\n",
20 | "from matplotlib.cm import ScalarMappable\n",
21 | "from matplotlib.lines import Line2D\n",
22 | "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
23 | "from textwrap import wrap\n",
24 | "\n",
25 | "import codecs\n",
26 | "import glob\n",
27 | "from tqdm.notebook import tqdm\n",
28 | "import itertools\n",
29 | "from collections import Counter\n",
30 | "import os\n",
31 | "import sys\n",
32 | "\n",
33 | "\n",
34 | "matplotlib.rcParams['pdf.fonttype'] = 42\n",
35 | "matplotlib.rcParams['ps.fonttype'] = 42\n",
36 | "\n",
37 | "import pickle\n",
38 | "\n",
39 | "#shut down warnings\n",
40 | "import warnings\n",
41 | "warnings.filterwarnings('ignore')"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {
47 | "pycharm": {
48 | "name": "#%% md\n"
49 | }
50 | },
51 | "source": [
52 | "## Paper Figures\n",
53 | "\n",
54 | "This notebook allows for reconstructing most of paper figures (except of illustrations figures) "
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {
60 | "pycharm": {
61 | "name": "#%% md\n"
62 | }
63 | },
64 | "source": [
65 | "### Load data\n",
66 | "First you'll need to load the data for generating the figures.\n",
67 | "\n",
68 | "The data needed can be found in the supplementary material provided with the paper.\n",
69 | "\n",
70 | "Download instruction found in main Readme file"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {
77 | "pycharm": {
78 | "name": "#%%\n"
79 | }
80 | },
81 | "outputs": [],
82 | "source": [
83 | "w = pd.read_csv('/models_and_data/figures_data.csv')\n",
84 | "known_system_table = pd.read_csv('/models_and_data/novel_defense_mapping/defense_hypothetical_system_predictions.csv')\n",
85 | "rarefaction_path = '/models_and_data/rarefaction/*pkl'\n",
86 | "all_metrics = pd.read_csv('/models_and_data/benchmark and optimization/all_metrics.csv')\n",
87 | "acc_by_mdl = pd.read_csv('/models_and_data/benchmark and optimization/model_comp.csv')\n",
88 | "known_unknown = pd.read_csv('/models_and_data/hypothetical_prediction_count.csv')"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "metadata": {
95 | "pycharm": {
96 | "name": "#%%\n"
97 | }
98 | },
99 | "outputs": [],
100 | "source": [
101 | "labels = ['Prokaryotic defense system', 'Secretion system',\n",
102 | " 'Benzoate degradation', 'Oxidative phosphorylation',\n",
103 | " 'Two-component system', 'Ribosome',\n",
104 | " 'Porphyrin and chlorophyll metabolism', 'Energy metabolism',\n",
105 | " 'Other', 'Amino sugar and nucleotide sugar metabolism']"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {
111 | "pycharm": {
112 | "name": "#%% md\n"
113 | }
114 | },
115 | "source": [
116 | "### Figure 2"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {
123 | "pycharm": {
124 | "name": "#%%\n"
125 | }
126 | },
127 | "outputs": [],
128 | "source": [
129 | "f, ax = plt.subplots(1,3,figsize=(16, 5))\n",
130 | "\n",
131 | "# All space\n",
132 | "all_space = w[(w['x'] <= 8) & (w['y'] < 19) & (w['y'] >= 2) & (w['x'] > -10)]\n",
133 | "sns.scatterplot(x='x', y='y', data=all_space[all_space['label'] == 'unknown'],color='grey',\n",
134 | " alpha=0.01, linewidth=0, s=2, ax=ax[0], legend=False)\n",
135 | "sns.scatterplot(x='x', y='y', data=all_space[all_space['label']!= 'unknown'] ,color='#D29380',\n",
136 | " alpha=0.05, linewidth=0, s=4, ax=ax[0], legend=False)\n",
137 | "\n",
138 | "\n",
139 | "# CRISPR zoom-in\n",
140 | "crspr = w[(w['x'] > -2) & (w['x'] < -0.4) & (w['y'] > 2) & (w['y'] <4.4)]\n",
141 | "crspr[\"label\"] = crspr.apply(lambda w: \"CRISPR\" if (w['x'] > -1) and (w['x'] < -0.4) and\n",
142 | " (w['y'] > 0) and (w['y'] <2.8) else w['label'], axis=1)\n",
143 | "crspr['label'] = crspr.apply(lambda w: \"Prokaryotic defense system\" if w[\"hmm_type\"] == \"defense\" else w[\"label\"], axis=1)\n",
144 | "cmap = sns.color_palette(['cornflowerblue', 'tomato'])\n",
145 | "sns.scatterplot(x='x', y='y', data=crspr[crspr[\"hmm_type\"] != 'defense'],\n",
146 | " color='grey', alpha=0.09, linewidth=0, s=4, ax=ax[1], label=\"Non-Defense\", legend=False)\n",
147 | "sns.scatterplot(x='x', y='y', data=crspr[crspr[\"label\"].isin([\"Prokaryotic defense system\", \"CRISPR\"])],\n",
148 | " palette=cmap, alpha=0.6, linewidth=0, s=14, ax=ax[1],hue='label',\n",
149 | " label=\"Prokaryotic defense system\", legend=False)\n",
150 | "\n",
151 | "\n",
152 | "# Secretion zoom-in\n",
153 | "secr = w[(w['x'] > -.5) & (w['x'] < 8.5) & (w['y'] > 7) & (w['y'] < 14)]\n",
154 | "cmap = sns.color_palette([\"tomato\", \"darkmagenta\" ,\"cornflowerblue\",\"seagreen\",\"deeppink\"])\n",
155 | "sns.scatterplot(x='x', y='y', data=secr[(secr['label'] == 'unknown')],color='grey',\n",
156 | " alpha=0.009, linewidth=0, s=4, ax=ax[2], legend=False)\n",
157 | "sns.scatterplot(x='x', y='y', data=secr[(secr[\"label\"] == \"Secretion system\") & (secr[\"secretion_type\"] != \"other\")],\n",
158 | " hue='secretion_type',palette=cmap, alpha=0.8, linewidth=0, s=14, ax=ax[2], legend=False)\n",
159 | "\n",
160 | "ax[0].set_xlabel(\"UMAP1\")\n",
161 | "ax[1].set_xlabel(\"UMAP1\")\n",
162 | "ax[2].set_xlabel(\"UMAP1\")\n",
163 | "\n",
164 | "ax[0].set_ylabel(\"UMAP2\")\n",
165 | "ax[1].set_ylabel(\"\")\n",
166 | "ax[2].set_ylabel(\"\")\n",
167 | "\n",
168 | "for i in range(3):\n",
169 | " ax[i].axes.get_xaxis().set_visible(False)\n",
170 | " ax[i].axes.get_yaxis().set_visible(False)\n",
171 | " plt.setp(ax[i].spines.values(), color=\"#D2D7DA\", lw=2)\n",
172 | "\n",
173 | "plt.savefig(\"figure2.png\", format='png', dpi=350)"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {
179 | "pycharm": {
180 | "name": "#%% md\n"
181 | }
182 | },
183 | "source": [
184 | "### Figure 3 "
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "metadata": {
191 | "pycharm": {
192 | "name": "#%%\n"
193 | }
194 | },
195 | "outputs": [],
196 | "source": [
197 | "# Figure 3a + c\n",
198 | "\n",
199 | "sns.set_context(\"poster\")\n",
200 | "fig, ax = plt.subplots(2,1,figsize=(22,10))\n",
201 | "melted = pd.melt(all_metrics[all_metrics['class'] != 'overall'],\n",
202 | " id_vars=['classifier'], value_vars=['f1-score', 'accuracy', 'precision','recall'])\n",
203 | "\n",
204 | "\n",
205 | "sns.pointplot(x='label', y='f1-score', hue='model', data=acc_by_mdl,\n",
206 | " alpha=1, marker=True, palette=['#7F9ACF', '#F9B233', '#F3CCB8', '#EF856A'], ax=ax[0])\n",
207 | "sns.barplot(x='variable', y='value', hue='classifier', data=melted, ax=ax[1], palette='Reds_r',\\\n",
208 | " capsize=.06, errwidth=4)\n",
209 | "sns.stripplot(x='variable', y='value', hue='classifier', data=melted, ax=ax[1], palette='Reds_r')\n",
210 | "for i in [0,1]:\n",
211 | " ax[i].set_ylim(0,1)\n",
212 | " ax[i].legend(bbox_to_anchor=[1,0.86])\n",
213 | " ax[i].set_xlabel('')\n",
214 | "fig.tight_layout()\n",
215 | "plt.savefig(\"figure3.pdf\", bbox_inches='tight')\n"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {
221 | "pycharm": {
222 | "name": "#%% md\n"
223 | }
224 | },
225 | "source": [
226 | "### Figure 4"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": null,
232 | "metadata": {
233 | "pycharm": {
234 | "name": "#%%\n"
235 | }
236 | },
237 | "outputs": [],
238 | "source": [
239 | "# Figure 4a\n",
240 | "# candidates\n",
241 | "f, ax = plt.subplots(figsize=(7, 7))\n",
242 | "\n",
243 | "all_space = w[(w['x'] <= 10) & (w['y'] >= 1.2)]\n",
244 | "all_space = all_space[(all_space['predicted_class'].isin(labels))].sort_values(by=\"predicted_class\")\n",
245 | "cmap = sns.color_palette(['deeppink', '#3F681C', 'lightcoral', 'gainsboro', 'indianred', 'aqua','#FB6542', 'lightgreen', 'dodgerblue', 'gold'])\n",
246 | "sns.scatterplot(x='x', y='y', data=all_space ,hue='predicted_class', palette=cmap,\n",
247 | " alpha=0.1, linewidth=0, s=4, ax=ax)\n",
248 | "\n",
249 | "ax.axes.get_xaxis().set_visible(False)\n",
250 | "ax.axes.get_yaxis().set_visible(False)\n",
251 | "plt.setp(ax.spines.values(), color=\"#D2D7DA\", lw=2)\n",
252 | "plt.legend(bbox_to_anchor=[1,1])\n",
253 | "\n",
254 | "plt.savefig(\"candidates.png\", format='png', dpi=350,bbox_inches=\"tight\")\n"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "metadata": {
261 | "pycharm": {
262 | "name": "#%%\n"
263 | }
264 | },
265 | "outputs": [],
266 | "source": [
267 | "# Figure 4b\n",
268 | "\n",
269 | "# This figure was adjusted from :\n",
270 | "# https://www.python-graph-gallery.com/web-circular-barplot-with-matplotlib\n",
271 | "\n",
272 | "\n",
273 | "w_preds = w[(w['label'].isin(labels)) | (w['predicted_class'].isin(labels))]\n",
274 | "w_preds[\"class\"] = w_preds.apply(lambda x: x['label'] if x['label'] != 'unknown' else x['predicted_class'], axis=1)\n",
275 | "w_preds['hypothetical'] = w_preds['word'].apply(lambda x: \"hypo.clst.\" in x)\n",
276 | "grp = w_preds.groupby(['class', 'hypothetical']).agg({'word': pd.Series.nunique, 'word_count': sum}).reset_index()\n",
277 | "\n",
278 | "grp['word_count_log'] = np.log10(grp['word_count'])\n",
279 | "grp['word_log'] = np.log10(grp['word'])\n",
280 | "\n",
281 | "grp_hypo = grp[grp['hypothetical'] == True]\n",
282 | "grp_known = dict(grp[grp['hypothetical'] != True][[\"class\", \"word_log\"]].values)\n",
283 | "grp_known['Other'] = 3.5\n",
284 | "grp_hypo['n'] = grp_hypo['class'].apply(lambda x: grp_known[x])\n",
285 | "\n",
286 | "df_sorted = grp_hypo.sort_values(\"word_count\", ascending=False)\n",
287 | "\n",
288 | "# Values for the x axis\n",
289 | "ANGLES = np.linspace(0.05, 2 * np.pi - 0.05, len(df_sorted), endpoint=False)\n",
290 | "LENGTHS = df_sorted[\"word_count_log\"].values\n",
291 | "MEAN_GAIN = df_sorted[\"word_log\"].values\n",
292 | "REGION = df_sorted[\"class\"].values\n",
293 | "TRACKS_N = df_sorted[\"n\"].values\n",
294 | "\n",
295 | "GREY12 = \"#1f1f1f\"\n",
296 | "COLORS = [\"#6C5B7B\", \"#C06C84\", \"#F67280\", \"#F8B195\"]\n",
297 | "cmap = mpl.colors.LinearSegmentedColormap.from_list(\"my color\", COLORS, N=256)\n",
298 | "norm = mpl.colors.Normalize(vmin=TRACKS_N.min(), vmax=TRACKS_N.max())\n",
299 | "\n",
300 | "COLORS = cmap(norm(TRACKS_N))\n",
301 | "\n",
302 | "fig, ax = plt.subplots(figsize=(9, 12.6), subplot_kw={\"projection\": \"polar\"})\n",
303 | "\n",
304 | "fig.patch.set_facecolor(\"white\")\n",
305 | "ax.set_facecolor(\"white\")\n",
306 | "\n",
307 | "ax.set_theta_offset(1.2 * np.pi / 2)\n",
308 | "ax.set_ylim(-2, 8)\n",
309 | "\n",
310 | "ax.bar(ANGLES, LENGTHS, color=COLORS, alpha=0.9, width=0.52, zorder=10)\n",
311 | "ax.vlines(ANGLES, 0, 8, color=GREY12, ls=(0, (4, 4)), zorder=11)\n",
312 | "\n",
313 | "ax.scatter(ANGLES, MEAN_GAIN, s=60, color=GREY12, zorder=11)\n",
314 | "\n",
315 | "\n",
316 | "REGION = [\"\\n\".join(wrap(r, 5, break_long_words=False)) for r in REGION]\n",
317 | "# Set the labels\n",
318 | "ax.set_xticks(ANGLES)\n",
319 | "ax.set_xticklabels(REGION, size=12)\n",
320 | "\n",
321 | "\n",
322 | "cbaxes = inset_axes(\n",
323 | " ax,\n",
324 | " width=\"100%\",\n",
325 | " height=\"100%\",\n",
326 | " loc=\"center\",\n",
327 | " bbox_to_anchor=(0.325, 0.1, 0.35, 0.01),\n",
328 | " bbox_transform=fig.transFigure # Note it uses the figure.\n",
329 | ")\n",
330 | "\n",
331 | "# Create a new norm, which is discrete\n",
332 | "bounds = [1, 150, 400, 1000, 3000]\n",
333 | "norm = mpl.colors.BoundaryNorm(bounds, cmap.N)\n",
334 | "\n",
335 | "# Create the colorbar\n",
336 | "cb = fig.colorbar(\n",
337 | " ScalarMappable(norm=norm, cmap=cmap),\n",
338 | " cax=cbaxes, # Use the inset_axes created above\n",
339 | " orientation=\"horizontal\",\n",
340 | " ticks=[150, 400, 1000, 3000]\n",
341 | ")\n",
342 | "\n",
343 | "cb.outline.set_visible(False)\n",
344 | "cb.ax.xaxis.set_tick_params(size=0)\n",
345 | "cb.set_label(\"Words in training set\", size=16, labelpad=-40)\n",
346 | "\n",
347 | "plt.savefig(\"predictions_cbar.png\", format='png', dpi=350)\n",
348 | "\n"
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": null,
354 | "metadata": {
355 | "pycharm": {
356 | "name": "#%%\n"
357 | }
358 | },
359 | "outputs": [],
360 | "source": [
361 | "# Figure 4c\n",
362 | "def show_values_on_bars(axs):\n",
363 | " def _show_on_single_plot(ax): \n",
364 | " for p in ax.patches:\n",
365 | " _x = p.get_x() + p.get_width() / 2\n",
366 | " _y = p.get_y() + p.get_height() \n",
367 | " value = '{}'.format(int(p.get_height()))\n",
368 | " ax.text(_x, _y, value, ha=\"center\", color='#94979C', fontsize=16) \n",
369 | "\n",
370 | " if isinstance(axs, np.ndarray):\n",
371 | " for idx, ax in np.ndenumerate(axs):\n",
372 | " _show_on_single_plot(ax)\n",
373 | " else:\n",
374 | " _show_on_single_plot(axs)\n",
375 | "\n",
376 | "\n",
377 | "sns.set_context('poster')\n",
378 | "fig, ax = plt.subplots(figsize=(14,6))\n",
379 | "plt_data = known_unknown[~known_unknown['predicted_class'].isin(['Other'])]\n",
380 | "\n",
381 | "sns.barplot(x='predicted_class', y='count', hue='has_annotation', data=plt_data, ax=ax, palette=['#FB6542','#375E97'])\n",
382 | "plt.yscale('log')\n",
383 | "_ = plt.xticks(rotation=90)\n",
384 | "ax.legend(bbox_to_anchor=[1,0.86])\n",
385 | "show_values_on_bars(ax)\n",
386 | "\n",
387 | "plt.savefig(\"figure4c.pdf\", bbox_inches='tight')"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": null,
393 | "metadata": {
394 | "pycharm": {
395 | "name": "#%%\n"
396 | }
397 | },
398 | "outputs": [],
399 | "source": [
400 | "# Figure 4d\n",
401 | "f, ax = plt.subplots(figsize=(14, 6))\n",
402 | "\n",
403 | "plt_data = known_system_table[(known_system_table['system'] != 'unknown')]\n",
404 | "plt_data[\"class\"] = plt_data[\"predicted_class\"].apply(lambda x: x if \"Proka\" in x else \"Other classes\")\n",
405 | "\n",
406 | "\n",
407 | "sns.barplot(x='system', y='per', data=plt_data, hue='class', palette=['#EB2B4C', '#DFDBD9'],alpha=0.75, ax=ax)\n",
408 | "_ = plt.xticks(rotation=60)\n",
409 | "plt.legend(bbox_to_anchor=[1,1])\n",
410 | "plt.ylabel('% Predictions')\n",
411 | "\n",
412 | "def change_width(ax, new_value) :\n",
413 | " for patch in ax.patches :\n",
414 | " current_width = patch.get_width()\n",
415 | " diff = current_width - new_value\n",
416 | "\n",
417 | " # we change the bar width\n",
418 | " patch.set_width(new_value)\n",
419 | "\n",
420 | " # we recenter the bar\n",
421 | " patch.set_x(patch.get_x() + diff * .5)\n",
422 | "\n",
423 | "change_width(ax, .45)\n",
424 | "# sns.despine()\n",
425 | "\n",
426 | "plt.savefig(\"predictions_bar.pdf\", format='pdf', bbox_inches=\"tight\")\n"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": null,
432 | "metadata": {
433 | "pycharm": {
434 | "name": "#%%\n"
435 | }
436 | },
437 | "outputs": [],
438 | "source": [
439 | "# Figure 4e\n",
440 | "boots = [pd.read_pickle(f) for f in glob.glob(rarefaction_path)]\n",
441 | "boots[1]['n_genes'] = boots[0]['n_genes']\n",
442 | "boots[2]['n_genes'] = boots[0]['n_genes']\n",
443 | "\n",
444 | "df = pd.concat(boots)\n",
445 | "\n",
446 | "fig, ax = plt.subplots(figsize=(14, 6))\n",
447 | "colors = ['limegreen', 'darkorange', 'cornflowerblue', 'gold', 'olive', 'tomato', 'deeppink', 'pink', 'turquoise']\n",
448 | "colors = ['deeppink', '#3F681C', '#9B0D7F', 'lightseagreen', 'aqua','gold', 'lightgreen', '#FB6542', 'dodgerblue']\n",
449 | "for c, cl in zip(colors, df.sort_values(by='function')[\"function\"].unique()):\n",
450 | " class_data = df[df[\"function\"] == cl]\n",
451 | " ax.plot(class_data['n_genes'], class_data['uniq_genes_mean'], color=c,\n",
452 | " label=cl, lw=3, alpha=.8)\n",
453 | " ax.fill_between(class_data['n_genes'], class_data['lower_q'], class_data['upper_q'], color=c, alpha=.1)\n",
454 | "\n",
455 | "ax.grid(True)\n",
456 | "plt.legend(bbox_to_anchor=(1.01, 1))\n",
457 | "plt.xlabel(\"Number of genes in sample\")\n",
458 | "plt.ylabel(\"Number of genes\")\n",
459 | "plt.xlim(1000, df['n_genes'].max())\n",
460 | "\n",
461 | "plt.savefig(\"rarefaction.pdf\", bbox_inches='tight')\n"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": null,
467 | "metadata": {
468 | "pycharm": {
469 | "name": "#%%\n"
470 | }
471 | },
472 | "outputs": [],
473 | "source": []
474 | }
475 | ],
476 | "metadata": {
477 | "kernelspec": {
478 | "display_name": "Python 3 (ipykernel)",
479 | "language": "python",
480 | "name": "python3"
481 | },
482 | "language_info": {
483 | "codemirror_mode": {
484 | "name": "ipython",
485 | "version": 3
486 | },
487 | "file_extension": ".py",
488 | "mimetype": "text/x-python",
489 | "name": "python",
490 | "nbconvert_exporter": "python",
491 | "pygments_lexer": "ipython3",
492 | "version": "3.9.12"
493 | }
494 | },
495 | "nbformat": 4,
496 | "nbformat_minor": 4
497 | }
498 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "genomic_embeddings"
7 | version = "0.0.1"
8 | authors = [
9 | { name="Danielle Miller Sayag", email="danimillers10@gmail.com" },
10 | ]
11 | description = "Package supporting the paper Deciphering microbial gene function using natural language processing"
12 | readme = "README.md"
13 | license = { file="LICENSE" }
14 | requires-python = ">=3.7"
15 | classifiers = [
16 | "Programming Language :: Python :: 3",
17 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
18 | "Operating System :: OS Independent",
19 | ]
20 |
21 | [project.urls]
22 | "GitHub" = "https://github.com/burstein-lab/genomic-nlp"
23 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.3.0
2 | astunparse==1.6.3
3 | bcbio-gff==0.6.6
4 | bio==1.3.3
5 | biopython==1.79
6 | biothings-client==0.2.6
7 | cachetools==5.2.0
8 | certifi==2022.9.24
9 | charset-normalizer==2.1.1
10 | contourpy==1.0.5
11 | cycler==0.11.0
12 | flatbuffers==22.9.24
13 | fonttools==4.37.4
14 | gast==0.4.0
15 | gensim==3.8.3
16 | google-auth==2.13.0
17 | google-auth-oauthlib==0.4.6
18 | google-pasta==0.2.0
19 | grpcio==1.49.1
20 | h5py==3.7.0
21 | idna==3.4
22 | importlib-metadata==5.0.0
23 | joblib==1.2.0
24 | keras==2.10.0
25 | Keras-Applications==1.0.8
26 | Keras-Preprocessing==1.1.2
27 | kiwisolver==1.4.4
28 | libclang==14.0.6
29 | Markdown==3.4.1
30 | matplotlib
31 | mygene==3.2.2
32 | numcodecs==0.7.3
33 | numpy==1.21.0
34 | oauthlib==3.2.2
35 | opt-einsum==3.3.0
36 | packaging==21.3
37 | pandas==1.3.5
38 | Pillow==9.2.0
39 | protobuf==3.19.6
40 | pyasn1==0.4.8
41 | pyasn1-modules==0.2.8
42 | pyparsing==3.0.9
43 | python-dateutil==2.8.2
44 | pytz==2022.4
45 | PyYAML==6.0
46 | requests==2.28.1
47 | requests-oauthlib==1.3.1
48 | rsa==4.9
49 | scikit-learn==0.24.1
50 | scipy==1.7.3
51 | seaborn==0.10.0
52 | six==1.16.0
53 | smart-open==6.2.0
54 | tensorboard==2.10.1
55 | tensorboard-data-server==0.6.1
56 | tensorboard-plugin-wit==1.8.1
57 | tensorflow==2.10.0
58 | tensorflow-estimator==2.10.0
59 | tensorflow-io-gcs-filesystem==0.27.0
60 | termcolor==2.0.1
61 | threadpoolctl==3.1.0
62 | tqdm==4.43.0
63 | typing_extensions==4.4.0
64 | urllib3==1.26.12
65 | Werkzeug==2.2.2
66 | wrapt==1.14.1
67 | xgboost==1.3.3
68 | zipp==3.9.0
--------------------------------------------------------------------------------
/scripts/MultiRunner.py:
--------------------------------------------------------------------------------
1 | from src.genomic_embeddings import Gff, corpus
2 | import os
3 | import argparse
4 |
5 |
6 | def main(args):
7 | gff_file = args.input
8 | hypothetical = args.hypothetical
9 |
10 | gff = Gff.Gff(gff_path=gff_file)
11 | gff.set_name()
12 | gff.parse_gff()
13 | gff.extract_hypothetical_and_prokka()
14 |
15 | if args.build_corpus:
16 | gff = Gff.Gff(gff_path=gff_file, hypothetical_folder=hypothetical)
17 | gff.set_name()
18 | gff_corpus = corpus.CorpusGenerator(gff=gff, by=args.annotation)
19 | gff_corpus.generate(os.path.join(args.output, f"{gff_corpus.gff.name}.txt"))
20 |
21 |
22 |
23 | if __name__ == "__main__":
24 | argparse = argparse.ArgumentParser()
25 | argparse.add_argument('--input', required=True, type=str, help='gff input file')
26 | argparse.add_argument('--output', default='/output/',
27 | type=str, help='the path to restore output txt files')
28 | argparse.add_argument('--hypothetical',
29 | default='pkl_by_sample',
30 | type=str, help='hypothetical pkl per sample folder')
31 | argparse.add_argument('--alias', default='G2V', type=str, help='model running alias that will be used for model tracking')
32 | argparse.add_argument('--cluster', dest='cluster', action='store_true', help="run on cluster flag, default")
33 | argparse.add_argument('--local', dest='cluster', action='store_false', help="run locally flag")
34 | argparse.add_argument('--annotation', default='annotation', type=str, help='annotation level, can be annotation or annotation_extended [default: annotation]')
35 | argparse.add_argument('--build_corpus', dest='build_corpus', action='store_true', help="build corpus from parsed GFFs flag")
36 | argparse.set_defaults(cluster=True, build_corpus=False)
37 | params = argparse.parse_args()
38 |
39 | main(params)
40 |
41 |
42 |
--------------------------------------------------------------------------------
/scripts/classify.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pandas as pd
4 | from sklearn.model_selection import StratifiedKFold
5 | from src.genomic_embeddings.models import MLClf, MLClfFolds, NNClf, NNClfFolds
6 | from src.genomic_embeddings.data import Embedding
7 | from src.genomic_embeddings.plot import ModelPlots, FoldModelPlots
8 | import argparse
9 |
10 |
11 | argparse = argparse.ArgumentParser()
12 | argparse.add_argument('--model', required=True, type=str, help='model file')
13 | argparse.add_argument('--output',
14 | default='/predictions',
15 | type=str, help='predictions output dir')
16 | argparse.add_argument('--metadata',
17 | default='metadata.csv',
18 | type=str, help='metadata csv file path')
19 | params = argparse.parse_args()
20 |
21 | MODEL = params.model
22 | METADATA = params.metadata
23 | OUTPUT_DIR = params.output
24 |
25 | # configure logger
26 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
27 | filename=os.path.join(OUTPUT_DIR, f"Validation.log"), level=logging.INFO)
28 |
29 | # top predicted LABEL
30 | top_labels = ['Amino sugar and nucleotide sugar metabolism', 'Benzoate degradation', 'Cell growth', 'Energy metabolism',
31 | 'Methane metabolism', 'Oxidative phosphorylation', 'Prokaryotic defense system', 'Ribosome',
32 | 'Secretion system', 'Transporters', 'Glycosyltransferases']
33 |
34 | curated_labels = ['Amino sugar and nucleotide sugar metabolism',
35 | 'Benzoate degradation',
36 | 'Energy metabolism',
37 | 'Oxidative phosphorylation',
38 | 'Porphyrin and chlorophyll metabolism',
39 | 'Prokaryotic defense system',
40 | 'Purine metabolism',
41 | 'Ribosome',
42 | 'Secretion system',
43 | 'Transporters',
44 | 'Two-component system']
45 |
46 | curated_labels_no_pumps = ['Amino sugar and nucleotide sugar metabolism',
47 | 'Benzoate degradation',
48 | 'Energy metabolism',
49 | 'Oxidative phosphorylation',
50 | 'Porphyrin and chlorophyll metabolism',
51 | 'Prokaryotic defense system',
52 | 'Ribosome',
53 | 'Secretion system',
54 | 'Two-component system']
55 |
56 | labels = [(top_labels, 'TOPLABELS'), (curated_labels, 'CURATED-LABELS'),
57 | (curated_labels_no_pumps, 'NO-PUMPS-CURATED-LABELS')]
58 | LABEL = 'label'
59 | q = ''
60 |
61 | for label, label_alias in labels:
62 | alias = label_alias
63 | logging.info(f"=== Extract embedding for label = {LABEL}, Q= {q}")
64 | emb = Embedding(mdl=MODEL, metadata=METADATA, labels=label)
65 | emb.process_data_pipeline(label=LABEL, q=q, add_other=True)
66 | logging.info(f"Number of effective words: {emb.effective_words.shape[0]}\n")
67 |
68 |
69 | data = emb.data
70 | X, y = data.drop(columns=[LABEL]).values, data[LABEL].values
71 | cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
72 | logging.info(f"Matrix shape is: {X.shape}, 20% used for testing in 5 fold CV\n"
73 | f"Train size: {X.shape[0] * 0.8}, Test size: {X.shape[0] * 0.2}\n"
74 | f"Number of unique classes: {pd.Series(y).nunique()}")
75 |
76 | MDLS = [(MLClf(X=X, y=y, out_dir=OUTPUT_DIR), alias),
77 | (NNClf(X=X, y=y, out_dir=OUTPUT_DIR),alias),
78 | (MLClfFolds(X=X, y=y, cv=cv, out_dir=OUTPUT_DIR),'FOLD_' + alias),
79 | (NNClfFolds(X=X, y=y, cv=cv, out_dir=OUTPUT_DIR), 'FOLD_' + alias)]
80 |
81 |
82 | for mdl, name in MDLS:
83 | mdl.classification_pipeline(LABEL, alias=name)
84 | if 'FOLD' in name:
85 | plotter = FoldModelPlots(mdl=mdl)
86 | plotter.plot_single_aupr_with_ci()
87 | plotter.plot_single_roc_with_ci()
88 | else:
89 | plotter = ModelPlots(mdl=mdl)
90 | plotter.plot_precision_recall()
91 | plotter.plot_roc()
92 |
93 |
94 |
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/scripts/create_cross_validation_data.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import shutil
3 | import os
4 | import codecs
5 | import argparse
6 | import pickle
7 |
8 | import numpy as np
9 | from sklearn.model_selection import KFold
10 | from tqdm import tqdm
11 |
12 | def cp_files(list_of_files, dest):
13 | for f in list_of_files:
14 | if os.path.isfile(f):
15 | shutil.copy(f, dest)
16 |
17 | def extract_known_words(list_of_files, unknown='hypo.clst'):
18 | corpus_raw = u""
19 | for f in tqdm(list_of_files):
20 | with codecs.open(f, "r", "utf-8") as book_file:
21 | corpus_raw += book_file.read()
22 | raw_sentences = corpus_raw.split('. ')
23 | words = []
24 | for raw_sentence in tqdm(raw_sentences):
25 | if len(raw_sentence) > 0:
26 | words.extend([w for w in raw_sentence.split() if unknown not in w])
27 | return set(words)
28 |
29 | class CorpusCV():
30 | def __init__(self, corpus_dir, output_dir, folds, name, folds_mapping=None):
31 | self.corpus_dir = corpus_dir
32 | self.output_dir = output_dir
33 | self.nfolds = folds
34 | self.name = name
35 | self.folds_mapping = folds_mapping
36 |
37 | def create_output_dir(self):
38 | os.makedirs(os.path.join(self.output_dir, self.name),exist_ok=True)
39 | self.output_dir = os.path.join(self.output_dir, self.name)
40 |
41 | def corpus_kfold(self):
42 | corpus_files = np.array(glob.glob(self.corpus_dir))
43 | cv = KFold(n_splits=self.nfolds)
44 | fold = 1
45 | for train_idx, test_idx in cv.split(corpus_files):
46 | train_files = corpus_files[train_idx]
47 | test_files = corpus_files[train_idx]
48 |
49 | test_words = extract_known_words(test_files)
50 |
51 | # create the fold directory
52 | fold_dir = os.path.join(self.output_dir, f'fold_{fold}')
53 | train_dir = os.path.join(fold_dir, 'corpus')
54 | os.makedirs(fold_dir, exist_ok=True)
55 | os.makedirs(train_dir, exist_ok=True)
56 |
57 | cp_files(train_files, train_dir)
58 | np.save(os.path.join(fold_dir, 'test_words.npy'), test_words)
59 |
60 | fold += 1
61 |
62 | def corpusLOPOCV(self):
63 | with open(self.folds_mapping, 'rb') as handle:
64 | lopocv_mapper = pickle.load(handle)
65 |
66 | for phylum in lopocv_mapper:
67 | train_files = lopocv_mapper[phylum]['train_files']
68 | test_files = lopocv_mapper[phylum]['test_files']
69 |
70 | test_words = extract_known_words(test_files)
71 |
72 | # create the fold directory
73 | fold_dir = os.path.join(self.output_dir, f'fold_{phylum}')
74 | train_dir = os.path.join(fold_dir, 'corpus')
75 | os.makedirs(fold_dir, exist_ok=True)
76 | os.makedirs(train_dir, exist_ok=True)
77 |
78 | cp_files(train_files, train_dir)
79 | np.save(os.path.join(fold_dir, 'test_words.npy'), test_words)
80 |
81 | if __name__ == '__main__':
82 | argparse = argparse.ArgumentParser()
83 | argparse.add_argument('--input_dir', required=True, type=str, help='input directory with the full corpus files')
84 | argparse.add_argument('--output_dir', required=True, type=str, help='output directory for cross validation')
85 | argparse.add_argument('--lopocv', default=None, type=str, help='folds mapping of corpus files')
86 | argparse.add_argument('--name', default='5foldcv', type=str, help='cv identifier')
87 | argparse.add_argument('--folds', default=5, type=int, help='number of folds in cv')
88 |
89 | params = argparse.parse_args()
90 |
91 | corpus_cv = CorpusCV(params.input_dir, params.output_dir, params.folds, params.name, params.lopocv)
92 | corpus_cv.create_output_dir()
93 | if corpus_cv.folds_mapping is None:
94 | corpus_cv.corpus_kfold()
95 | else:
96 | corpus_cv.corpusLOPOCV()
--------------------------------------------------------------------------------
/scripts/cross_validate.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | from src.genomic_embeddings.models import NNClfCVFolds, RFClfCVFolds, XGBClfCVFolds, SVMClfCVFolds
5 | from src.genomic_embeddings.plot import FoldModelPlots
6 | import argparse
7 | import pickle
8 |
9 |
10 | argparse = argparse.ArgumentParser()
11 | argparse.add_argument('--cv', default='LOPOCV', type=str, help='name of the cross validation')
12 | argparse.add_argument('--output',
13 | default='/predictions',
14 | type=str, help='predictions output dir')
15 | argparse.add_argument('--metadata',
16 | default='metadata.csv',
17 | type=str, help='metadata csv file path')
18 | argparse.add_argument('--fold2data',
19 | default='fold2data.pkl',
20 | type=str, help='mapping of LOPOCV train and test')
21 |
22 | params = argparse.parse_args()
23 |
24 | CV = params.cv
25 | METADATA = params.metadata
26 | OUTPUT_DIR = params.output
27 | FOLD2DATA = params.fold2data
28 |
29 | # configure logger
30 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
31 | filename=os.path.join(OUTPUT_DIR, f"Validation.log"), level=logging.INFO)
32 |
33 | # top predicted LABEL
34 | curated_labels_no_pumps = ['Amino sugar and nucleotide sugar metabolism',
35 | 'Benzoate degradation',
36 | 'Energy metabolism',
37 | 'Oxidative phosphorylation',
38 | 'Porphyrin and chlorophyll metabolism',
39 | 'Prokaryotic defense system',
40 | 'Ribosome',
41 | 'Secretion system',
42 | 'Two-component system']
43 |
44 | labels = [(curated_labels_no_pumps, 'NO-PUMPS-CURATED-LABELS')]
45 | LABEL = 'label'
46 | q = ''
47 |
48 | with open(FOLD2DATA, 'rb') as o:
49 | fold2data = pickle.load(o)
50 |
51 | for label, label_alias in labels:
52 | alias = label_alias
53 | MDLS = [(NNClfCVFolds(X=1, y=1, cv=5, out_dir=OUTPUT_DIR, fold2data=fold2data[CV], fold_type=CV), 'CVFOLD_' + alias),
54 | (XGBClfCVFolds(X=1, y=1, cv=5, out_dir=OUTPUT_DIR, fold2data=fold2data[CV], fold_type=CV), 'CVFOLD_' + alias),
55 | (RFClfCVFolds(X=1, y=1, cv=5, out_dir=OUTPUT_DIR, fold2data=fold2data[CV], fold_type=CV), 'CVFOLD_' + alias),
56 | (SVMClfCVFolds(X=1, y=1, cv=5, out_dir=OUTPUT_DIR, fold2data=fold2data[CV], fold_type=CV), 'CVFOLD_' + alias)]
57 |
58 |
59 | for mdl, name in MDLS:
60 | mdl.classification_pipeline(LABEL, alias=name)
61 | plotter = FoldModelPlots(mdl=mdl)
62 | plotter.plot_single_aupr_with_ci()
63 | plotter.plot_single_roc_with_ci()
64 | plotter.plot_precision_recall()
65 | plotter.plot_roc()
66 | plotter.plot_precision_recall_by_fold()
67 | plotter.plot_roc_by_fold()
--------------------------------------------------------------------------------
/scripts/optimize_ML.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import xgboost as xgb
3 | from sklearn.model_selection import StratifiedKFold
4 | from sklearn.ensemble import RandomForestClassifier
5 | from sklearn.svm import SVC
6 | from sklearn.model_selection import GridSearchCV
7 | import pandas as pd
8 | import os
9 |
10 | import argparse
11 |
12 |
13 | def optimize(clf, X, y, parameters, alias):
14 | cv=StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
15 | scores = ["precision", "recall"]
16 |
17 | results = []
18 |
19 | for score in scores:
20 | print("# Tuning hyper-parameters for %s" % score)
21 | print()
22 |
23 | mdl = GridSearchCV(estimator=clf, param_grid=parameters, cv=cv, scoring="%s_weighted" % score)
24 | mdl.fit(X, y)
25 |
26 | res = pd.DataFrame(mdl.cv_results_)
27 | res['score_function'] = score
28 | results.append(res)
29 |
30 | return pd.cocnat(results)
31 |
32 |
33 |
34 | def main(args):
35 |
36 | with open(args.fold2data, 'rb') as o:
37 | fold2data = pickle.load(o)
38 |
39 | name = args.clf_name
40 | fold = args.fold_name
41 | lopocv = fold2data[args.cv_name]
42 |
43 | X_train, y_train = lopocv[fold]['X_train'], lopocv[fold]['y_train']
44 |
45 | names_mappings = { 'SVC' : {'CLF':SVC(),
46 | 'params': {'kernel':('linear', 'rbf'), 'C':[0.01, 1, 10]}},
47 | 'RF': {'CLF': RandomForestClassifier(n_jobs=10),
48 | 'params': {'max_depth': [6, 10, 50, 100, None], 'max_features': ['auto', 'sqrt'],
49 | 'min_samples_leaf': [1, 2, 4], 'min_samples_split': [2, 5, 10],
50 | 'n_estimators': [100, 500, 1000]}},
51 | 'XGB' : {'CLF':xgb.XGBClassifier(n_jobs=10),
52 | 'params': {'n_estimators': [100, 500, 800, 1000], 'max_depth': [6, 10, 50, 100, None],
53 | 'learning_rate': [0.001, 0.05, None]}}
54 | }
55 |
56 | clf, param_grid = names_mappings[name]['CLF'], names_mappings[name]['params']
57 | results = optimize(clf, X_train, y_train, param_grid, name)
58 | results['fold_name'] = fold
59 | results['model'] = name
60 |
61 | results.to_csv(os.path.join(args.outdir, f'opt_{name}_{fold}'), index=False)
62 |
63 |
64 | if __name__ == '__main__':
65 | argparse = argparse.ArgumentParser()
66 | argparse.add_argument('--fold_name', required=True, type=str, help='name of the fold use for train')
67 | argparse.add_argument('--cv_name', default='LOPOCV', type=str, help='name of the CV use for train')
68 | argparse.add_argument('--clf_name', default='SVC', type=str, help='name of the CV use for train')
69 | argparse.add_argument('--outdir', default='./',
70 | type=str, help='output dir to save files')
71 | argparse.add_argument('--fold2data',
72 | default='fold2data.pkl',
73 | type=str, help='mapping of LOPOCV train and test')
74 |
75 | args = argparse.parse_args()
76 |
77 | main(args)
--------------------------------------------------------------------------------
/scripts/predict_hypothetical.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import re
3 | import numpy as np
4 | import os
5 | import glob
6 |
7 | from src.genomic_embeddings.models import NNClf
8 | from src.genomic_embeddings.data import Embedding
9 | import argparse
10 | import pickle
11 |
12 | argparse = argparse.ArgumentParser()
13 | argparse.add_argument('--model', required=True, type=str, help='model file')
14 | argparse.add_argument('--output',
15 | default='predictions',
16 | type=str, help='predictions output dir')
17 | argparse.add_argument('--metadata',
18 | default='metadata.csv',
19 | type=str, help='metadata csv file path')
20 | params = argparse.parse_args()
21 |
22 | MODEL = params.model
23 | METADATA = params.metadata
24 | OUTPUT_DIR = params.output
25 |
26 | label='label'
27 |
28 |
29 | with open(glob.glob(os.path.join(os.path.dirname(MODEL), "predictions/label/FOLD_NO-PUMPS-CURATED-LABELS/*.pkl"))[0], 'rb') as o:
30 | class_2_aupr = pickle.load(o)
31 | class_2_aupr = {k:v for k, v in class_2_aupr.items() if k != 'ALL'}
32 | curated_labels = [k for k in class_2_aupr if k != 'ALL' and k != 'Other']
33 |
34 | emb = Embedding(mdl=MODEL, metadata=METADATA, labels=curated_labels)
35 | emb.process_data_pipeline(label=label, q=0, add_other=True)
36 | data = emb.data
37 | X, y = data.drop(columns=[label]).values, data[label].values
38 |
39 | meta = emb.metadata
40 | meta['label'] = meta['label'].apply(lambda x: re.split('(.)\[|\(|,', x)[0].strip())
41 | test_embeddings_idx = {word:emb.mdl.wv.vocab[word].index for word in emb.mdl.wv.vocab
42 | if word not in emb.train_words["word"] and "hypo.clst" in word}
43 | X_test = emb.embedding[[*test_embeddings_idx.values()]]
44 |
45 | trainer = NNClf(X=X, y=y, out_dir=None)
46 | predicted, predicted_prob = trainer.model_fit(X, X_test, y)
47 |
48 | test_df = pd.DataFrame.from_dict(test_embeddings_idx, orient="index").reset_index().rename(
49 | columns={0:"index", 'index':"word"})
50 | dic_y_mapping = {n: label for n, label in enumerate(np.unique(y))}
51 |
52 | for key, value in dic_y_mapping.items():
53 | test_df[value] = predicted_prob[:,key]
54 |
55 | test_df['predicted_class'] = predicted
56 | test_df['predicted_class_score'] = predicted_prob.max(axis=1)
57 |
58 | test_df["weighted_total_score"] = test_df.apply(lambda row: sum([row[k]*class_2_aupr[k] for k in class_2_aupr]), axis=1)
59 | test_df["weighted_prediction_score"] = test_df.apply(lambda row:
60 | class_2_aupr[row["predicted_class"]] *
61 | row["predicted_class_score"], axis=1)
62 | test_df = test_df.sort_values(by="weighted_total_score", ascending=False)
63 |
64 | test_df.to_pickle(os.path.join(OUTPUT_DIR, "hypothetical_predictions.pkl"))
65 |
66 |
--------------------------------------------------------------------------------
/scripts/rarefaction_analysis.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import os
4 | import argparse
5 |
6 | # word 2 vec
7 | from gensim.models import word2vec as w2v
8 |
9 | class EmpiricalRarefaction(object):
10 | def __init__(self, mdl, function):
11 | self.mdl = w2v.Word2Vec.load(mdl)
12 | self.preds = pd.read_pickle(os.path.join(os.path.dirname(mdl),
13 | 'predictions/hypothetical_predictions.pkl'))
14 | self.function = function
15 | self.t_weighted = 0.9
16 | self.t_unweighted = 0.99
17 | self.out_dir = os.path.join(os.path.dirname(mdl),'predictions')
18 |
19 | def set_preds_to_function(self):
20 | preds = self.preds
21 | preds = preds[(preds['predicted_class'] == self.function) &
22 | ((preds['weighted_prediction_score'] > self.t_weighted) |
23 | (preds['predicted_class_score'] > self.t_unweighted))]
24 | self.preds = preds
25 |
26 | def bootstrap_samples(self, n_bootstraps, n_genes_min, n_genes_max, step, alpha=0.05):
27 | preds = self.preds
28 | mdl = self.mdl
29 | preds['word_count'] = preds['word'].apply(lambda w: mdl.wv.vocab[w].count)
30 | preds = preds[['word', 'word_count']]
31 | preds['words_by_count'] = preds.apply(lambda row: [row["word"]] * row['word_count'], axis=1)
32 |
33 | gene_words = preds['words_by_count'].explode().values
34 |
35 | res = []
36 | for n_genes in np.arange(n_genes_min, n_genes_max, step):
37 | n_genes = min(n_genes, len(gene_words))
38 | X = np.random.choice(gene_words, size=[n_bootstraps, n_genes])
39 | X_sorted = np.sort(X, axis=1)
40 | uniq_genes_dist = (X_sorted[:,1:] != X_sorted[:,:-1]).sum(axis=1)+1
41 | uniq_genes_mean = np.mean(uniq_genes_dist)
42 | upper_std = uniq_genes_mean + np.std(uniq_genes_dist)
43 | lower_std = uniq_genes_mean - np.std(uniq_genes_dist)
44 | upper_q = 2*uniq_genes_mean - np.quantile(uniq_genes_dist, alpha/2)
45 | lower_q = 2*uniq_genes_mean - np.quantile(uniq_genes_dist, 1 - alpha / 2)
46 | res.append((uniq_genes_mean, upper_std, lower_std, upper_q, lower_q, self.function, n_genes))
47 | df = pd.DataFrame(res, columns=['uniq_genes_mean', 'upper_std', 'lower_std',
48 | 'upper_q', 'lower_q', 'function', 'n_genes'])
49 |
50 | df.to_pickle(os.path.join(self.out_dir, f'{self.function}_bootstrap.pkl'))
51 | return df
52 |
53 |
54 |
55 |
56 | if __name__ == "__main__":
57 | argparse = argparse.ArgumentParser()
58 | argparse.add_argument('--model', required=True, type=str, help='model file path')
59 | argparse.add_argument('--function', required=True, type=str, help='functional group')
60 | argparse.add_argument('--min_genes', default=100, type=int, help='min number of genes to sample')
61 | argparse.add_argument('--max_genes', default=500000, type=int, help='max number of genes to sample')
62 | argparse.add_argument('--bootstrap', default=10000, type=int, help='mumber of bootstraps')
63 | argparse.add_argument('--step', default=1000, type=int, help='step size')
64 | params = argparse.parse_args()
65 |
66 | rarefaction = EmpiricalRarefaction(mdl=params.model, function=params.function)
67 | rarefaction.set_preds_to_function()
68 | rarefaction.bootstrap_samples(n_bootstraps=params.bootstrap, n_genes_min=params.min_genes, n_genes_max=params.max_genes, step=params.step, alpha=0.05)
69 |
70 |
--------------------------------------------------------------------------------
/scripts/utils.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import statsmodels.stats.multitest as multi
3 | from matplotlib.backends.backend_pdf import PdfPages
4 | import os
5 |
6 | import pandas as pd
7 | import umap
8 | import glob
9 | import pickle
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | from mpl_toolkits.mplot3d import Axes3D
13 | import seaborn as sns
14 | from scipy.stats import stats, entropy
15 | import gensim
16 | from gensim.models import word2vec as w2v
17 | import re
18 | from tqdm import tqdm
19 |
20 | # taken from https://umap-learn.readthedocs.io/en/latest/parameters.html
21 | def draw_umap(data, n_neighbors=15, min_dist=0.1, n_components=2, metric='euclidean', title=''):
22 | fit = umap.UMAP(
23 | n_neighbors=n_neighbors,
24 | min_dist=min_dist,
25 | n_components=n_components,
26 | metric=metric
27 | )
28 | u = fit.fit_transform(data)
29 | fig = plt.figure()
30 | if n_components == 1:
31 | ax = fig.add_subplot(111)
32 | ax.scatter(u[:,0], range(len(u)), c=data)
33 | if n_components == 2:
34 | ax = fig.add_subplot(111)
35 | ax.scatter(u[:,0], u[:,1], c=data)
36 | if n_components == 3:
37 | ax = fig.add_subplot(111, projection='3d')
38 | ax.scatter(u[:,0], u[:,1], u[:,2], c=data, s=100)
39 | plt.title(title, fontsize=18)
40 |
41 | def reducer(matrix, mdl, n_neighbors=20, min_dist=0.1, n_components=2, metric='euclidean'):
42 | fit = umap.UMAP(
43 | n_neighbors=n_neighbors,
44 | min_dist=min_dist,
45 | n_components=n_components,
46 | metric=metric
47 | )
48 | u = fit.fit_transform(matrix)
49 |
50 | points = pd.DataFrame([tuple([word] + [coord for coord in coords])
51 | for word, coords in [(word, u[mdl.wv.vocab[word].index])
52 | for word in mdl.wv.vocab]],
53 | columns=["word"] + [str(i) for i in range(n_components)])
54 | return points
55 |
56 | def add_metadata(metadata_path, mdl_folder):
57 | """
58 | add metadata for 2D words
59 | :param metadata_path: a path to a ko table with annotations per KO
60 | :param mdl_folder: a folder containing g2v model outputs
61 | :return: a merged dataframe
62 | """
63 | ko_table = pd.read_table(metadata_path)
64 | # read files from input folder
65 | cur_files = glob.glob(f"{mdl_folder}/*")
66 | with open([c for c in cur_files if "tsne" in c][0], 'rb') as o:
67 | words = pickle.load(o)
68 | words["hypothetical"] = words["word"].apply(lambda x: "YES" if "Cluster" in x else "NO")
69 | words["KO"] = words["word"]
70 | # merge files and save as pickle to reduce space
71 | merged = words.merge(ko_table, on=["KO"], how="left").fillna("unkown")
72 | return merged
73 |
74 | def cluster_data(merged, cluster_obj):
75 | """
76 | cluster words data using a clustering algorithm
77 | :param merged: words dataframe having "x","y" columnsrr
78 | :param cluster_obj: abject for clustering, need to have a fit_predict function
79 | :return: 2d clustered dataframe
80 | """
81 |
82 | fp = getattr(cluster_obj, "fit_predict", None)
83 | if not callable(fp):
84 | raise Exception("cluster object provided do not have a fit_predict function")
85 |
86 | cluster_labels = cluster_obj.fit_predict(merged[["x","y"]])
87 | merged["cluster"] = cluster_labels
88 | merged = merged.sort_values(by="cluster")
89 | merged["cluster"] = merged["cluster"].astype(str)
90 | return merged
91 |
92 |
93 | def get_kegg_enrichments(merged):
94 | """
95 | calculate all naive enrichments of kegg ko_lvl_3 annotations using a fisher exact test
96 | :param merged: a data frame of words, must be merged with KO data
97 | :return: a pair of dataframes- (merged df with enrichments, only enrichments)
98 | """
99 | res = {}
100 | dfs = []
101 | for cluster in merged["cluster"].unique():
102 | clust_df = merged[merged["cluster"] == cluster]
103 | split_by = "KO_lvl_3"
104 | splitted = clust_df[split_by].apply(lambda x: x.split(';')).explode()
105 | annotations = [s.strip() for s in splitted]
106 | res[cluster] = Counter(annotations)
107 | d = pd.DataFrame.from_dict(res[cluster], orient='index').reset_index().rename(
108 | columns={'index': 'annotation', 0: 'count'})
109 | d["cluster"] = cluster
110 | dfs.append(d)
111 | df = pd.concat(dfs)
112 |
113 | annot_enrich = []
114 | for cluster in merged["cluster"].unique():
115 | clust_df = df[df["cluster"] == cluster]
116 | not_clust_df = df[df["cluster"] != cluster]
117 |
118 | annotations = clust_df[clust_df["count"] > 1]["annotation"].unique()
119 | for annot in annotations:
120 | cluster_annot = clust_df[clust_df["annotation"] == annot]["count"].sum()
121 | not_cluster_annot = not_clust_df[not_clust_df["annotation"] == annot]["count"].sum()
122 | cluster_not_annot = clust_df[clust_df["annotation"] != annot]["count"].sum()
123 | not_cluster_not_annot = not_clust_df[not_clust_df["annotation"] != annot]["count"].sum()
124 |
125 | oddsratio, pvalue = stats.fisher_exact(
126 | [[cluster_annot, cluster_not_annot], [not_cluster_annot, not_cluster_not_annot]])
127 | annot_enrich.append((oddsratio, pvalue, cluster, annot))
128 | data = pd.DataFrame(annot_enrich, columns=["odds", "pvalue", "cluster", "annotation"])
129 | data['corrected_pvalue'] = multi.fdrcorrection(data['pvalue'])[1]
130 | data["enriched"] = data["corrected_pvalue"].apply(lambda x: "yes" if x < 0.05 else "no")
131 |
132 | enriched_df = data[(data["enriched"] == "yes") & (data["annotation"] != "unkown")]
133 | enriched_df = enriched_df.sort_values(by=["cluster", "corrected_pvalue"])
134 | res_df = merged.merge(data, on="cluster", how="left")
135 |
136 | return res_df, enriched_df
137 |
138 |
139 | def cluster_entropy(merged, mode="collapsed"):
140 | """
141 | get the entropy of each cluster
142 | :param merged: a dataframe with words, kos and clusters
143 | :return: a dict - cluster: (score, cluster size, unknown size)
144 | """
145 | res = {}
146 | for cluster in merged["cluster"].unique():
147 | clust_df = merged[merged["cluster"] == cluster]
148 | split_by = "KO_lvl_3"
149 | if mode != "collapsed":
150 | splitted = clust_df[split_by].apply(lambda x: x.split(';')).explode()
151 | annotations = [s.strip() for s in splitted]
152 | else:
153 | annotations = clust_df[split_by]
154 | res[cluster] = Counter(annotations)
155 |
156 | cluster_scores = {}
157 | for cluster in res:
158 | vals = [v for key, v in res[cluster].items() if key!= "unkown"]
159 | n = sum(vals)
160 | score = entropy(vals, base=n)
161 | n_unknown = res[cluster]["unkown"]
162 | size = sum(res[cluster].values())
163 | cluster_scores[cluster] = (score, size, n_unknown)
164 | return cluster_scores
165 |
166 |
167 | def process_word_statistics(mdl, outdir, hypo_word = 'hypo.clst'):
168 | with PdfPages(os.path.join(outdir, f'HypoDistribution.pdf')) as pdf:
169 | w2c_known = dict()
170 | w2c_unknown = dict()
171 | for item in mdl.wv.vocab:
172 | if hypo_word in item:
173 | w2c_unknown[item] = mdl.wv.vocab[item].count
174 | else:
175 | w2c_known[item] = mdl.wv.vocab[item].count
176 |
177 | fig, ax = plt.subplots(1, 2, figsize=(14, 4))
178 |
179 | ax[0].hist(w2c_known.values(), color='#7FB7E5', bins=20)
180 | ax[0].set_title(
181 | f"Known\nAVG: {round(np.mean(list(w2c_known.values())), 2)}, MED: {round(np.median(list(w2c_known.values())), 2)} MAX: {max(w2c_known.values())}")
182 | ax[0].set_yscale("log")
183 | ax[0].grid(True)
184 |
185 | ax[1].hist(w2c_unknown.values(), color='#DC3D13', alpha=0.8)
186 | ax[1].set_title(
187 | f"Unknown\nAVG: {round(np.mean(list(w2c_unknown.values())), 2)}, MED: {round(np.median(list(w2c_unknown.values())), 2)} MAX: {max(w2c_unknown.values())}")
188 | ax[1].set_yscale("log")
189 | ax[1].grid(True)
190 |
191 | pdf.savefig(transparent=True, bbox_inches="tight")
192 | plt.close()
193 |
194 | def summeraize_mdls(w2v_mdl, word2metadata="words2metadata.pkl"):
195 |
196 | with open(word2metadata, 'rb') as o:
197 | words = pickle.load(o)
198 |
199 | res = []
200 | for mdl in tqdm(w2v_mdl):
201 | g2v = w2v.Word2Vec.load(mdl)
202 | corpus_type = "annotation" if "extended" not in mdl else "annotation_extended"
203 | batch_type = mdl.split("/")[6]
204 | mintf = int(re.findall(r"tf(\d*)_annotation", mdl)[-1])
205 | vocab_size = sum([v.count for k,v in g2v.wv.vocab.items()])
206 | unique_tokens = len(g2v.wv.vocab)
207 | unique_hypo = len([k for k in g2v.wv.vocab if 'hypo.clst.' in k])
208 | unique_kegg = unique_tokens - unique_hypo
209 | hypo_count = sum([v.count for k,v in g2v.wv.vocab.items() if 'hypo.clst.' in k])
210 | diamond_hypo = sum([words[k][0] for k in g2v.wv.vocab if 'hypo.clst.' in k and k in words])
211 | diamond_known_hypo = len([words[k][0] for k in g2v.wv.vocab if 'hypo.clst.' in k and k in words]) - diamond_hypo
212 | diamond_nf = unique_hypo - diamond_hypo - diamond_known_hypo
213 |
214 | res.append((corpus_type, batch_type, mintf, vocab_size, unique_tokens, unique_kegg, unique_hypo, hypo_count, diamond_hypo, diamond_known_hypo, diamond_nf))
215 |
216 |
217 | df = pd.DataFrame(res, columns=['corpus_type', 'batch_type', 'mintf', 'vocab_size', 'unique_tokens', 'unique_kegg', 'unique_hypo', 'hypo_count', 'diamond_hypo','diamond_known_hypo', 'diamond_not_found'])
218 | df["kegg_count"] = df["vocab_size"] - df["hypo_count"]
219 | df["per_kegg"] = df["kegg_count"] / df["vocab_size"]
220 | df["per_hypo"] = df["hypo_count"] / df["vocab_size"]
221 | df["per_unique_tokens"] = df["unique_tokens"] / df["vocab_size"]
222 | df["per_unique_kegg"] = df["unique_kegg"] / df["unique_tokens"]
223 | df["per_unique_hypo"] = df["unique_hypo"] / df["unique_tokens"]
224 |
225 | df["per_diamond_hypo"] = df["diamond_hypo"] / df["unique_hypo"]
226 | df["per_diamond_known_hypo"] = df["diamond_known_hypo"] / df["unique_hypo"]
227 | df["per_diamond_not_found"] = df["diamond_not_found"] / df["unique_hypo"]
228 | df = df.sort_values(by=["corpus_type", "mintf"])
229 | return df
230 |
--------------------------------------------------------------------------------
/src/genomic_embeddings/Embeddings.py:
--------------------------------------------------------------------------------
1 | from gensim.models import word2vec as w2v
2 | import pickle
3 | import pandas as pd
4 |
5 |
6 | def load_embeddings(embedding_mdl):
7 | """
8 | load the existing embeddings from trained model
9 | :param embedding_mdl: the path to a trained w2v model
10 | :return: w2v object, with embeddings for each word in the corpus
11 | """
12 | return w2v.Word2Vec.load(embedding_mdl)
13 |
14 | def get_2d_mapping(embedding_2d_path):
15 | """
16 | get the 2d coordinates as obtained by umap for each gene in the vocabulary.
17 | :param embedding_2d_path: a path to pickle file containing the 2d coordinates for each gene
18 | :return: a dataframe with a word and coordinates
19 | """
20 | with open(embedding_2d_path, "rb") as handle:
21 | embedding_2d = pickle.load(handle)
22 | return embedding_2d
23 |
24 | def get_functional_prediction(predicted_hypo_path):
25 | """
26 | get a table with all prediction made by the functional model
27 | :param predicted_hypo_path: a path to the pickle file with the hypothetical proteins
28 | :return: data frame with predictions for every hypothetical word
29 | """
30 | return pd.read_pickle(predicted_hypo_path, 'rb')
31 |
--------------------------------------------------------------------------------
/src/genomic_embeddings/Gff.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import numpy as np
4 | import pandas as pd
5 | import glob
6 | from Bio import SeqIO
7 | import pickle
8 | from BCBio import GFF
9 |
10 | class Gff(object):
11 | def __init__(self, gff_path=None, hypothetical_folder='/hypothetical_mapping/',
12 | ko_path="metadata.csv"):
13 | self.gff = gff_path
14 | self.ko_path = ko_path
15 | self.hypothetical = None
16 | self.keggRun = False
17 | self.hypotheticalRun = False
18 | self.clusterRun = False
19 | self.name = None
20 | self.gff_table = None
21 |
22 | self.fasta = self.gff.replace('.kg.05_21.gff', '.fa')
23 | self.proteins_fasta = self.gff.replace('.kg.05_21.gff', '.proteins.faa')
24 | self.hypothetical_path = os.path.join(hypothetical_folder, f"{os.path.basename(gff_path).split('.contig')[0]}.pkl")
25 |
26 | def __repr__(self):
27 | return self.name
28 |
29 |
30 | def set_name(self):
31 | # get the name of the file as the name of the gff object (includes some previous folder hierarchy)
32 | protein_fasta_path = self.proteins_fasta
33 | self.name = os.path.basename(protein_fasta_path).split(".contig")[0]
34 |
35 | def set_hypothetical(self):
36 | hypothetical_path = self.hypothetical_path
37 | if not os.path.exists(hypothetical_path):
38 | raise FileNotFoundError
39 | with open(hypothetical_path, 'rb') as hanlde:
40 | hypothetical_mapping = pickle.load(hanlde)
41 | self.hypothetical = hypothetical_mapping
42 |
43 |
44 | def set_gff_table(self):
45 | gff_table_path = f"{self.gff}.parsed.tsv"
46 | if not os.path.exists(gff_table_path):
47 | raise FileNotFoundError
48 | gff_table = pd.read_table(gff_table_path)
49 | self.gff_table = gff_table
50 |
51 |
52 | def parse_gff(self):
53 | f = self.gff
54 | if not os.path.exists(f):
55 | raise Exception("No GFF is currently available, need to run_kegg_qprokka first")
56 | if os.path.exists(f.replace(".gff", ".gff.parsed.tsv")):
57 | return
58 |
59 | records = []
60 | in_handle = open(f)
61 | for rec in tqdm(GFF.parse(in_handle)):
62 | for feature in rec.features:
63 | strand = feature.strand
64 | start = int(feature.location.start)
65 | end = int(feature.location.end)
66 | if feature.id == '' and 'locus_tag' in feature.qualifiers:
67 | feature.id = feature.qualifiers['locus_tag'][0]#tmp addition due to curr gff file formats
68 | if "inference" not in feature.qualifiers:
69 | feature.qualifiers["inference"] = "no inference record"
70 |
71 | if "product" in feature.qualifiers:
72 | records.append((feature.id, feature.qualifiers['product'][0], feature.qualifiers['inference'][-1], feature.type, strand, start, end))
73 | else:
74 | records.append((feature.id, feature.type, feature.type, feature.type, strand, start, end))
75 | in_handle.close()
76 | df = pd.DataFrame(records, columns=["contig_id", "product", "inference", "type", "strand", "start", "end"])
77 | df["annotation"] = df.apply(lambda row: _annotate(row["product"], row["inference"], row["type"]), axis=1)
78 | df["annotation_extended"] = df.apply(lambda row: _annotate_extended(row["product"], row["inference"],
79 | row["type"]), axis=1)
80 |
81 | output_path = f.replace(".gff", ".gff.parsed.tsv")
82 | df.to_csv(output_path, index=False, sep='\t')
83 |
84 | def extract_hypothetical(self):
85 | """ extract all hypothetical proteins for sequence-based clustering """
86 | protein_fasta_path = self.proteins_fasta
87 | hypothetical_fasta_path = protein_fasta_path.replace(".faa", ".hypothetical.faa")
88 | gff_table_path = self.gff.replace(".gff", ".gff.parsed.tsv")
89 |
90 | if self.hypotheticalRun or os.path.exists(hypothetical_fasta_path):
91 | self.hypotheticalRun = True
92 | return
93 |
94 | if not os.path.exists(gff_table_path):
95 | raise Exception("no parsed gff table was found - please run parse_gff to obtain the table")
96 |
97 | annotations = pd.read_table(gff_table_path)
98 | #adjust to lower case:
99 | annotations["annotation"] = annotations["annotation"].apply(lambda x: x.lower())
100 | hypothetical_ids = annotations[annotations["annotation"].isin(["hypothetical protein", "putative protein"])][
101 | "contig_id"].values
102 | hypothetical_records = []
103 | for rec in SeqIO.parse(protein_fasta_path, 'fasta'):
104 | if rec.name in hypothetical_ids:
105 | hypothetical_records.append(rec)
106 | with open(hypothetical_fasta_path, 'w') as handle:
107 | SeqIO.write(hypothetical_records, handle, "fasta")
108 |
109 | self.hypotheticalRun = True
110 |
111 | def extract_hypothetical_and_prokka(self):
112 | """ extract all hypothetical proteins for sequence-based clustering """
113 | protein_fasta_path = self.proteins_fasta
114 | hypothetical_fasta_path = protein_fasta_path.replace(".faa", ".hypothetical.prokka.faa")
115 | gff_table_path = self.gff.replace(".gff", ".gff.parsed.tsv")
116 | if os.path.exists(hypothetical_fasta_path):
117 | self.hypotheticalRun = True
118 | return
119 |
120 | if not os.path.exists(gff_table_path):
121 | raise Exception("no parsed gff table was found - please run parse_gff to obtain the table")
122 |
123 | annotations = pd.read_table(gff_table_path)
124 | ko_table = pd.read_table(self.ko_path)
125 | filter_ids = \
126 | annotations[(~annotations["annotation"].isin(ko_table['KO'])) & (annotations["type"] == 'CDS')][
127 | "contig_id"].values.tolist()
128 |
129 | hypothetical_and_prokka_records = []
130 | for rec in SeqIO.parse(protein_fasta_path, 'fasta'):
131 | if rec.name in filter_ids:
132 | hypothetical_and_prokka_records.append(rec)
133 | with open(hypothetical_fasta_path, 'w') as handle:
134 | SeqIO.write(hypothetical_and_prokka_records, handle, "fasta")
135 | self.hypotheticalRun = True
136 | # for in house use
137 | def cluster_hypothetical(self, queue="dudulight", threads=4):
138 | """ run mmseq cluster to generate a table of clustered """
139 | hypothetical_path = self.proteins_fasta.replace(".faa", ".hypothetical.faa")
140 | os.system(f"python /davidb/daniellemiller/bioutils/scripts/mmseq_cluster_runner.py --query {hypothetical_path} "
141 | f"--queue {queue} --threads {threads}")
142 | self.hypotheticalRun = True
143 |
144 | def cluster_hypothetical_and_prokka(self, queue="dudulight", threads=4):
145 | """ run mmseq cluster to generate a table of clustered """
146 | hypothetical_path = self.proteins_fasta.replace(".faa", ".hypothetical.prokka.faa")
147 | os.system(f"python /davidb/daniellemiller/bioutils/scripts/mmseq_cluster_runner.py --query {hypothetical_path} "
148 | f"--queue {queue} --threads {threads}")
149 | self.hypotheticalRun = True
150 |
151 | def assign_clusters(self):
152 | clustering_path = self.hypothetical_path
153 | if not os.path.exists(clustering_path):
154 | print("Clustering tsv do not exists, check whether cluster_hypothetical was previously run")
155 | return
156 | output_path = clustering_path.replace(".tsv", ".assigned.tsv")
157 | if os.path.exists(output_path):
158 | return pd.read_table(output_path)
159 | table = pd.read_table(clustering_path, names=["id1","id2"]) #id1 is the cluster representative, id2 is the matched orf
160 | table["cluster_id"] = pd.factorize(table["id1"])[0]
161 | table["cluster_id"] = table["cluster_id"].apply(
162 | lambda x: f"{os.path.basename(clustering_path).split('.contig')[0].replace('.', '_')}_Cluster_{x}")
163 | table.to_csv(output_path, index=False, sep='\t')
164 | return table
165 |
166 |
167 | def _annotate(product, inference, gff_type):
168 | if "kg.05_21.ren4prok" in inference:
169 | return inference.split(":")[-1].split('.')[0]
170 | elif gff_type != "CDS":
171 | return gff_type
172 | else:
173 | return product
174 |
175 | def _annotate_extended(product, inference, gff_type):
176 | if "kg.05_21.ren4prok" in inference:
177 | return inference.split(":")[-1]
178 | elif gff_type != "CDS":
179 | return gff_type
180 | else:
181 | return product
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
--------------------------------------------------------------------------------
/src/genomic_embeddings/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/burstein-lab/genomic-nlp/301c621aa663bfc08a2e159e4bac2df87ba520b2/src/genomic_embeddings/__init__.py
--------------------------------------------------------------------------------
/src/genomic_embeddings/corpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import numpy as np
4 | import pandas as pd
5 | import pickle
6 | import glob
7 | from Bio import SeqIO
8 | import socket
9 | import subprocess
10 | import hashlib
11 | from BCBio import GFF
12 |
13 | class CorpusGenerator(object):
14 | def __init__(self, gff, by="annotation", include='include.txt',
15 | word_mapping='word_mapping.pkl'):
16 | self.gff = gff
17 | self.name = gff.name
18 | self.include = include
19 | self.annotation = by
20 | self.text_df = None
21 | self.word_mapping = word_mapping
22 |
23 | def __repr__(self):
24 | return self.name
25 |
26 | def get_gff(self):
27 | return self.gff
28 |
29 | def validate(self):
30 | """ validate params before calling to make sentences"""
31 | gff = self.get_gff()
32 | gff.set_name()
33 | try:
34 | gff.set_gff_table()
35 | except:
36 | print("Cannot load KEGG table")
37 | return
38 | try:
39 | gff.set_hypothetical()
40 | except:
41 | print("Cannot load hypothetical table")
42 | return
43 | return gff
44 |
45 |
46 | def make_sentences_df(self):
47 | gff = self.validate()
48 | if gff is None:
49 | print(f"Validation failed for GFF instance {self.gff.name}.")
50 | return
51 | hypo2rep = gff.hypothetical
52 | sample = gff.gff_table
53 | annotation = self.annotation
54 |
55 | with open(self.include, 'r') as o:
56 | includes = [l.replace('\n', '') for l in o.readlines()]
57 |
58 | sample["word"] = sample.apply(
59 | lambda row: hypo2rep[row["contig_id"]] if row["contig_id"] in hypo2rep else row[annotation],
60 | axis=1)
61 | sample['ctg'] = sample['contig_id'].apply(lambda x: x.rsplit('_', 1)[0])
62 | sample = sample[sample['ctg'].isin(includes)]
63 | sample['orf'] = sample['contig_id'].apply(lambda x: int(x.rsplit('_', 1)[-1]))
64 |
65 | data_to_text = sample.sort_values(["ctg", "orf"]).groupby(["ctg"])["word"].apply(
66 | list).reset_index()
67 |
68 | self.text_df = data_to_text
69 |
70 | return data_to_text
71 |
72 | def compile_text(self, path):
73 | """ create a text file contains the """
74 | text_df = self.text_df
75 | if text_df is None:
76 | print("no text df found, Exiting....")
77 | return
78 |
79 | text_df["text"] = text_df["word"].apply(lambda x: ' '.join(x))
80 | text = '. '.join(text_df["text"].values)
81 | with open(path, "w") as handle:
82 | handle.write(text)
83 | return text
84 |
85 | def generate(self, path):
86 | text_df = self.make_sentences_df()
87 | if text_df is None:
88 | return
89 | self.compile_text(path)
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/src/genomic_embeddings/data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import re
4 |
5 | # word 2 vec
6 | from gensim.models import word2vec as w2v
7 |
8 |
9 | class Embedding(object):
10 | def __init__(self, mdl, metadata, labels=None):
11 | self.mdl = w2v.Word2Vec.load(mdl)
12 | self.metadata = pd.read_csv(metadata)
13 | self.labels = labels
14 | self.embedding = self.mdl.wv.vectors.astype('float64')
15 | self.known_embeddings = None
16 | self.word2index = None
17 | self.effective_words = None
18 | self.train_words = None
19 | self.data = None
20 | self.unknown_embeddings = None
21 | self.unknown_word2index = None
22 | self.data_with_words = None
23 |
24 | def extract_known_words(self, unknown="hypo.clst"):
25 | idxs = [self.mdl.wv.vocab[word].index for word in self.mdl.wv.vocab if unknown not in word]
26 | known_mat = self.embedding[idxs]
27 | known_word2index = {self.mdl.wv.index2word[word]: i for i, word in enumerate(idxs)}
28 |
29 | self.known_embeddings = known_mat
30 | self.word2index = known_word2index
31 |
32 | def extract_effective_words(self, label='label'):
33 | metadata = self.metadata
34 | metadata[label] = metadata[label].apply(lambda x: re.split('(.)\[|\(|,', x)[0].strip()) #remove redundant
35 | eff_words = pd.DataFrame(self.word2index.items(), columns=["word","index"])
36 | eff_words["KO"] = eff_words["word"].apply(lambda x: x.rsplit(".")[0])
37 | eff_words = eff_words.merge(metadata, on=["KO"], how='left')[["word","index",label]].dropna()
38 | self.effective_words = eff_words
39 |
40 | def filter_effective_words(self, q=0.96, label='label'):
41 | eff_words = self.effective_words
42 | labels = self.labels
43 | if labels is None:
44 | labels_count = eff_words.groupby(label).size().reset_index(name="size").\
45 | sort_values(by="size", ascending=False)
46 | labels_to_keep = labels_count[labels_count["size"] >= np.quantile(labels_count["size"], q)]
47 | labels_to_keep = labels_to_keep[
48 | ~labels_to_keep[label].isin(["Function unknown [99997]", "Enzymes with EC numbers [99980]"])]
49 | labels = labels_to_keep[label].values
50 |
51 | eff_words = eff_words[eff_words[label].isin(labels)]
52 | self.train_words = eff_words
53 |
54 | if self.labels is None:
55 | self.labels = eff_words[label].unique()
56 |
57 | def add_other_class(self, label='label', sample_size=12, min_points=30):
58 | eff_words = self.effective_words
59 | eff_words[label] = eff_words[label].apply(lambda x: re.split('(.)\[|\(|,', x)[0].strip()) # remove redundant
60 |
61 | label_sizes = eff_words.groupby('label').size().reset_index(name='size')
62 | labels_to_keep = self.labels
63 |
64 | sample_from = label_sizes[~label_sizes["label"].isin(labels_to_keep)].sort_values(by='size', ascending=False)
65 | labels_to_sample_from = sample_from[sample_from['size'] > min_points]['label']
66 |
67 | other_class = eff_words[eff_words['label'].isin(labels_to_sample_from)].groupby("label").sample(n=sample_size,
68 | random_state=42)
69 | other_class['label'] = 'Other'
70 | data_with_other = pd.concat([self.train_words, other_class])
71 | self.train_words = data_with_other
72 |
73 |
74 | def cleanup_train_data(self, add_other=False):
75 | df = pd.DataFrame(self.known_embeddings)
76 | if add_other:
77 | self.add_other_class()
78 | df = df.reset_index().merge(self.train_words, on="index", how="right")
79 | self.data_with_words = df
80 | self.data = df.drop(columns=["index", "word"])
81 |
82 |
83 | def process_unknown_words(self, labels2filter, label='label'):
84 | meta = self.metadata
85 | meta['label'] = meta['label'].apply(lambda x: re.split('(.)\[|\(|,', x)[0].strip())
86 | if labels2filter is None:
87 | labels2filter = meta[label].unique()
88 | train_words = meta[meta[label].isin(labels2filter)]['KO'].values
89 | test_embeddings_idx = {word: self.mdl.wv.vocab[word].index for word in self.mdl.wv.vocab if
90 | word not in train_words}
91 | unknown_embs = self.embedding[[*test_embeddings_idx.values()]]
92 |
93 | self.unknown_embeddings = unknown_embs
94 | self.unknown_word2index = test_embeddings_idx
95 |
96 | def process_data_pipeline(self, label, q, add_other=False):
97 | self.extract_known_words()
98 | self.extract_effective_words(label=label)
99 | self.filter_effective_words(q=q, label=label)
100 | self.cleanup_train_data(add_other=add_other)
101 |
102 |
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------
/src/genomic_embeddings/gene2vec.py:
--------------------------------------------------------------------------------
1 | # imports
2 | import codecs
3 | import glob
4 | import logging
5 | import multiprocessing
6 | from tqdm import tqdm
7 | import argparse
8 | import os
9 | import pickle
10 | import sys
11 | import sklearn.manifold
12 | import pandas as pd
13 | from datetime import datetime
14 | import umap
15 |
16 | #word 2 vec
17 | from gensim.models import word2vec as w2v
18 |
19 |
20 | class Corpus(object):
21 | def __init__(self, dir_path):
22 | self.dirs = dir_path
23 | self.len = None
24 | self.corpus = None
25 | self.sentences = None
26 | self.token_count = None
27 |
28 | def load_corpus(self):
29 | # initialize rawunicode , all text goes here
30 | corpus_raw = u""
31 | files = glob.glob(self.dirs)
32 | files.sort()
33 | print(f"Number of files in corpus: {len(files)}")
34 | for f in tqdm(files):
35 | with codecs.open(f, "r", "utf-8") as book_file:
36 | corpus_raw += book_file.read()
37 |
38 | # set current corpus
39 | self.corpus = corpus_raw
40 |
41 | def make_sentences(self, delim=". "):
42 | # create sentences from corpus
43 | if self.corpus == None:
44 | print("Error: no corpus object found, use load_corpus function to generate corpus object")
45 | return
46 | raw_sentences = self.corpus.split(delim)
47 | sentences = []
48 | for raw_sentence in tqdm(raw_sentences):
49 | if len(raw_sentence) > 0:
50 | sentences.append(raw_sentence.split())
51 | self.sentences = sentences
52 |
53 | # update number of tokens in corpus
54 | self.token_count = sum([len(sentence) for sentence in sentences])
55 |
56 | def main(args):
57 | # configure logger -
58 | out_dir = os.path.join(args.output, args.alias)
59 | if not os.path.exists(out_dir):
60 | os.makedirs(out_dir)
61 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
62 | filename=os.path.join(out_dir, f"{args.alias}.log"), level=logging.INFO)
63 |
64 | corpus = Corpus(args.input)
65 | corpus.load_corpus()
66 | corpus.make_sentences()
67 | # Seed for the RNG, to make the results reproducible.
68 | seed = 1
69 | if args.workers == None:
70 | args.workers = multiprocessing.cpu_count()
71 |
72 | # build model
73 | gene2vec = w2v.Word2Vec(
74 | sg=1,
75 | seed=seed,
76 | workers=args.workers,
77 | size=args.size,
78 | min_count=args.minTF,
79 | window=args.window,
80 | sample=args.sample
81 | )
82 | gene2vec.build_vocab(corpus.sentences)
83 | print("Gene2Vec vocabulary length:", len(gene2vec.wv.vocab))
84 | gene2vec.train(corpus.sentences,
85 | total_examples=gene2vec.corpus_count, epochs=args.epochs)
86 | # save model
87 | gene2vec.save(os.path.join(out_dir, f"{args.alias}_{datetime.today().strftime('%Y-%m-%d')}.w2v"))
88 |
89 | mapper = umap.UMAP(n_neighbors=15,min_dist=0.0, n_components=2)
90 | #train umap
91 | all_word_vectors_matrix_2d = mapper.fit_transform(gene2vec.wv.vectors.astype(
92 | 'float64'))
93 | points = pd.DataFrame([(word, coords[0], coords[1])
94 | for word, coords in [(word, all_word_vectors_matrix_2d[gene2vec.wv.vocab[word].index])
95 | for word in gene2vec.wv.vocab]],
96 | columns=["word", "x", "y"])
97 | with open(os.path.join(out_dir, f"words_umap_{datetime.today().strftime('%Y-%m-%d')}"), 'wb') as o:
98 | pickle.dump(points, o)
99 |
100 |
101 | if __name__ == "__main__":
102 | argparse = argparse.ArgumentParser()
103 | argparse.add_argument('--window', default=5, type=int, help='window size')
104 | argparse.add_argument('--size', default=300, type=int, help='vector size')
105 | argparse.add_argument('--workers', required=False, type=int, help='number of processes')
106 | argparse.add_argument('--epochs', default=5, type=int, help='number of epochs')
107 | argparse.add_argument('--minTF', default=4, type=int, help='minimum term frequency')
108 | argparse.add_argument('--sample', default=1e-3, type=int, help='down sampling setting for frequent words')
109 | argparse.add_argument('--model', required=False, type=str, help='model file if exists')
110 | argparse.add_argument('--input', default='../data/*', type=str, help='dir to learn from, as a regex for file generation')
111 | argparse.add_argument('--output', default='outputs/', type=str, help='output folder for results')
112 | argparse.add_argument('--alias', default='G2V', type=str, help='model running alias that will be used for model tracking')
113 | params = argparse.parse_args()
114 |
115 | main(params)
116 |
--------------------------------------------------------------------------------
/src/genomic_embeddings/models.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import os
4 | import pickle
5 |
6 | # ML packages
7 | from sklearn import metrics
8 | from sklearn import model_selection
9 | import xgboost as xgb
10 | from sklearn.model_selection import StratifiedKFold
11 | from sklearn.ensemble import RandomForestClassifier
12 | from sklearn.svm import SVC
13 | # DL packages
14 | import tensorflow as tf
15 |
16 |
17 | ##### Models interface ######
18 |
19 | class Model(object):
20 | def __init__(self, X, y, out_dir, clf=None):
21 | self.X = X
22 | self.y = y
23 | self.clf = clf
24 | self.out_dir = out_dir
25 | self.name = "MDL"
26 | self.report = None
27 | self.confusion_matrix = None
28 | self.pr = None
29 | self.roc = None
30 | self.auc = None
31 | self.ap = None
32 | self.history=None
33 |
34 |
35 | def set_alias(self, alias='_TOPLABELS'):
36 | self.name = self.name + alias
37 |
38 | def model_fit(self, X_train, X_test, y_train):
39 | clf = self.clf
40 | clf.fit(X_train, y_train)
41 | predicted = clf.predict(X_test)
42 | predicted_prob = clf.predict_proba(X_test)
43 | return predicted, predicted_prob
44 |
45 |
46 | def summarize_accuracy(self, y_test, predicted, predicted_prob, fold):
47 | report = metrics.classification_report(y_test, predicted, output_dict=True)
48 | classes = np.unique(y_test)
49 | y_test_array = pd.get_dummies(y_test, drop_first=False).values
50 |
51 | pr_dfs = []
52 | roc_dfs = []
53 | for i in range(len(classes)):
54 | precision, recall, thresholds = metrics.precision_recall_curve(
55 | y_test_array[:, i], predicted_prob[:, i])
56 | report[classes[i]]["aupr"] = metrics.auc(recall, precision)
57 |
58 | fpr, tpr, thresholds = metrics.roc_curve(y_test_array[:, i],
59 | predicted_prob[:, i])
60 | report[classes[i]]["auc"] = metrics.auc(fpr, tpr)
61 | cur_df_pr = pd.DataFrame({"precision": precision, "recall": recall, "fold": fold, "class": classes[i]})
62 | cur_df_roc = pd.DataFrame({"fpr": fpr, "tpr": tpr, "fold": fold, "class": classes[i]})
63 | pr_dfs.append(cur_df_pr)
64 | roc_dfs.append(cur_df_roc)
65 | try:
66 | precision, recall, _ = metrics.precision_recall_curve(y_test_array.ravel(),
67 | predicted_prob.ravel())
68 | fpr, tpr, thresholds = metrics.roc_curve(y_test_array.ravel(), predicted_prob.ravel())
69 | micro_pr = pd.DataFrame({"precision": precision, "recall": recall, "fold": "micro", "class": "ALL"})
70 | micro_roc = pd.DataFrame({"fpr": fpr, "tpr": tpr, "fold": "micro", "class": "ALL"})
71 | pr_dfs.append(micro_pr)
72 | roc_dfs.append(micro_roc)
73 | auc = round(metrics.roc_auc_score(y_test, predicted_prob,
74 | multi_class="ovr", average='weighted'), 2)
75 | ap = round(metrics.average_precision_score(y_test_array, predicted_prob,average="micro"), 2)
76 | self.auc = auc
77 | self.ap = ap
78 | except:
79 | print("Cannot calculate accuracy metrics")
80 |
81 | report_df = pd.DataFrame(report).T
82 | report_df["fold"] = fold
83 |
84 | pr = pd.concat(pr_dfs)
85 | roc = pd.concat(roc_dfs)
86 |
87 | return report_df, pr, roc
88 |
89 |
90 | def split_and_classify(self):
91 | X = self.X
92 | y = self.y
93 | X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=42,
94 | stratify=y)
95 | predicted, predicted_prob = self.model_fit(X_train, X_test, y_train)
96 | report, precision_report, roc_report = self.summarize_accuracy(y_test, predicted, predicted_prob, fold='NO-CV')
97 |
98 | cm = metrics.confusion_matrix(y_test, predicted)
99 | classes = np.unique(y_test)
100 | cm = pd.DataFrame(cm, columns=classes, index=classes)
101 |
102 | self.pr = precision_report
103 | self.roc = roc_report
104 | self.report = report
105 | self.confusion_matrix = cm
106 |
107 | return report, precision_report, roc_report
108 |
109 | def wrap_up(self, label, alias):
110 | label_dir = os.path.join(self.out_dir, label)
111 | q_dir = os.path.join(label_dir, alias)
112 | if not os.path.exists(label_dir):
113 | os.makedirs(label_dir)
114 | if not os.path.exists(q_dir):
115 | os.makedirs(q_dir)
116 |
117 | report_path = os.path.join(q_dir, f"{self.name}_report.csv")
118 | cm_path = os.path.join(q_dir, f"{self.name}_confusion_matrix.csv")
119 | history_path = os.path.join(q_dir, f"{self.name}_history.pickle")
120 |
121 | self.out_dir = q_dir
122 | report = self.report.reset_index().rename(columns={"index": label})
123 | report.to_csv(report_path, index=False)
124 | cm = self.confusion_matrix
125 | history = self.history
126 | if cm is not None:
127 | cm.to_csv(cm_path)
128 | if history is not None:
129 | with open(history_path, 'wb') as handle:
130 | pickle.dump(history.history, handle)
131 |
132 | def classification_pipeline(self, label, alias='_TOPLABELS'):
133 | self.set_alias(alias)
134 | self.split_and_classify()
135 | self.wrap_up(label, alias)
136 |
137 | class FoldsModel(Model):
138 | def __init__(self, cv=StratifiedKFold(n_splits=5, shuffle=True, random_state=42), **kwargs):
139 | super().__init__(**kwargs)
140 | self.name = self.name + 'Folds'
141 | self.cv = cv
142 | self.folds_pr = None
143 | self.folds_roc = None
144 | self.mean_pr = None
145 |
146 | def calc_ovelall_pr_by_folds(self, y_test_list, predicted_prob_list):
147 | res = []
148 | for fold in range(len(y_test_list)):
149 | classes = np.unique(y_test_list[fold])
150 | y_test_array = pd.get_dummies(y_test_list[fold], drop_first=False).values
151 | predicted_prob = predicted_prob_list[fold]
152 |
153 | for i in range(len(classes)):
154 | fold_by_class_pr = pd.DataFrame({'class':classes[i],'y_test':y_test_array[:, i],
155 | 'predicted_prob':predicted_prob[:, i]})
156 | res.append(fold_by_class_pr)
157 | df = pd.concat(res)
158 |
159 | all_classes = []
160 | for cl in df['class'].unique():
161 | cl_df = df[df['class'] == cl]
162 | precision, recall, thresh = metrics.precision_recall_curve(cl_df['y_test'], cl_df['predicted_prob'])
163 | res = pd.DataFrame({'class':cl, 'precision':precision, 'recall':recall, 'thresh': np.insert(thresh, 0,0)})
164 | res = res[~((res['precision'] == 0) & (res['recall'] == 0))].sort_values(by='thresh')
165 | all_classes.append(res)
166 | self.mean_pr = pd.concat(all_classes)
167 |
168 | def split_and_classify(self):
169 | reports, prs, rocs = [], [], []
170 | fold = 1
171 | X = self.X
172 | y = self.y
173 |
174 | y_real = []
175 | y_proba = []
176 |
177 | for train_index, test_index in self.cv.split(X, y):
178 | X_train, X_test = X[train_index], X[test_index]
179 | y_train, y_test = y[train_index], y[test_index]
180 |
181 | predicted, predicted_prob = self.model_fit(X_train, X_test, y_train)
182 | fold_report, fold_pr, fold_roc = self.summarize_accuracy(y_test, predicted, predicted_prob, fold)
183 | reports.append(fold_report)
184 | prs.append(fold_pr)
185 | rocs.append(fold_roc)
186 | y_real.append(y_test)
187 | y_proba.append(predicted_prob)
188 | fold += 1
189 |
190 | self.pr = pd.concat(prs)
191 | self.roc = pd.concat(rocs)
192 | self.report = pd.concat(reports)
193 |
194 | self.calc_ovelall_pr_by_folds(y_real, y_proba)
195 |
196 | return self.report, self.pr, self.roc
197 |
198 | def merge_folds(self):
199 | pr = self.pr
200 | roc = self.roc
201 | overall_by_cl = self.mean_pr
202 |
203 | mean_precision = np.linspace(0, 1, 100)
204 | mean_fpr = np.linspace(0, 1, 100)
205 |
206 | pr_grp = pr.groupby(["class", "fold"]).agg({'precision': list, 'recall': list}).reset_index()
207 | pr_grp["interp"] = pr_grp.apply(lambda row: np.interp(np.linspace(0,1,max(100, overall_by_cl[overall_by_cl['class'] == row['class']].shape[0])), row['precision'], row['recall']), axis=1)
208 | pr_grp["auc"] = pr_grp.apply(lambda row: metrics.auc(sorted(row["precision"]),
209 | sorted(row["recall"], reverse=True)), axis=1)
210 | pr_data = pr_grp.groupby("class").agg({"interp": list, "auc": list}).reset_index()
211 |
212 | roc_grp = roc.groupby(["class", "fold"]).agg({'fpr': list, 'tpr': list}).reset_index()
213 | roc_grp["interp"] = roc_grp.apply(lambda row: np.interp(mean_fpr, row['fpr'], row['tpr']), axis=1)
214 | roc_grp["auc"] = roc_grp.apply(lambda row: metrics.auc(sorted(row["fpr"]),
215 | sorted(row["tpr"], reverse=True)), axis=1)
216 | roc_data = roc_grp.groupby("class").agg({"interp": list, "auc": list}).reset_index()
217 |
218 | self.folds_roc = roc_data
219 | self.folds_pr = pr_data
220 |
221 |
222 | def merge_folds_reports(self):
223 | res = self.report
224 | res = res[res['fold'] != 'micro']
225 | avg = res.drop(columns=['fold']).groupby(res.index).mean()
226 | avg["fold"] = 'AVG'
227 | res = pd.concat([res, avg], axis=0)
228 | self.report = res
229 | return res
230 |
231 | def classification_pipeline(self, label, alias='_TOPLABELS'):
232 | self.set_alias(alias)
233 | self.split_and_classify()
234 | self.merge_folds_reports()
235 | self.merge_folds()
236 | self.wrap_up(label, alias)
237 |
238 | class CVFoldsModel(FoldsModel):
239 | def __init__(self, fold2data, fold_type, **kwargs):
240 | super().__init__(**kwargs)
241 | self.name = self.name + fold_type + 'Folds'
242 | self.fold2data = fold2data
243 |
244 | def split_and_classify(self):
245 | reports, prs, rocs = [], [], []
246 | fold2data = self.fold2data
247 |
248 | y_real = []
249 | y_proba = []
250 |
251 | for fold in fold2data:
252 | X_train, X_test = fold2data[fold]['X_train'], fold2data[fold]['X_test']
253 | y_train, y_test = fold2data[fold]['y_train'], fold2data[fold]['y_test']
254 |
255 | predicted, predicted_prob = self.model_fit(X_train, X_test, y_train)
256 | fold_report, fold_pr, fold_roc = self.summarize_accuracy(y_test, predicted, predicted_prob, fold)
257 | reports.append(fold_report)
258 | prs.append(fold_pr)
259 | rocs.append(fold_roc)
260 | y_real.append(y_test)
261 | y_proba.append(predicted_prob)
262 |
263 | self.pr = pd.concat(prs)
264 | self.roc = pd.concat(rocs)
265 | self.report = pd.concat(reports)
266 |
267 | self.calc_ovelall_pr_by_folds(y_real, y_proba)
268 |
269 | return self.report, self.pr, self.roc
270 |
271 |
272 |
273 | ##### Models ######
274 | class MLClf(Model):
275 | def __init__(self, **kwargs):
276 | super().__init__(**kwargs)
277 | self.clf = xgb.XGBClassifier(n_estimators=100, max_depth=5)
278 | self.name = "XGB"
279 |
280 | class NNClf(Model):
281 | def __init__(self, **kwargs):
282 | super().__init__(**kwargs)
283 | self.name = "DNN"
284 |
285 |
286 | def set_clf(self, n):
287 | model = tf.keras.models.Sequential([tf.keras.layers.Dense(256, activation=tf.nn.relu),
288 | tf.keras.layers.Dropout(0.2),
289 | tf.keras.layers.Dense(128, activation=tf.nn.relu),
290 | tf.keras.layers.Dropout(0.2),
291 | tf.keras.layers.Dense(64, activation=tf.nn.relu),
292 | tf.keras.layers.Dense(n, activation=tf.nn.softmax)])
293 | model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
294 | self.clf = model
295 |
296 | def model_fit(self, X_train, X_test, y_train):
297 | self.set_clf(pd.Series(y_train).nunique())
298 | clf = self.clf
299 | dic_y_mapping = {n: label for n, label in
300 | enumerate(np.unique(y_train))}
301 | inverse_dic = {v: k for k, v in dic_y_mapping.items()}
302 | y_train_tag = np.array([inverse_dic[y] for y in y_train])
303 |
304 | history = clf.fit(x=X_train, y=y_train_tag, batch_size=256,
305 | epochs=20, shuffle=True, verbose=0)
306 | self.history = history
307 |
308 | predicted_prob = self.clf.predict(X_test, workers=5)
309 | predicted = [dic_y_mapping[np.argmax(pred)] for pred in
310 | predicted_prob]
311 | return predicted, predicted_prob
312 |
313 |
314 | class MLClfFolds(FoldsModel):
315 | def __init__(self, **kwargs):
316 | super().__init__(**kwargs)
317 | self.clf = xgb.XGBClassifier(n_estimators=100, max_depth=5)
318 | self.name = "XGBFolds"
319 |
320 |
321 | class NNClfFolds(FoldsModel):
322 | def __init__(self, **kwargs):
323 | super().__init__(**kwargs)
324 | self.name = "DNNFolds"
325 |
326 | def set_clf(self, n):
327 | model = tf.keras.models.Sequential([tf.keras.layers.Dense(256, activation=tf.nn.relu),
328 | tf.keras.layers.Dropout(0.2),
329 | tf.keras.layers.Dense(128, activation=tf.nn.relu),
330 | tf.keras.layers.Dropout(0.2),
331 | tf.keras.layers.Dense(64, activation=tf.nn.relu),
332 | tf.keras.layers.Dense(n, activation=tf.nn.softmax)])
333 | model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
334 | self.clf = model
335 |
336 | def model_fit(self, X_train, X_test, y_train):
337 | self.set_clf(pd.Series(y_train).nunique())
338 | clf = self.clf
339 | dic_y_mapping = {n: label for n, label in
340 | enumerate(np.unique(y_train))}
341 | inverse_dic = {v: k for k, v in dic_y_mapping.items()}
342 | y_train_tag = np.array([inverse_dic[y] for y in y_train])
343 |
344 | history = clf.fit(x=X_train, y=y_train_tag, batch_size=256,
345 | epochs=20, shuffle=True, verbose=0)
346 |
347 | self.history = history
348 |
349 | predicted_prob = self.clf.predict(X_test)
350 | predicted = [dic_y_mapping[np.argmax(pred)] for pred in
351 | predicted_prob]
352 | return predicted, predicted_prob
353 |
354 | class NNClfCVFolds(CVFoldsModel):
355 | def __init__(self, **kwargs):
356 | super().__init__(**kwargs)
357 | self.name = "DNN" + self.name
358 |
359 | def set_clf(self, n):
360 | model = tf.keras.models.Sequential([tf.keras.layers.Dense(256, activation=tf.nn.relu),
361 | tf.keras.layers.Dropout(0.2),
362 | tf.keras.layers.Dense(128, activation=tf.nn.relu),
363 | tf.keras.layers.Dropout(0.2),
364 | tf.keras.layers.Dense(64, activation=tf.nn.relu),
365 | tf.keras.layers.Dense(n, activation=tf.nn.softmax)])
366 | model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
367 | self.clf = model
368 |
369 | def model_fit(self, X_train, X_test, y_train):
370 | self.set_clf(pd.Series(y_train).nunique())
371 | clf = self.clf
372 | dic_y_mapping = {n: label for n, label in
373 | enumerate(np.unique(y_train))}
374 | inverse_dic = {v: k for k, v in dic_y_mapping.items()}
375 | y_train_tag = np.array([inverse_dic[y] for y in y_train])
376 |
377 | history = clf.fit(x=X_train, y=y_train_tag, batch_size=256,
378 | epochs=20, shuffle=True, verbose=0)
379 |
380 | self.history = history
381 |
382 | predicted_prob = self.clf.predict(X_test)
383 | predicted = [dic_y_mapping[np.argmax(pred)] for pred in
384 | predicted_prob]
385 | return predicted, predicted_prob
386 |
387 | class XGBClfCVFolds(CVFoldsModel):
388 | def __init__(self, **kwargs):
389 | super().__init__(**kwargs)
390 | self.clf = xgb.XGBClassifier()
391 | self.name = "XGB" + self.name
392 |
393 | params = {"learning_rate":0.05, "max_depth":6, "n_estimators":800}
394 | self.clf.set_params(**params)
395 |
396 | class RFClfCVFolds(CVFoldsModel):
397 | def __init__(self, **kwargs):
398 | super().__init__(**kwargs)
399 | self.clf = RandomForestClassifier(max_depth=50, min_samples_split=2, min_samples_leaf=1, n_estimators=1000)
400 | self.name = "RF" + self.name
401 |
402 | class SVMClfCVFolds(CVFoldsModel):
403 | def __init__(self, **kwargs):
404 | super().__init__(**kwargs)
405 | self.clf = SVC(kernel="rbf", C=1, gamma='auto', probability=True)
406 | self.name = "SVM" + self.name
--------------------------------------------------------------------------------
/src/genomic_embeddings/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from matplotlib.backends.backend_pdf import PdfPages
3 | from itertools import cycle
4 | import math
5 | import os
6 | from sklearn import metrics
7 | import numpy as np
8 | import pickle
9 |
10 | import matplotlib
11 | matplotlib.rcParams['pdf.fonttype'] = 42
12 | matplotlib.rcParams['ps.fonttype'] = 42
13 |
14 | class ModelPlots(object):
15 | def __init__(self, mdl):
16 | self.mdl = mdl
17 | self.roc_data = mdl.roc
18 | self.precision_data = mdl.pr
19 | self.mdl_report = mdl.report
20 | self.out_dir = mdl.out_dir
21 | self.name = mdl.name
22 |
23 | def plot_roc(self):
24 | roc_df = self.roc_data
25 | res = self.mdl_report
26 |
27 | colors = ['pink', 'turquoise', 'darkorange', 'cornflowerblue', 'teal', 'gold', 'olive','tomato', 'deeppink']
28 |
29 | with PdfPages(os.path.join(self.out_dir, f'{self.name}_ROC.pdf')) as pdf:
30 | fig, ax = plt.subplots(figsize=(5, 4))
31 |
32 | for c, cl in zip(cycle(colors), roc_df["class"].unique()):
33 | if cl == "ALL":
34 | c = "k"
35 | score = metrics.auc(roc_df[roc_df["class"] == cl]["fpr"], roc_df[roc_df["class"] == cl]["tpr"])
36 | else:
37 | score = res.groupby(res.index)['auc'].mean().loc[cl]
38 |
39 | ax.plot(roc_df[roc_df["class"] == cl]["fpr"], roc_df[roc_df["class"] == cl]["tpr"], lw=3,
40 | label="class {0} ({1:0.2f})".format(cl, score),
41 | color=c)
42 | ax.plot([0, 1], [0, 1], color='grey', lw=3, linestyle='--', alpha=0.2)
43 | ax.set(xlim=[-0.05, 1.0], ylim=[0.0, 1.05], title=f"ROC",
44 | xlabel="False Positive Rate", ylabel="True Positive rate")
45 | ax.grid(True)
46 | ax.text(.7, 0.05, f"AUC:{self.mdl.auc}", fontsize=12)
47 | plt.legend(bbox_to_anchor=(1.01, 1))
48 | pdf.savefig(transparent=True, bbox_inches="tight")
49 | plt.close()
50 |
51 | def plot_precision_recall(self):
52 | pr_df = self.precision_data
53 | res = self.mdl_report
54 |
55 | colors = ['pink', 'turquoise', 'darkorange', 'cornflowerblue', 'teal', 'gold', 'olive','tomato', 'deeppink']
56 | with PdfPages(os.path.join(self.out_dir, f'{self.name}_AUPR.pdf')) as pdf:
57 | fig, ax = plt.subplots(figsize=(5, 4))
58 | for c, cl in zip(cycle(colors), pr_df["class"].unique()):
59 |
60 | if cl == "ALL":
61 | c = "k"
62 | score = metrics.auc(pr_df[pr_df["class"] == cl]["recall"], pr_df[pr_df["class"] == cl]["precision"])
63 | else:
64 | score = res.groupby(res.index)['aupr'].mean().loc[cl]
65 |
66 | ax.plot(pr_df[pr_df["class"] == cl]["recall"], pr_df[pr_df["class"] == cl]["precision"], lw=3,
67 | label="class {0} ({1:0.2f})".format(cl, score),
68 | color=c)
69 | ax.set(xlim=[-0.05, 1.0], ylim=[0.0, 1.05], title=f"AUPR", xlabel="Recall", ylabel="Precision")
70 | ax.grid(True)
71 | ax.text(0.01, 0.05, "F1:{0:0.2f}".format(res.groupby(res.index)['f1-score'].mean().loc['weighted avg']),
72 | fontsize=12)
73 | plt.legend(bbox_to_anchor=(1.01, 1))
74 | pdf.savefig(transparent=True, bbox_inches="tight")
75 | plt.close()
76 |
77 |
78 | class FoldModelPlots(object):
79 | def __init__(self, mdl):
80 | self.mdl = mdl
81 | self.roc_data = mdl.roc
82 | self.precision_data = mdl.pr
83 | self.mdl_report = mdl.report
84 | self.outdir = mdl.out_dir
85 | self.name = mdl.name
86 |
87 | def plot_roc(self):
88 | roc_df = self.roc_data
89 | res = self.mdl_report
90 |
91 | roc_df = roc_df[roc_df['class'] != 'ALL']
92 |
93 | n_classes = roc_df["class"].nunique()
94 | classes = list(roc_df["class"].unique())
95 | COLS = 4
96 | ROWS = math.ceil(n_classes / COLS)
97 | pages = math.ceil(ROWS / 4)
98 |
99 | colors = ['pink', 'turquoise', 'darkorange', 'cornflowerblue', 'teal', 'gold', 'olive','tomato', 'deeppink']
100 |
101 | with PdfPages(os.path.join(self.outdir, f'{self.name}_ROC.pdf')) as pdf:
102 | for page in range(pages):
103 | fig, ax = plt.subplots(COLS, COLS, figsize=(20,16))
104 |
105 | i, j = 0, 0
106 | for cl in classes[page*COLS*COLS: page*COLS*COLS +COLS*COLS]:
107 | d = roc_df[roc_df["class"] == cl]
108 |
109 | for c, fold in zip(colors, d["fold"].unique()):
110 | ax[i][j].plot(d[d["fold"] == fold]["fpr"], d[d["fold"] == fold]["tpr"],
111 | lw=3, label=f"fold: {fold}", color=c)
112 | ax[i][j].plot([0, 1], [0, 1], color='grey', lw=3, linestyle='--', alpha=0.2)
113 | ax[i][j].set(xlim=[-0.05, 1.0], ylim=[0.0, 1.05], title=f"{cl}")
114 | ax[i][j].grid(True)
115 | ax[i][j].text(.6, 0.05, "AUC:{0:0.2f}".format(res.groupby(res.index)['auc'].mean().loc[cl]),
116 | fontsize=12)
117 | if j == 0:
118 | ax[i][j].set(ylabel="True Positive Rate")
119 | if i == COLS - 1:
120 | ax[i][j].set(xlabel="False Positive Rate")
121 | if j == (COLS-1):
122 | j = 0
123 | i += 1
124 | else:
125 | j += 1
126 | plt.subplots_adjust(hspace=.3, wspace=.3)
127 | _ = [ax[k][q].axis("off") for k in range(COLS) for q in range(COLS) if not ax[k][q].lines]
128 | pdf.savefig(transparent=True, bbox_inches="tight")
129 | plt.close()
130 |
131 | def plot_roc_by_fold(self):
132 | roc_df = self.roc_data
133 | res = self.mdl_report
134 |
135 | roc_df = roc_df[roc_df['class'] != 'ALL']
136 | roc_df['fold'] = roc_df['fold'].apply(lambda x: x.split('_')[-1])
137 |
138 | n_folds = roc_df["fold"].nunique()
139 | folds = list(roc_df["fold"].unique())
140 | COLS = 3
141 | ROWS = math.ceil(n_folds / COLS)
142 | pages = math.ceil(ROWS / 3)
143 |
144 | class2color = {'Amino sugar and nucleotide sugar metabolism':'DarkOrchid',
145 | 'Benzoate degradation':'darkorange', 'Energy metabolism':'cornflowerblue', 'Other':'grey',
146 | 'Oxidative phosphorylation':'gold',
147 | 'Porphyrin and chlorophyll metabolism':'teal',
148 | 'Prokaryotic defense system':'tomato', 'Ribosome':'deeppink', 'Secretion system':'pink',
149 | 'Two-component system':'turquoise', 'ALL':'k'}
150 |
151 | with PdfPages(os.path.join(self.outdir, f'{self.name}_ROC_BY_FOLD.pdf')) as pdf:
152 | for page in range(pages):
153 | fig, ax = plt.subplots(COLS, COLS, figsize=(20,16))
154 |
155 | i, j = 0, 0
156 | for fold in folds[page*COLS*COLS: page*COLS*COLS +COLS*COLS]:
157 | d = roc_df[roc_df["fold"] == fold]
158 |
159 | for cl in d["class"].unique():
160 | ax[i][j].plot(d[d["class"] == cl]["fpr"], d[d["class"] == cl]["tpr"],
161 | lw=3, label=cl, color=class2color[cl])
162 | ax[i][j].plot([0, 1], [0, 1], color='grey', lw=3, linestyle='--', alpha=0.2)
163 | ax[i][j].set(xlim=[-0.05, 1.0], ylim=[0.0, 1.05], title=f"{fold}")
164 | ax[i][j].grid(True)
165 | ax[i][j].text(.6, 0.05, "AUC:{0:0.2f}".format(res.groupby('fold')['auc'].mean().loc[f'fold_{fold}']),
166 | fontsize=12)
167 | if j == 0:
168 | ax[i][j].set(ylabel="True Positive Rate")
169 | if i == COLS - 1:
170 | ax[i][j].set(xlabel="False Positive Rate")
171 | if j == (COLS-1):
172 | j = 0
173 | i += 1
174 | else:
175 | j += 1
176 | plt.subplots_adjust(hspace=.3, wspace=.3)
177 | _ = [ax[k][q].axis("off") for k in range(COLS) for q in range(COLS) if not ax[k][q].lines]
178 | pdf.savefig(transparent=True, bbox_inches="tight")
179 | plt.close()
180 |
181 |
182 | def plot_precision_recall(self):
183 | pr_df = self.precision_data
184 | res = self.mdl_report
185 |
186 | pr_df = pr_df[pr_df['class'] != 'ALL']
187 |
188 | n_classes = pr_df["class"].nunique()
189 | classes = list(pr_df["class"].unique())
190 | COLS = 4
191 | ROWS = math.ceil(n_classes / COLS)
192 | pages = math.ceil(ROWS / 4)
193 |
194 | colors = ['pink', 'turquoise', 'darkorange', 'cornflowerblue', 'teal', 'gold', 'olive','tomato', 'deeppink']
195 |
196 | with PdfPages(os.path.join(self.outdir, f'{self.name}_AUPR.pdf')) as pdf:
197 | for page in range(pages):
198 | fig, ax = plt.subplots(COLS, COLS, figsize=(20,16))
199 |
200 | i, j = 0, 0
201 | for cl in classes[page*COLS*COLS : COLS*COLS*(page +1)]:
202 |
203 | d = pr_df[pr_df["class"] == cl]
204 |
205 | for c, fold in zip(colors, d["fold"].unique()):
206 | ax[i][j].plot(d[d["fold"] == fold]["recall"], d[d["fold"] == fold]["precision"], lw=3,
207 | label=f"fold: {fold}", color=c
208 | )
209 | ax[i][j].set(xlim=[-0.05, 1.0], ylim=[0.0, 1.05], title=f"{cl}")
210 | ax[i][j].grid(True)
211 | ax[i][j].text(0.01, 0.05, "AUPR:{0:0.2f}".format(res.groupby(res.index)['aupr'].mean().loc[cl]),
212 | fontsize=12)
213 | if j == 0:
214 | ax[i][j].set(ylabel="Precision")
215 | if i == COLS - 1:
216 | ax[i][j].set(xlabel="Recall")
217 |
218 | if j == (COLS-1):
219 | j = 0
220 | i += 1
221 | else:
222 | j += 1
223 | plt.subplots_adjust(hspace=.3, wspace=.3)
224 | _ = [ax[k][q].axis("off") for k in range(COLS) for q in range(COLS) if not ax[k][q].lines]
225 | pdf.savefig(transparent=True, bbox_inches="tight")
226 | plt.close()
227 |
228 |
229 | def plot_precision_recall_by_fold(self):
230 | pr_df = self.precision_data
231 | res = self.mdl_report
232 |
233 | pr_df = pr_df[pr_df['class'] != 'ALL']
234 | pr_df['fold'] = pr_df['fold'].apply(lambda x: x.split('_')[-1])
235 |
236 | n_folds = pr_df["fold"].nunique()
237 | folds = list(pr_df["fold"].unique())
238 | COLS = 3
239 | ROWS = math.ceil(n_folds / COLS)
240 | pages = math.ceil(ROWS / 3)
241 |
242 | class2color = {'Amino sugar and nucleotide sugar metabolism':'DarkOrchid',
243 | 'Benzoate degradation':'darkorange', 'Energy metabolism':'cornflowerblue', 'Other':'grey',
244 | 'Oxidative phosphorylation':'gold',
245 | 'Porphyrin and chlorophyll metabolism':'teal',
246 | 'Prokaryotic defense system':'tomato', 'Ribosome':'deeppink', 'Secretion system':'pink',
247 | 'Two-component system':'turquoise', 'ALL':'k'}
248 |
249 |
250 | with PdfPages(os.path.join(self.outdir, f'{self.name}_AUPR_BY_FOLD.pdf')) as pdf:
251 | for page in range(pages):
252 | fig, ax = plt.subplots(COLS, COLS, figsize=(20,16))
253 |
254 | i, j = 0, 0
255 | for fold in folds[page*COLS*COLS : COLS*COLS*(page +1)]:
256 |
257 | d = pr_df[pr_df["fold"] == fold]
258 | d = d[~((d['precision'] == 0) & (d['recall'] == 0))]
259 |
260 | for cl in d["class"].unique():
261 | ax[i][j].plot(d[d["class"] == cl]["recall"], d[d["class"] == cl]["precision"], lw=3,
262 | label=cl, color=class2color[cl]
263 | )
264 | ax[i][j].set(xlim=[-0.05, 1.0], ylim=[0.0, 1.05], title=f"{fold}")
265 | ax[i][j].grid(True)
266 | ax[i][j].text(0.01, 0.05, "AUPR:{0:0.2f}".format(res.groupby('fold')['aupr'].mean().loc[f'fold_{fold}']),
267 | fontsize=12)
268 | if j == 0:
269 | ax[i][j].set(ylabel="Precision")
270 | if i == COLS - 1:
271 | ax[i][j].set(xlabel="Recall")
272 |
273 | if j == (COLS-1):
274 | j = 0
275 | i += 1
276 | else:
277 | j += 1
278 | plt.subplots_adjust(hspace=.3, wspace=.3)
279 | _ = [ax[k][q].axis("off") for k in range(COLS) for q in range(COLS) if not ax[k][q].lines]
280 | pdf.savefig(transparent=True, bbox_inches="tight")
281 | plt.close()
282 |
283 |
284 | def plot_single_aupr_with_ci(self):
285 | data = self.mdl.folds_pr
286 | overall_data = self.mdl.mean_pr
287 | overall_data = overall_data[~((overall_data['precision'] == 0) & (overall_data['recall'] == 0))]
288 | class2aupr = {}
289 | class2color = {'Amino sugar and nucleotide sugar metabolism':'DarkOrchid',
290 | 'Benzoate degradation':'darkorange', 'Energy metabolism':'cornflowerblue', 'Other':'grey',
291 | 'Oxidative phosphorylation':'gold',
292 | 'Porphyrin and chlorophyll metabolism':'teal',
293 | 'Prokaryotic defense system':'tomato', 'Ribosome':'deeppink', 'Secretion system':'pink',
294 | 'Two-component system':'turquoise', 'ALL':'k'}
295 |
296 |
297 | with PdfPages(os.path.join(self.outdir, f'{self.name}_AUPR_CI.pdf')) as pdf:
298 | fig, ax = plt.subplots(figsize=(5, 4))
299 | for cl in data["class"].unique():
300 |
301 | class_data = data[data["class"] == cl]
302 | tprs = class_data['interp'].tolist()[0]
303 | mean_precision = np.mean(tprs, axis=0)
304 | mean_recall = np.linspace(0, 1, mean_precision.shape[0])
305 | mean_auc = metrics.auc(mean_recall, mean_precision)
306 |
307 | if cl == 'ALL':
308 | ax.plot(mean_recall, mean_precision, color=class2color[cl], label="{0} ({1:0.2f})".format(cl, mean_auc), lw=3, alpha=.8)
309 | else:
310 | pr_data = overall_data[overall_data['class'] == cl]
311 | mean_auc = metrics.auc(pr_data['recall'], pr_data['precision'])
312 | ax.plot(pr_data['recall'], pr_data['precision'], color=class2color[cl],label="{0} ({1:0.2f})".format(cl, mean_auc), lw=3, alpha=.8)
313 |
314 | mean_precision = np.mean(tprs)
315 | std_precision = np.std(tprs, axis=0)
316 | ax.fill_between(pr_data['recall'], pr_data['precision'] + std_precision, pr_data['precision'] - std_precision, color=class2color[cl], alpha=.1)
317 |
318 | class2aupr[cl] = mean_auc
319 | ax.set(xlim=[-0.05, 1.0], ylim=[0.0, 1.05], title=f"AUPR", xlabel="Recall", ylabel="Precision")
320 | ax.grid(True)
321 | plt.legend(bbox_to_anchor=(1.01, 1))
322 | pdf.savefig(transparent=True, bbox_inches="tight")
323 | plt.close()
324 | with open(os.path.join(self.outdir, f'{self.name}_AUPR_CI.pkl'), 'wb') as o:
325 | pickle.dump(class2aupr, o)
326 |
327 | def plot_single_roc_with_ci(self):
328 | data = self.mdl.folds_roc
329 |
330 | class2color = {'Amino sugar and nucleotide sugar metabolism':'DarkOrchid',
331 | 'Benzoate degradation':'darkorange', 'Energy metabolism':'cornflowerblue', 'Other':'grey',
332 | 'Oxidative phosphorylation':'gold',
333 | 'Porphyrin and chlorophyll metabolism':'teal',
334 | 'Prokaryotic defense system':'tomato', 'Ribosome':'deeppink', 'Secretion system':'pink',
335 | 'Two-component system':'turquoise', 'ALL':'k'}
336 |
337 | with PdfPages(os.path.join(self.outdir, f'{self.name}_ROC_CI.pdf')) as pdf:
338 | fig, ax = plt.subplots(figsize=(5, 4))
339 | for cl in data["class"].unique():
340 |
341 | class_data = data[data["class"] == cl]
342 | tprs = class_data['interp'].tolist()[0]
343 | mean_tpr = np.mean(tprs, axis=0)
344 | mean_fpr = np.linspace(0, 1, 100)
345 | mean_auc = metrics.auc(mean_fpr, mean_tpr)
346 |
347 | ax.plot(mean_fpr, mean_tpr, color=class2color[cl],
348 | label="{0} ({1:0.2f})".format(cl, mean_auc), lw=3, alpha=.8)
349 |
350 | std_tpr = np.std(tprs, axis=0)
351 | tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
352 | tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
353 | ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color=class2color[cl], alpha=.1)
354 |
355 | ax.set(xlim=[-0.05, 1.0], ylim=[0.0, 1.05], title=f"ROC", xlabel="True Positive Rate",
356 | ylabel="False Positive Rate")
357 | ax.grid(True)
358 | plt.legend(bbox_to_anchor=(1.01, 1))
359 | pdf.savefig(transparent=True, bbox_inches="tight")
360 | plt.close()
--------------------------------------------------------------------------------