├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── MODELCARD.md ├── README.md ├── config.json ├── config.yaml ├── generate_reconstructions.ipynb ├── huggingface_mae.py ├── loss.py ├── mae_modules.py ├── mae_utils.py ├── masking.py ├── normalizer.py ├── pyproject.toml ├── sample ├── AA41_s1_1.jp2 ├── AA41_s1_2.jp2 ├── AA41_s1_3.jp2 ├── AA41_s1_4.jp2 ├── AA41_s1_5.jp2 └── AA41_s1_6.jp2 ├── test_huggingface_mae.py ├── vit.py └── vit_encoder.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # model artifacts 30 | *.pickle 31 | *.ckpt 32 | *.safetensors -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "flake8.args": [ 3 | "--max-line-length=120" 4 | ] 5 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | CONFIDENTIAL 2 | 3 | 4 | Recursion Pharmaceuticals, Inc. 5 | Non-Commercial End User License Agreement 6 | 7 | 8 | 1. INTRODUCTION. 9 | This Non-Commercial End User License Agreement (as may be revised from time to time, this 10 | “Agreement”) is a binding agreement between You (as defined below) and Recursion Pharmaceuticals, 11 | Inc., a Delaware corporation with offices located at 41 S. Rio Grande St., Salt Lake City, UT 84101 (“We,” 12 | “Us,” or “Our”). This Agreement grants You a license to Use (as defined below) certain Licensed 13 | Materials (as defined below) subject to Your acceptance of all terms contained in this Agreement. While 14 | this Agreement is not a Creative Commons license, it incorporates certain core principles thereof, 15 | including attribution, non-commercial, and ShareAlike (similar to CC BY-NC-SA). 16 | Please read the terms of this Agreement carefully before Using any of the Licensed Materials. By 17 | Using any of the Licensed Materials or by clicking to accept or agree to the terms of this Agreement, You 18 | agree that You have read and understand the terms of this Agreement, and further agree to accept and 19 | agree to comply with the terms of this Agreement. You represent that You are at least 18 years of age, 20 | and if You are accessing or using the Licensed Materials on behalf of an entity, that You have the legal 21 | authority to enter into this Agreement on that entity’s behalf. If You do not agree to the terms of this 22 | Agreement, then You must not Use any Licensed Materials and You should click to reject or not agree to 23 | the terms of this Agreement. 24 | We may revise this Agreement from time to time, for any reason. Any change to this Agreement 25 | will be effective immediately upon posting unless We state otherwise. You should check this Agreement 26 | on the Site regularly. Your continued Use of the Licensed Materials after any changes to this Agreement 27 | constitutes Your binding acceptance of this Agreement as revised, including such changes. 28 | 2. DEFINITIONS. 29 | “Derivative Technology” means any product or technology generated, conceived, developed, or 30 | reduced to practice through Your Use of, or derived from or based on, any Licensed Material. For clarity, 31 | “Derivative Technology” includes any modified Recursion Dataset. 32 | “Intellectual Property Rights” means all intellectual property and proprietary rights of any kind, 33 | however denominated, throughout the world, including all rights in patents, patent applications, 34 | copyrights, trademarks, trade secrets, designs, inventions, works of authorship, software (including 35 | source code and object code), documentation, know-how, methods, processes, algorithms, data and 36 | databases, and all updates, upgrades, new versions, and enhancements of any and all of the foregoing, 37 | and all registrations and applications for any and all of the foregoing. 38 | “Licensed Intellectual Property Rights” means copyrights and similar rights closely related to 39 | copyrights, including rights in software, data, and databases, (a) owned or otherwise controlled by Us 40 | and (b) necessary for You to exercise Your rights under, and in strict accordance with the terms of, this 41 | Agreement. “Licensed Intellectual Property Rights” does not include any other Intellectual Property 42 | Rights, including patent rights, trademark rights, moral rights, or publicity, privacy, or other similar 43 | personality rights. 44 | “Licensed Materials” means the Recursion Software and/or Recursion Dataset, as applicable, to 45 | which We apply this Agreement. For clarity, references to the “Licensed Materials” in this Agreement 46 | include any portion thereof. 47 | 48 | 49 | 50 | “Permitted Purpose” means non-commercial research, academic, and educational purposes 51 | only. For the purposes of this definition, “non-commercial research” means research not primarily 52 | intended for or directed towards commercial advantage or monetary compensation. 53 | “Recursion Dataset” means the data and datasets (which may include, without limitation, 54 | phenomics maps, images, image data, embeddings, genetic information, and other metadata), in each 55 | case, made available to You through the Site, including through or otherwise in connection with the 56 | Recursion Software, but excluding, for clarity, the RxRx19x dataset. 57 | “Recursion Software” means Recursion’s proprietary software (including, without limitation, 58 | Recursion’s proprietary AI Models) made available to You through the Site, including any updates or 59 | upgrades thereto and any written documentation or other media related thereto made available to You 60 | through the Site, which may include, without limitation, the MolRec™ application. The Recursion 61 | Software will not be provided in source code format. 62 | “Site” means www.rxrx.ai, together with its subdomains. 63 | “Use” (and its correlatives) means (a) use, download, and access, and (b) solely with respect to 64 | the Recursion Dataset, copy, analyze, modify, adapt, aggregate, share, and use to produce Derivative 65 | Technology. 66 | “You” (and its correlatives) means the individual(s) or entity(ies) that Use the Licensed Materials 67 | under this Agreement. If you are Using the Licensed Materials in your individual capacity, all references 68 | to “You” reference you as an individual person. If you are Using the Licensed Materials on behalf of a 69 | company or other entity, all references to “You” reference both you as an individual person and that 70 | company or entity. 71 | 3. LICENSE GRANT. 72 | Subject to Your compliance with the terms of this Agreement, We grant to You a personal, 73 | limited, non-exclusive, non-transferable, non-sublicensable, royalty-free, irrevocable (except as set forth 74 | below) license under the Licensed Intellectual Property Rights to Use the Licensed Materials solely for 75 | the Permitted Purpose. 76 | For clarity, and without limiting the generality of the foregoing, You may not, in any and all fields: 77 | (a) sell, lease, rent, lend, license, sublicense, assign, distribute, share, publish, transfer, or 78 | otherwise make available the Licensed Materials or Derivative Technology to any individual or entity for 79 | monetary compensation; 80 | (b) Use the Licensed Materials or Derivative Technology, in each case, to initiate or conduct, 81 | either for Yourself, Your affiliates, or a third party, a program directed to the research, development, 82 | manufacture, commercialization, or exploitation of any product (including any pharmaceutical, biologic, 83 | or diagnostic product) or service that is, or if successful ultimately would be, intended for commercial 84 | sale, distribution, or offering, including validating a biological target in connection with the foregoing 85 | activities (collectively, a “Commercial Program”); 86 | (c) Use the Licensed Materials or Derivative Technology, in each case, to directly or 87 | indirectly research, develop, commercialize, or exploit any software, model, algorithm, platform, or 88 | artificial intelligence (collectively, “AI Models”) that is, or if successful ultimately would be, intended for 89 | commercial sale, distribution, or offering; 90 | 91 | 92 | 93 | 94 | 95 | (d) deploy any AI Model trained on the Licensed Materials or Derivative Technology, in each 96 | case, for the purpose of initiating or conducting, either for Yourself, Your affiliates, or a third party, any 97 | Commercial Program; 98 | (e) Use the Licensed Materials or Derivative Technology (including any AI Model trained on 99 | Licensed Materials), in each case, for the sale, offer for sale, or performance of commercial services; 100 | (f) engage in, or advise in the engaging of, any trading of securities using or based on the 101 | Licensed Materials or Derivative Technology; or 102 | (g) publish any article or other document, or deliver any presentation for monetary 103 | compensation that is based on Your Use of the Licensed Materials or Derivative Technology (for clarity, 104 | this sub-clause (g) will not prohibit You from publishing or presenting any article, document, or 105 | presentation that You author or present Yourself in any medium or format so long as You do not directly 106 | or indirectly receive any monetary compensation for such publication or presentation). 107 | If You wish to Use the Licensed Materials or Derivative Technology for any purpose not permitted 108 | by this Agreement, please contact Us to discuss such Use – a commercial license may be available. Any 109 | such commercial Use by You (to the extent approved by Us) will be subject to separate commercial 110 | licensing terms, and We will retain sole discretion whether or not to agree to any such Use and grant 111 | such license (including the applicable terms thereof). 112 | 4. ATTRIBUTION REQUIREMENTS. 113 | You must include an attribution to Us in the applicable form set forth below when citing any 114 | Recursion Dataset or any Recursion Software constituting an AI Model: 115 | For any Recursion Dataset: “We used the [insert the name of the dataset (e.g., RxRx3)] dataset, 116 | available from Recursion Pharmaceuticals at www.rxrx.ai, pursuant to Recursion 117 | Pharmaceutical’s licensing terms at [insert hyperlink to this Agreement]. Under this license, 118 | Recursion Pharmaceuticals disclaims all representations and warranties with respect to such 119 | dataset.” 120 | For any Recursion Software constituting an AI Model: “We used the [insert the name of AI 121 | Model] AI model, available from Recursion Pharmaceuticals at www.rxrx.ai, pursuant to 122 | Recursion Pharmaceutical’s licensing terms at [insert hyperlink to this Agreement]. Under this 123 | license, Recursion Pharmaceuticals disclaims all representations and warranties with respect to 124 | such AI model.” 125 | You should insert the information specified in brackets above, and delete such brackets, when including 126 | such attribution. 127 | In addition, You must indicate whether You modified the applicable Licensed Material, or 128 | otherwise used any Licensed Material to create any Derivative Technology, and if so, indicate that such 129 | Derivative Technology was created using such Licensed Material, and retain any indication of the 130 | foregoing previously made by other individuals or entities. 131 | If We request, You must remove any of the information required above to the extent reasonably 132 | practicable. Nothing in this Agreement constitutes or may be construed as permission to assert or imply 133 | that You are, or that Your Use of the Licensed Materials or Derivative Technology is, connected with, or 134 | sponsored, endorsed, or granted official status by, Us. 135 | 5. ACCEPTABLE USE TERMS 136 | You will not, and will not permit or encourage any other individual or entity to: 137 | 138 | 139 | 140 | (a) reverse engineer, disassemble, decompile, decode, adapt, or otherwise attempt to 141 | derive, recreate, or gain access to the source code of the Recursion Software, in whole or in part; 142 | (b) except as expressly permitted by Section 3 (License Grant), modify, adapt, or create 143 | derivative works or improvements of the Licensed Materials; 144 | (c) except as expressly permitted by Section 3 (License Grant), sell, lease, rent, lend, license, 145 | sublicense, assign, distribute, share, publish, transfer, or otherwise make available the Licensed Materials 146 | to any individual or entity; 147 | (d) Use the Licensed Materials in any manner or for any purpose that infringes, 148 | misappropriates, or otherwise violates any Intellectual Property Right of any individual or entity; 149 | (e) remove, delete, alter, or obscure any trademarks or any copyright, trademark, patent, or 150 | other Intellectual Property Right notices from the Licensed Materials, including any copy thereof; 151 | (f) Use the Licensed Materials to violate any national or international law, statute, decree, 152 | rule, or regulation; 153 | (g) attempt to interfere with the proper working of the Recursion Software, or remove, 154 | disable, circumvent, or otherwise create or implement any workaround to any security or technological 155 | measures for the Licensed Materials, including any measures that control access to the Licensed 156 | Materials; 157 | (h) disrupt or interfere with the Recursion Software or Our systems, servers, or networks, or 158 | fail to comply with any requirements, procedures, policies, or regulations of networks connected to the 159 | Recursion Software, or transmit any viruses, worms, defects, Trojan horses, spyware, malware, 160 | ransomware, or any items of a destructive nature through Your Use of the Recursion Software; or 161 | (i) Use the Licensed Materials in any abusive or illegal way, as determined in Our sole 162 | discretion. 163 | 6. INTELLECTUAL PROPERTY RIGHTS. 164 | You acknowledge and agree that the Licensed Materials are provided under license, and not 165 | sold, to You. You acknowledge and agree that the Licensed Intellectual Property Rights are proprietary to 166 | Us, and the Licensed Materials are protected under copyright and other Intellectual Property Rights 167 | owned or controlled by Us. We own and retain ownership of all Our Intellectual Property Rights, 168 | including all rights, title, and interests in and to the Licensed Materials (including any portion thereof 169 | that may be incorporated into any Derivative Technology). Under applicable law, Your separate 170 | contribution to any Derivative Technology may be subject to Intellectual Property Rights owned or 171 | controlled by You (“Arising Intellectual Property Rights”). 172 | All rights not expressly granted to You herein are reserved for Us. Except for the limited license 173 | granted to You herein, this Agreement does not grant You any ownership or other rights or interests in or 174 | to the Licensed Materials or Licensed Intellectual Property Rights, whether by implication, estoppel, or 175 | otherwise. 176 | 7. SHARING LICENSED MATERIALS. 177 | Every individual or entity with whom You share the Recursion Dataset (including any portion of 178 | the Recursion Dataset incorporated into any Derivative Technology) automatically receives an offer from 179 | Us to Use such Recursion Dataset or portion thereof, as applicable, under the terms of this Agreement. 180 | If You share any Derivative Technology with any individual(s) or entity(ies), then the license You 181 | apply to Your Arising Intellectual Property Rights in such Derivative Technology must be essentially the 182 | 183 | 184 | 185 | equivalent of this Agreement, and for the avoidance of doubt, must not permit any Use of such 186 | Derivative Technology for any purpose other than a Permitted Purpose. 187 | If You share any Recursion Dataset or Derivative Technology, You may not offer or impose on any 188 | recipient of the Recursion Dataset or Derivative Technology any additional or different terms or 189 | conditions, or apply any technological measures to, the recipient’s use of the Recursion Dataset or 190 | Derivative Technology if doing so restricts such recipient from Using the Recursion Dataset or Derivative 191 | Technology to the same extent as is permitted under this Agreement. 192 | 8. UPDATES. 193 | We will have no obligation to provide upgrades or updates to the Licensed Materials. You 194 | acknowledge that You may be required on a periodic or as-needed basis to apply updates to or 195 | re-download and re-install the Recursion Software to address security, interoperability, or performance 196 | issues, or to incorporate new features. You will promptly apply such updates to, or download and install, 197 | as applicable, all such updates or upgrades, and acknowledge and agree that the Licensed Materials or 198 | portions thereof may not properly operate should You fail to do so. We may also modify or delete in 199 | their entirety certain features and functionality of the Licensed Materials, and You agree that We have 200 | no obligation to continue to provide the Licensed Materials or enable any particular features or 201 | functionality thereof. 202 | 9. THIRD-PARTY MATERIALS. 203 | The Licensed Materials may display, include, or make available third-party content and 204 | functionality (including data, information, applications, and other products, services, or materials), or 205 | provide links to third-party websites or services (“Third-Party Materials”). You acknowledge and agree 206 | that We are not responsible for Third-Party Materials, including their accuracy, completeness, timeliness, 207 | validity, copyright compliance, legality, decency, quality, or any other aspect thereof. We do not assume 208 | and will not have any liability or responsibility to You or any other individual or entity for any Third-Party 209 | Materials. Third-Party Materials and links thereto are provided solely as a convenience to You, and You 210 | will access and use them entirely at Your own risk and subject to such third party’s terms and conditions. 211 | 10. PRIVACY POLICY. 212 | You acknowledge that when You Use any of the Licensed Materials, We may use automatic 213 | means (including, for example, cookies and web beacons) to collect information about Your electronic 214 | device and about Your use of the Licensed Materials. You also may be required to provide certain 215 | information about Yourself as a condition to Using the Licensed Materials, or certain of their features or 216 | functionality. All information We collect through or in connection with the Licensed Materials is subject 217 | to Our Privacy Policy at https://www.recursion.com/privacy-notice (the “Privacy Policy”), which is 218 | incorporated herein by reference. By Using the Licensed Materials, You consent to all actions taken by Us 219 | with respect to Your information in compliance with the Privacy Policy. 220 | 11. TERM AND TERMINATION. 221 | The term of this Agreement (“Term”) commences when You download the Recursion Software or 222 | otherwise Use any Licensed Materials, and continues for the term of the Licensed Intellectual Property 223 | Rights unless otherwise earlier terminated. 224 | Your rights under this Agreement terminate automatically if You fail to comply with this 225 | Agreement. Where Your right to Use the Licensed Materials has terminated as provided in the 226 | immediately preceding sentence, Your right reinstates (a) automatically as of the date the violation is 227 | cured, provided it is cured within 30 days of Your discovery of the violation, or (b) upon express 228 | 229 | 230 | 231 | 232 | reinstatement by Us. However, this paragraph does not affect any right that We may have to seek 233 | remedies for Your violation of this Agreement. 234 | For the avoidance of doubt, We may also offer the Licensed Material under separate terms or 235 | conditions, or stop distributing or making the Licensed Materials available at any time; however, doing so 236 | will not terminate this Agreement. 237 | Upon termination of this Agreement: (i) all licenses and other rights granted to You under this 238 | Agreement will terminate; (ii) You will immediately cease all use of the Licensed Materials, and will 239 | delete or otherwise destroy, at Your cost, all Licensed Materials (including, for clarity, all copies thereof), 240 | provided that You may continue practicing Your Arising Intellectual Property Rights in any Derivative 241 | Technology so long as You do not Use the Licensed Materials (including any portion thereof incorporated 242 | into the Derivative Technology); and (iii) the provisions of this Agreement which by their nature must 243 | survive termination of this Agreement will continue in force upon any termination, including, but not 244 | limited to, Your obligations relating to Intellectual Property Rights, disclaimer of warranties, limitation of 245 | liability, effects of termination, and the general provisions. 246 | 12. DISCLAIMER OF WARRANTIES. 247 | THE LICENSED MATERIALS, INCLUDING ANY THIRD-PARTY MATERIALS PROVIDED THEREIN, ARE 248 | BEING PROVIDED “AS IS,” WITH ALL FAULTS AND DEFECTS, AND WITHOUT REPRESENTATIONS OR 249 | WARRANTIES OF ANY KIND. TO THE MAXIMUM EXTENT PERMITTED UNDER APPLICABLE LAW, WE, ON 250 | OUR OWN BEHALF AND ON BEHALF OF OUR AFFILIATES AND OUR AND THEIR RESPECTIVE DIRECTORS, 251 | OFFICERS, EMPLOYEES, PARTNERS, LICENSORS, AGENTS, SUCCESSORS, AND ASSIGNS, EXPRESSLY 252 | DISCLAIM ALL REPRESENTATIONS AND WARRANTIES, WHETHER EXPRESS, IMPLIED, STATUTORY, OR 253 | OTHERWISE, WITH RESPECT TO THE LICENSED MATERIALS, INCLUDING ALL IMPLIED WARRANTIES OF 254 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, AND NON-INFRINGEMENT, AND 255 | WARRANTIES THAT MAY ARISE OUT OF COURSE OF DEALING, COURSE OF PERFORMANCE, USAGE, OR 256 | TRADE PRACTICE. WITHOUT LIMITATION TO THE FOREGOING, WE PROVIDE NO WARRANTY OR 257 | UNDERTAKING, AND MAKE NO REPRESENTATION OF ANY KIND THAT THE LICENSED MATERIALS WILL 258 | MEET YOUR REQUIREMENTS, ACHIEVE ANY INTENDED RESULTS, BE COMPATIBLE, OR WORK WITH ANY 259 | OTHER SOFTWARE, APPLICATIONS, SYSTEMS, OR SERVICES, OPERATE WITHOUT INTERRUPTION, MEET 260 | ANY PERFORMANCE OR RELIABILITY STANDARDS, OR BE ERROR-FREE, OR THAT ANY ERRORS OR DEFECTS 261 | CAN OR WILL BE CORRECTED. WE DO NOT ENDORSE OR REPRESENT OR GUARANTEE THE 262 | TRUTHFULNESS, ACCURACY, OR RELIABILITY OF ANY LICENSED MATERIALS. YOU ACCEPT THE ENTIRE 263 | RISK OF THE ACCURACY, RELIABILITY, SECURITY, OR OTHER PERFORMANCE WITH RESPECT TO YOUR USE 264 | OF THE LICENSED MATERIALS OR OTHER EXERCISE OF YOUR RIGHTS UNDER THIS AGREEMENT, 265 | INCLUDING YOUR DEVELOPMENT OR USE OF ANY DERIVATIVE TECHNOLOGY. THIS DISCLAIMER OF 266 | WARRANTIES WILL BE INTERPRETED IN A MANNER THAT, TO THE FULLEST EXTENT PERMITTED BY 267 | APPLICABLE LAW, MOST CLOSELY APPROXIMATES AN ABSOLUTE DISCLAIMER OF WARRANTIES. 268 | 13. LIMITATION OF LIABILITY. 269 | TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT WILL WE OR ANY OF 270 | OUR AFFILIATES, OR ANY OF OUR OR THEIR RESPECTIVE DIRECTORS, OFFICERS, EMPLOYEES, PARTNERS, 271 | LICENSORS, AGENTS, SUCCESSORS, OR ASSIGNS, HAVE ANY LIABILITY FOR ANY DIRECT, SPECIAL, 272 | INDIRECT, INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, COSTS, EXPENSES, 273 | OR DAMAGES ARISING OUT OF THIS AGREEMENT OR YOUR USE OF THE LICENSED MATERIALS OR OTHER 274 | EXERCISE OF YOUR RIGHTS UNDER THIS AGREEMENT, INCLUDING YOUR DEVELOPMENT OF ANY 275 | DERIVATIVE TECHNOLOGY. THE FOREGOING LIMITATION WILL APPLY WHETHER SUCH LOSSES, COSTS, 276 | EXPENSES, OR DAMAGES ARISE OUT OF BREACH OF CONTRACT, TORT (INCLUDING NEGLIGENCE), OR 277 | 278 | 279 | 280 | 281 | OTHERWISE AND REGARDLESS OF WHETHER SUCH DAMAGES WERE FORESEEABLE OR WE WERE 282 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THIS LIMITATION OF LIABILITY WILL BE INTERPRETED 283 | IN A MANNER THAT, TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, MOST CLOSELY 284 | APPROXIMATES AN ABSOLUTE WAIVER OF ALL LIABILITY. 285 | 14. INDEMNIFICATION. 286 | You agree to indemnify, defend, and hold harmless Us, Our affiliates, and Our and their 287 | respective officers, directors, employees, partners, licensors, agents, successors, and assigns from and 288 | against any and all losses, damages, liabilities, deficiencies, claims, actions, judgments, settlements, 289 | interest, awards, penalties, fines, costs, or expenses of whatever kind, including reasonable attorneys’ 290 | fees, arising from or relating to Your use of the Licensed Materials or other exercise of Your rights under 291 | this Agreement (including Your development of any Derivative Technology), Your access to or use of any 292 | Third-Party Material, Your breach of any term of this Agreement, or Your violation of any law or right of a 293 | third party (including any Intellectual Property Rights of a third party). 294 | 15. GENERAL PROVISIONS. 295 | US Government Rights. The Licensed Materials include commercial computer software, as such 296 | term is defined in 48 C.F.R. §2.101. Accordingly, if You are an agency of the US Government or any 297 | contractor therefore, You receive only those rights with respect to the Licensed Materials as are granted 298 | to all other end users under license, in accordance with (a) 48 C.F.R. §227.7201 through 48 C.F.R. 299 | §227.7204, with respect to the Department of Defense and their contractors, or (b) 48 C.F.R. §12.212, 300 | with respect to all other US Government licensees and their contractors. 301 | Export Regulation. The Licensed Material or Derivative Technology may be subject to US export 302 | control laws, including the Export Control Reform Act and its associated regulations. You will not, 303 | directly or indirectly, export, re-export, or release the Licensed Material or any Derivative Technology to, 304 | or make the Licensed Material or any Derivative Technology accessible from, any jurisdiction or country 305 | to which export, re-export, or release is prohibited by law, rule, or regulation. You will comply with all 306 | applicable federal laws, regulations, and rules, and complete all required undertakings (including 307 | obtaining any necessary export license or other governmental approval), prior to exporting, re-exporting, 308 | releasing, or otherwise making the Licensed Material or any Derivative Technology available outside the 309 | United States. 310 | Assignment. You may not assign this Agreement or any of your rights or obligations hereunder 311 | without Our prior written consent and any attempt to do so without such consent will cause this 312 | Agreement and any of Your rights hereunder to be null and void. We may assign this Agreement or any 313 | of Our rights or obligations hereunder without Your consent. 314 | Governing Law; Venue. This Agreement will be governed by and construed in accordance with 315 | the laws of the State of Utah, United States, without giving effect to any choice of law provision or rule 316 | that would cause the application of laws of any other jurisdiction and without regard to the United 317 | Nations Convention on Contracts for the International Sale of Goods. You irrevocably agree that the 318 | state and federal courts in the County of Salt Lake, Utah, United States, will have exclusive jurisdiction to 319 | settle any dispute or claim arising out of or in connection with this Agreement, submit to the jurisdiction 320 | of such courts, and consent to venue in such forum with respect to any action or proceeding that relates 321 | to this Agreement. If We are the prevailing party in any action to enforce this Agreement, then We will 322 | be entitled to recover Our reasonable costs and expenses in connection with such action, including 323 | reasonable attorneys’ fees. 324 | 325 | 326 | 327 | 328 | 329 | Equitable Relief. You acknowledge and agree that the restrictions set forth in this Agreement 330 | are reasonable and necessary to protect Our legitimate interests, and that We would not have entered 331 | into this Agreement in the absence of such restrictions, and that any breach or threatened breach by You 332 | of any provision of this Agreement will result in irreparable injury to Us, for which there will be no 333 | adequate remedy at law. In the event of any breach or threatened breach by You of any provision of this 334 | Agreement, We will be authorized and entitled to obtain from any court of competent jurisdiction 335 | injunctive relief, whether preliminary or permanent, specific performance, and an equitable accounting 336 | of all earnings, profits, and other benefits arising from such breach, which rights will be cumulative and 337 | in addition to any other rights or remedies to which We may be entitled at law or in equity. You waive 338 | any requirement that We post a bond or other security as a condition for obtaining any such relief, or 339 | show irreparable harm, balancing of harms, consideration of the public interest, or inadequacy of 340 | monetary damages as a remedy. 341 | Section Titles. The section titles and headers are for convenience or reference only and in no 342 | way define, limit, or affect the scope or substance of any section of this Agreement. 343 | Entire Agreement. Other than the Privacy Policy and any commercial agreement that You have 344 | executed with Us in relation to the Licensed Materials, this Agreement constitutes the entire agreement 345 | between You and Us with respect to the Licensed Materials. 346 | Severability. If any provision of this Agreement is held to be unenforceable for any reason, then 347 | such provision will be reformed only to the extent necessary to make it enforceable, and such holding 348 | will not impair the validity, legality, or enforceability of the remaining provisions. 349 | Waiver. No delay or omission by Us in exercising any right under this Agreement will operate as 350 | a waiver of that or any other right. A waiver or consent given by Us on any one occasion will be effective 351 | only in that instance and will not be construed as a bar or waiver of any right on any other occasion. 352 | English Language. This Agreement is in the English language only, which language will be 353 | controlling and any revision of this Agreement in any other language will not be binding. 354 | Questions, Comments, and Concerns. All requests for technical support, and other 355 | communications relating to the Licensed Materials or the subject matter of this Agreement, including 356 | questions, inquiries, and concerns, should be directed to info@rxrx.ai. 357 | 358 | 359 | 360 | -------------------------------------------------------------------------------- /MODELCARD.md: -------------------------------------------------------------------------------- 1 | --- 2 | library_name: transformers 3 | tags: [] 4 | --- 5 | 6 | # Model Card for Phenom CA-MAE-S/16 7 | 8 | Channel-agnostic image encoding model designed for microscopy image featurization. 9 | The model uses a vision transformer backbone with channelwise cross-attention over patch tokens to create contextualized representations separately for each channel. 10 | 11 | 12 | ## Model Details 13 | 14 | ### Model Description 15 | 16 | This model is a [channel-agnostic masked autoencoder](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html) trained to reconstruct microscopy images over three datasets: 17 | 1. RxRx3 18 | 2. JUMP-CP overexpression 19 | 3. JUMP-CP gene-knockouts 20 | 21 | - **Developed, funded, and shared by:** Recursion 22 | - **Model type:** Vision transformer CA-MAE 23 | - **Image modality:** Optimized for microscopy images from the CellPainting assay 24 | - **License:** 25 | 26 | 27 | ### Model Sources 28 | 29 | - **Repository:** [https://github.com/recursionpharma/maes_microscopy](https://github.com/recursionpharma/maes_microscopy) 30 | - **Paper:** [Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html) 31 | 32 | 33 | ## Uses 34 | 35 | NOTE: model embeddings tend to extract features only after using standard batch correction post-processing techniques. **We recommend**, at a *minimum*, after inferencing the model over your images, to do the standard `PCA-CenterScale` pattern or better yet Typical Variation Normalization: 36 | 37 | 1. Fit a PCA kernel on all the *control images* (or all images if no controls) from across all experimental batches (e.g. the plates of wells from your assay), 38 | 2. Transform all the embeddings with that PCA kernel, 39 | 3. For each experimental batch, fit a separate StandardScaler on the transformed embeddings of the controls from step 2, then transform the rest of the embeddings from that batch with that StandardScaler. 40 | 41 | ### Direct Use 42 | 43 | - Create biologically useful embeddings of microscopy images 44 | - Create contextualized embeddings of each channel of a microscopy image (set `return_channelwise_embeddings=True`) 45 | - Leverage the full MAE encoder + decoder to predict new channels / stains for images without all 6 CellPainting channels 46 | 47 | ### Downstream Use 48 | 49 | - A determined ML expert could fine-tune the encoder for downstream tasks such as classification 50 | 51 | ### Out-of-Scope Use 52 | 53 | - Unlikely to be especially performant on brightfield microscopy images 54 | - Out-of-domain medical images, such as H&E (maybe it would be a decent baseline though) 55 | 56 | ## Bias, Risks, and Limitations 57 | 58 | - Primary limitation is that the embeddings tend to be more useful at scale. For example, if you only have 1 plate of microscopy images, the embeddings might underperform compared to a supervised bespoke model. 59 | 60 | ## How to Get Started with the Model 61 | 62 | You should be able to successfully run the below tests, which demonstrate how to use the model at inference time. 63 | 64 | ```python 65 | import pytest 66 | import torch 67 | 68 | from huggingface_mae import MAEModel 69 | 70 | huggingface_phenombeta_model_dir = "." 71 | # huggingface_modelpath = "recursionpharma/test-pb-model" 72 | 73 | 74 | @pytest.fixture 75 | def huggingface_model(): 76 | # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory 77 | # huggingface-cli download recursionpharma/test-pb-model --local-dir=. 78 | huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir) 79 | huggingface_model.eval() 80 | return huggingface_model 81 | 82 | 83 | @pytest.mark.parametrize("C", [1, 4, 6, 11]) 84 | @pytest.mark.parametrize("return_channelwise_embeddings", [True, False]) 85 | def test_model_predict(huggingface_model, C, return_channelwise_embeddings): 86 | example_input_array = torch.randint( 87 | low=0, 88 | high=255, 89 | size=(2, C, 256, 256), 90 | dtype=torch.uint8, 91 | device=huggingface_model.device, 92 | ) 93 | huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings 94 | embeddings = huggingface_model.predict(example_input_array) 95 | expected_output_dim = 384 * C if return_channelwise_embeddings else 384 96 | assert embeddings.shape == (2, expected_output_dim) 97 | ``` 98 | 99 | 100 | ## Training, evaluation and testing details 101 | 102 | See paper linked above for details on model training and evaluation. Primary hyperparameters are included in the repo linked above. 103 | 104 | 105 | ## Environmental Impact 106 | 107 | - **Hardware Type:** Nvidia H100 Hopper nodes 108 | - **Hours used:** 400 109 | - **Cloud Provider:** private cloud 110 | - **Carbon Emitted:** 138.24 kg co2 (roughly the equivalent of one car driving from Toronto to Montreal) 111 | 112 | **BibTeX:** 113 | 114 | ```TeX 115 | @inproceedings{kraus2024masked, 116 | title={Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology}, 117 | author={Kraus, Oren and Kenyon-Dean, Kian and Saberian, Saber and Fallah, Maryam and McLean, Peter and Leung, Jess and Sharma, Vasudev and Khan, Ayla and Balakrishnan, Jia and Celik, Safiye and others}, 118 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 119 | pages={11757--11768}, 120 | year={2024} 121 | } 122 | ``` 123 | 124 | ## Model Card Contact 125 | 126 | - Kian Kenyon-Dean: kian.kd@recursion.com 127 | - Oren Kraus: oren.kraus@recursion.com 128 | - Or, email: info@rxrx.ai 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![scorecard-score](https://github.com/recursionpharma/octo-guard-badges/blob/trunk/badges/repo/maes_microscopy/maturity_score.svg?raw=true)](https://infosec-docs.prod.rxrx.io/octoguard/scorecards/maes_microscopy) 2 | [![scorecard-status](https://github.com/recursionpharma/octo-guard-badges/blob/trunk/badges/repo/maes_microscopy/scorecard_status.svg?raw=true)](https://infosec-docs.prod.rxrx.io/octoguard/scorecards/maes_microscopy) 3 | # Masked Autoencoders are Scalable Learners of Cellular Morphology 4 | Official repo for Recursion's two recently accepted papers: 5 | - Spotlight full-length paper at [CVPR 2024](https://cvpr.thecvf.com/Conferences/2024/AcceptedPapers) -- Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology 6 | - Paper: https://arxiv.org/abs/2404.10242 7 | - CVPR poster page with video: https://cvpr.thecvf.com/virtual/2024/poster/31565 8 | - Spotlight workshop paper at [NeurIPS 2023 Generative AI & Biology workshop](https://openreview.net/group?id=NeurIPS.cc/2023/Workshop/GenBio) 9 | - Paper: https://arxiv.org/abs/2309.16064 10 | 11 | ![vit_diff_mask_ratios](https://github.com/recursionpharma/maes_microscopy/assets/109550980/c15f46b1-cdb9-41a7-a4af-bdc9684a971d) 12 | 13 | 14 | ## Provided code 15 | See the repo for ingredients required for defining our MAEs. Users seeking to re-implement training will need to stitch together the Encoder and Decoder modules according to their usecase. 16 | 17 | Furthermore the baseline Vision Transformer architecture backbone used in this work can be built with the following code snippet from Timm: 18 | ``` 19 | import timm.models.vision_transformer as vit 20 | 21 | def vit_base_patch16_256(**kwargs): 22 | default_kwargs = dict( 23 | img_size=256, 24 | in_chans=6, 25 | num_classes=0, 26 | fc_norm=None, 27 | class_token=True, 28 | drop_path_rate=0.1, 29 | init_values=0.0001, 30 | block_fn=vit.ParallelScalingBlock, 31 | qkv_bias=False, 32 | qk_norm=True, 33 | ) 34 | for k, v in kwargs.items(): 35 | default_kwargs[k] = v 36 | return vit.vit_base_patch16_224(**default_kwargs) 37 | ``` 38 | 39 | ## Provided models 40 | A publicly available model for research that handles inference and auto-scaling can be found at: https://www.rxrx.ai/phenom 41 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_attn_implementation_autoset": true, 3 | "apply_loss_unmasked": false, 4 | "architectures": [ 5 | "MAEModel" 6 | ], 7 | "crop_size": -1, 8 | "decoder": { 9 | "_target_": "mae_modules.CAMAEDecoder", 10 | "depth": 8, 11 | "embed_dim": 512, 12 | "mlp_ratio": 4, 13 | "norm_layer": { 14 | "_partial_": true, 15 | "_target_": "torch.nn.LayerNorm", 16 | "eps": 1e-06 17 | }, 18 | "num_heads": 16, 19 | "num_modalities": 6, 20 | "qkv_bias": true, 21 | "tokens_per_modality": 256 22 | }, 23 | "encoder": { 24 | "_target_": "mae_modules.MAEEncoder", 25 | "channel_agnostic": true, 26 | "max_in_chans": 11, 27 | "vit_backbone": { 28 | "_target_": "vit.sincos_positional_encoding_vit", 29 | "vit_backbone": { 30 | "_target_": "vit.vit_small_patch16_256", 31 | "global_pool": "avg" 32 | } 33 | } 34 | }, 35 | "fourier_loss": { 36 | "_target_": "loss.FourierLoss", 37 | "num_multimodal_modalities": 6 38 | }, 39 | "fourier_loss_weight": 0.0, 40 | "input_norm": { 41 | "_args_": [ 42 | { 43 | "_target_": "normalizer.Normalizer" 44 | }, 45 | { 46 | "_target_": "torch.nn.InstanceNorm2d", 47 | "affine": false, 48 | "num_features": null, 49 | "track_running_stats": false 50 | } 51 | ], 52 | "_target_": "torch.nn.Sequential" 53 | }, 54 | "layernorm_unfreeze": true, 55 | "loss": { 56 | "_target_": "torch.nn.MSELoss", 57 | "reduction": "none" 58 | }, 59 | "lr_scheduler": { 60 | "_partial_": true, 61 | "_target_": "torch.optim.lr_scheduler.OneCycleLR", 62 | "anneal_strategy": "cos", 63 | "max_lr": 0.0001, 64 | "pct_start": 0.1 65 | }, 66 | "mask_fourier_loss": true, 67 | "mask_ratio": 0.0, 68 | "model_type": "MAE", 69 | "norm_pix_loss": false, 70 | "num_blocks_to_freeze": 0, 71 | "optimizer": { 72 | "_partial_": true, 73 | "_target_": "timm.optim.lion.Lion", 74 | "betas": [ 75 | 0.9, 76 | 0.95 77 | ], 78 | "lr": 0.0001, 79 | "weight_decay": 0.05 80 | }, 81 | "torch_dtype": "float32", 82 | "transformers_version": "4.46.1", 83 | "trim_encoder_blocks": null, 84 | "use_MAE_weight_init": false 85 | } 86 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # © Recursion Pharmaceuticals 2024 2 | loss: 3 | _target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results 4 | reduction: none 5 | optimizer: 6 | _target_: timm.optim.lion.Lion 7 | _partial_: true 8 | lr: *lr 1e-4 # 1e-4 for <= ViT-B, and 3e-5 for ViT-L 9 | weight_decay: 0.05 10 | betas: [0.9, 0.95] 11 | lr_scheduler: 12 | _target_: torch.optim.lr_scheduler.OneCycleLR 13 | _partial_: true 14 | max_lr: @lr 15 | pct_start: 0.1 16 | anneal_strategy: cos -------------------------------------------------------------------------------- /huggingface_mae.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers import PretrainedConfig, PreTrainedModel 7 | 8 | from loss import FourierLoss 9 | from normalizer import Normalizer 10 | from mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder 11 | from mae_utils import flatten_images 12 | from vit import ( 13 | generate_2d_sincos_pos_embeddings, 14 | sincos_positional_encoding_vit, 15 | vit_small_patch16_256, 16 | ) 17 | 18 | TensorDict = Dict[str, torch.Tensor] 19 | 20 | 21 | class MAEConfig(PretrainedConfig): 22 | model_type = "MAE" 23 | 24 | def __init__( 25 | self, 26 | mask_ratio=0.75, 27 | encoder=None, 28 | decoder=None, 29 | loss=None, 30 | optimizer=None, 31 | input_norm=None, 32 | fourier_loss=None, 33 | fourier_loss_weight=0.0, 34 | lr_scheduler=None, 35 | use_MAE_weight_init=False, 36 | crop_size=-1, 37 | mask_fourier_loss=True, 38 | return_channelwise_embeddings=False, 39 | **kwargs, 40 | ): 41 | super().__init__(**kwargs) 42 | self.mask_ratio = mask_ratio 43 | self.encoder = encoder 44 | self.decoder = decoder 45 | self.loss = loss 46 | self.optimizer = optimizer 47 | self.input_norm = input_norm 48 | self.fourier_loss = fourier_loss 49 | self.fourier_loss_weight = fourier_loss_weight 50 | self.lr_scheduler = lr_scheduler 51 | self.use_MAE_weight_init = use_MAE_weight_init 52 | self.crop_size = crop_size 53 | self.mask_fourier_loss = mask_fourier_loss 54 | self.return_channelwise_embeddings = return_channelwise_embeddings 55 | 56 | 57 | class MAEModel(PreTrainedModel): 58 | config_class = MAEConfig 59 | 60 | # Loss metrics 61 | TOTAL_LOSS = "loss" 62 | RECON_LOSS = "reconstruction_loss" 63 | FOURIER_LOSS = "fourier_loss" 64 | 65 | def __init__(self, config: MAEConfig): 66 | super().__init__(config) 67 | 68 | self.mask_ratio = config.mask_ratio 69 | 70 | # Could use Hydra to instantiate instead 71 | self.encoder = MAEEncoder( 72 | vit_backbone=sincos_positional_encoding_vit( 73 | vit_backbone=vit_small_patch16_256(global_pool="avg") 74 | ), 75 | max_in_chans=11, # upper limit on number of input channels 76 | channel_agnostic=True, 77 | ) 78 | self.decoder = CAMAEDecoder( 79 | depth=8, 80 | embed_dim=512, 81 | mlp_ratio=4, 82 | norm_layer=nn.LayerNorm, 83 | num_heads=16, 84 | num_modalities=6, 85 | qkv_bias=True, 86 | tokens_per_modality=256, 87 | ) 88 | self.input_norm = torch.nn.Sequential( 89 | Normalizer(), 90 | nn.InstanceNorm2d(None, affine=False, track_running_stats=False), 91 | ) 92 | 93 | self.fourier_loss_weight = config.fourier_loss_weight 94 | self.mask_fourier_loss = config.mask_fourier_loss 95 | self.return_channelwise_embeddings = config.return_channelwise_embeddings 96 | self.tokens_per_channel = 256 # hardcode the number of tokens per channel since we are patch16 crop 256 97 | 98 | # loss stuff 99 | self.loss = torch.nn.MSELoss(reduction="none") 100 | 101 | self.fourier_loss = FourierLoss(num_multimodal_modalities=6) 102 | if self.fourier_loss_weight > 0 and self.fourier_loss is None: 103 | raise ValueError( 104 | "FourierLoss weight is activated but no fourier_loss was defined in constructor" 105 | ) 106 | elif self.fourier_loss_weight >= 1: 107 | raise ValueError( 108 | "FourierLoss weight is too large to do mixing factor, weight should be < 1" 109 | ) 110 | 111 | self.patch_size = int(self.encoder.vit_backbone.patch_embed.patch_size[0]) 112 | 113 | # projection layer between the encoder and decoder 114 | self.encoder_decoder_proj = nn.Linear( 115 | self.encoder.embed_dim, self.decoder.embed_dim, bias=True 116 | ) 117 | 118 | self.decoder_pred = nn.Linear( 119 | self.decoder.embed_dim, 120 | self.patch_size**2 121 | * (1 if self.encoder.channel_agnostic else self.in_chans), 122 | bias=True, 123 | ) # linear layer from decoder embedding to input dims 124 | 125 | # overwrite decoder pos embeddings based on encoder params 126 | self.decoder.pos_embeddings = generate_2d_sincos_pos_embeddings( # type: ignore[assignment] 127 | self.decoder.embed_dim, 128 | length=self.encoder.vit_backbone.patch_embed.grid_size[0], 129 | use_class_token=self.encoder.vit_backbone.cls_token is not None, 130 | num_modality=( 131 | self.decoder.num_modalities if self.encoder.channel_agnostic else 1 132 | ), 133 | ) 134 | 135 | if config.use_MAE_weight_init: 136 | w = self.encoder.vit_backbone.patch_embed.proj.weight.data 137 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 138 | 139 | torch.nn.init.normal_(self.encoder.vit_backbone.cls_token, std=0.02) 140 | torch.nn.init.normal_(self.decoder.mask_token, std=0.02) 141 | 142 | self.apply(self._MAE_init_weights) 143 | 144 | def setup(self, stage: str) -> None: 145 | super().setup(stage) 146 | 147 | def _MAE_init_weights(self, m): 148 | if isinstance(m, nn.Linear): 149 | torch.nn.init.xavier_uniform_(m.weight) 150 | if isinstance(m, nn.Linear) and m.bias is not None: 151 | nn.init.constant_(m.bias, 0) 152 | elif isinstance(m, nn.LayerNorm): 153 | nn.init.constant_(m.bias, 0) 154 | nn.init.constant_(m.weight, 1.0) 155 | 156 | @staticmethod 157 | def decode_to_reconstruction( 158 | encoder_latent: torch.Tensor, 159 | ind_restore: torch.Tensor, 160 | proj: torch.nn.Module, 161 | decoder: MAEDecoder | CAMAEDecoder, 162 | pred: torch.nn.Module, 163 | ) -> torch.Tensor: 164 | """Feed forward the encoder latent through the decoders necessary projections and transformations.""" 165 | decoder_latent_projection = proj( 166 | encoder_latent 167 | ) # projection from encoder.embed_dim to decoder.embed_dim 168 | decoder_tokens = decoder.forward_masked( 169 | decoder_latent_projection, ind_restore 170 | ) # decoder.embed_dim output 171 | predicted_reconstruction = pred( 172 | decoder_tokens 173 | ) # linear projection to input dim 174 | return predicted_reconstruction[:, 1:, :] # drop class token 175 | 176 | def forward( 177 | self, imgs: torch.Tensor, constant_noise: Union[torch.Tensor, None] = None 178 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 179 | imgs = self.input_norm(imgs) 180 | latent, mask, ind_restore = self.encoder.forward_masked( 181 | imgs, self.mask_ratio, constant_noise 182 | ) # encoder blocks 183 | reconstruction = self.decode_to_reconstruction( 184 | latent, 185 | ind_restore, 186 | self.encoder_decoder_proj, 187 | self.decoder, 188 | self.decoder_pred, 189 | ) 190 | return latent, reconstruction, mask 191 | 192 | def compute_MAE_loss( 193 | self, 194 | reconstruction: torch.Tensor, 195 | img: torch.Tensor, 196 | mask: torch.Tensor, 197 | ) -> Tuple[torch.Tensor, Dict[str, float]]: 198 | """Computes final loss and returns specific values of component losses for metric reporting.""" 199 | loss_dict = {} 200 | img = self.input_norm(img) 201 | target_flattened = flatten_images( 202 | img, 203 | patch_size=self.patch_size, 204 | channel_agnostic=self.encoder.channel_agnostic, 205 | ) 206 | 207 | loss: torch.Tensor = self.loss( 208 | reconstruction, target_flattened 209 | ) # should be with MSE or MAE (L1) with reduction='none' 210 | loss = loss.mean( 211 | dim=-1 212 | ) # average over embedding dim -> mean loss per patch (N,L) 213 | loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches only 214 | loss_dict[self.RECON_LOSS] = loss.item() 215 | 216 | # compute fourier loss 217 | if self.fourier_loss_weight > 0: 218 | floss: torch.Tensor = self.fourier_loss(reconstruction, target_flattened) 219 | if not self.mask_fourier_loss: 220 | floss = floss.mean() 221 | else: 222 | floss = floss.mean(dim=-1) 223 | floss = (floss * mask).sum() / mask.sum() 224 | 225 | loss_dict[self.FOURIER_LOSS] = floss.item() 226 | 227 | # here we use a mixing factor to keep the loss magnitude appropriate with fourier 228 | if self.fourier_loss_weight > 0: 229 | loss = (1 - self.fourier_loss_weight) * loss + ( 230 | self.fourier_loss_weight * floss 231 | ) 232 | return loss, loss_dict 233 | 234 | def training_step(self, batch: TensorDict, batch_idx: int) -> TensorDict: 235 | img = batch["pixels"] 236 | latent, reconstruction, mask = self(img.clone()) 237 | full_loss, loss_dict = self.compute_MAE_loss(reconstruction, img.float(), mask) 238 | return { 239 | "loss": full_loss, 240 | **loss_dict, # type: ignore[dict-item] 241 | } 242 | 243 | def validation_step(self, batch: TensorDict, batch_idx: int) -> TensorDict: 244 | return self.training_step(batch, batch_idx) 245 | 246 | def update_metrics(self, outputs: TensorDict, batch: TensorDict) -> None: 247 | self.metrics["lr"].update(value=self.lr_scheduler.get_last_lr()) 248 | for key, value in outputs.items(): 249 | if key.endswith("loss"): 250 | self.metrics[key].update(value) 251 | 252 | def on_validation_batch_end( # type: ignore[override] 253 | self, 254 | outputs: TensorDict, 255 | batch: TensorDict, 256 | batch_idx: int, 257 | dataloader_idx: int = 0, 258 | ) -> None: 259 | super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) 260 | 261 | def predict(self, imgs: torch.Tensor) -> torch.Tensor: 262 | imgs = self.input_norm(imgs) 263 | X = self.encoder.vit_backbone.forward_features( 264 | imgs 265 | ) # 3d tensor N x num_tokens x dim 266 | if self.return_channelwise_embeddings: 267 | N, _, d = X.shape 268 | num_channels = imgs.shape[1] 269 | X_reshaped = X[:, 1:, :].view(N, num_channels, self.tokens_per_channel, d) 270 | pooled_segments = X_reshaped.mean( 271 | dim=2 272 | ) # Resulting shape: (N, num_channels, d) 273 | latent = pooled_segments.view(N, num_channels * d).contiguous() 274 | else: 275 | latent = X[:, 1:, :].mean(dim=1) # 1 + 256 * C tokens 276 | return latent 277 | 278 | def save_pretrained(self, save_directory: str, **kwargs): 279 | filename = kwargs.pop("filename", "model.safetensors") 280 | modelpath = f"{save_directory}/{filename}" 281 | self.config.save_pretrained(save_directory) 282 | torch.save({"state_dict": self.state_dict()}, modelpath) 283 | 284 | @classmethod 285 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 286 | filename = kwargs.pop("filename", "model.safetensors") 287 | 288 | modelpath = f"{pretrained_model_name_or_path}/{filename}" 289 | config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 290 | state_dict = torch.load(modelpath, map_location="cpu") 291 | model = cls(config) 292 | model.load_state_dict(state_dict["state_dict"]) 293 | return model 294 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | # © Recursion Pharmaceuticals 2024 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class FourierLoss(nn.Module): 7 | def __init__( 8 | self, 9 | use_l1_loss: bool = True, 10 | num_multimodal_modalities: int = 1, # set to 1 for vanilla MAE, 6 for channel-agnostic MAE 11 | ) -> None: 12 | """ 13 | Fourier transform loss is only sound when using L1 or L2 loss to compare the frequency domains 14 | between the images / their radial histograms. 15 | 16 | We will always set `reduction="none"` and enforce that the computation of any reductions from the 17 | output of this loss be managed by the model under question. 18 | """ 19 | super().__init__() 20 | self.loss = ( 21 | nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none") 22 | ) 23 | self.num_modalities = num_multimodal_modalities 24 | 25 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 26 | # input = reconstructed image, target = original image 27 | # flattened images from MAE are (B, H*W, C), so, here we convert to B x C x H x W (note we assume H == W) 28 | flattened_images = len(input.shape) == len(target.shape) == 3 29 | if flattened_images: 30 | B, H_W, C = input.shape 31 | H_W = H_W // self.num_modalities 32 | four_d_shape = (B, C * self.num_modalities, int(H_W**0.5), int(H_W**0.5)) 33 | input = input.view(*four_d_shape) 34 | target = target.view(*four_d_shape) 35 | else: 36 | B, C, h, w = input.shape 37 | H_W = h * w 38 | 39 | if len(input.shape) != len(target.shape) != 4: 40 | raise ValueError( 41 | f"Invalid input shape: got {input.shape} and {target.shape}." 42 | ) 43 | 44 | fft_reconstructed = torch.fft.fft2(input) 45 | fft_original = torch.fft.fft2(target) 46 | 47 | magnitude_reconstructed = torch.abs(fft_reconstructed) 48 | magnitude_original = torch.abs(fft_original) 49 | 50 | loss_tensor: torch.Tensor = self.loss( 51 | magnitude_reconstructed, magnitude_original 52 | ) 53 | 54 | if ( 55 | flattened_images and not self.num_bins 56 | ): # then output loss should be reshaped 57 | loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C) 58 | 59 | return loss_tensor 60 | -------------------------------------------------------------------------------- /mae_modules.py: -------------------------------------------------------------------------------- 1 | # © Recursion Pharmaceuticals 2024 2 | from functools import partial 3 | from typing import Tuple, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | from timm.models.helpers import checkpoint_seq 8 | from timm.models.vision_transformer import Block, Mlp, VisionTransformer 9 | 10 | from masking import transformer_random_masking 11 | from vit import channel_agnostic_vit 12 | 13 | # If interested in training new MAEs, combine an encoder and decoder into a new module, and you should 14 | # leverage the flattening and unflattening utilities as needed from mae_utils.py. 15 | # Be sure to use an encoder-decoder Linear projection layer to match encoder dims with decoder dimensions. 16 | # As described in the paper, images are self-standardized at the start. 17 | 18 | 19 | class SelfStandardize(nn.Module): 20 | def __init__(self) -> None: 21 | super().__init__() 22 | self.self_standardize = nn.LazyInstanceNorm2d( 23 | affine=False, track_running_stats=False 24 | ) 25 | 26 | def forward(self, pixels: torch.Tensor) -> torch.Tensor: 27 | x = pixels.float() / 255.0 28 | return self.self_standardize(x) 29 | 30 | 31 | class MAEEncoder(nn.Module): 32 | def __init__( 33 | self, 34 | vit_backbone: VisionTransformer, 35 | max_in_chans: int = 6, 36 | channel_agnostic: bool = False, 37 | ) -> None: 38 | super().__init__() 39 | if channel_agnostic: 40 | self.vit_backbone = channel_agnostic_vit( 41 | vit_backbone, max_in_chans=max_in_chans 42 | ) 43 | else: 44 | self.vit_backbone = vit_backbone 45 | self.max_in_chans = max_in_chans 46 | self.channel_agnostic = channel_agnostic 47 | 48 | @property 49 | def embed_dim(self) -> int: 50 | return int(self.vit_backbone.embed_dim) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.vit_backbone.forward_features(x) 54 | x = self.vit_backbone.forward_head(x) 55 | return x # type: ignore[no-any-return] 56 | 57 | def forward_masked( 58 | self, 59 | x: torch.Tensor, 60 | mask_ratio: float, 61 | constant_noise: Union[torch.Tensor, None] = None, 62 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 63 | x = self.vit_backbone.patch_embed(x) 64 | x = self.vit_backbone._pos_embed(x) # adds class token 65 | x_ = x[:, 1:, :] # no class token 66 | x_, mask, ind_restore = transformer_random_masking( 67 | x_, mask_ratio, constant_noise 68 | ) 69 | x = torch.cat([x[:, :1, :], x_], dim=1) # add class token 70 | x = self.vit_backbone.norm_pre(x) 71 | 72 | if self.vit_backbone.grad_checkpointing and not torch.jit.is_scripting(): 73 | x = checkpoint_seq(self.vit_backbone.blocks, x) 74 | else: 75 | x = self.vit_backbone.blocks(x) 76 | x = self.vit_backbone.norm(x) 77 | return x, mask, ind_restore 78 | 79 | 80 | class MAEDecoder(nn.Module): 81 | def __init__( 82 | self, 83 | embed_dim: int = 512, 84 | depth: int = 8, 85 | num_heads: int = 16, 86 | mlp_ratio: float = 4, 87 | qkv_bias: bool = True, 88 | norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment] 89 | ) -> None: 90 | super().__init__() 91 | self.embed_dim = embed_dim 92 | self.pos_embeddings = None # to be overwritten by MAE class 93 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 94 | self.blocks = nn.Sequential( 95 | *[ 96 | Block( 97 | embed_dim, 98 | num_heads, 99 | mlp_ratio, 100 | qkv_bias=qkv_bias, 101 | norm_layer=norm_layer, 102 | ) 103 | for i in range(depth) 104 | ] 105 | ) 106 | self.norm = norm_layer(embed_dim) 107 | 108 | def forward(self, x: torch.Tensor) -> torch.Tensor: 109 | x = x + self.pos_embeddings 110 | x = self.blocks(x) 111 | x = self.norm(x) 112 | return x # type: ignore[no-any-return] 113 | 114 | def forward_masked( 115 | self, x: torch.Tensor, ind_restore: torch.Tensor 116 | ) -> torch.Tensor: 117 | mask_tokens = self.mask_token.repeat( 118 | x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1 119 | ) 120 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token 121 | x_ = torch.gather( 122 | x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) 123 | ) # unshuffle 124 | x = torch.cat([x[:, :1, :], x_], dim=1) # add class token 125 | 126 | x = x + self.pos_embeddings 127 | x = self.blocks(x) 128 | x = self.norm(x) 129 | return x # type: ignore[no-any-return] 130 | 131 | 132 | class CrossAttention(nn.Module): 133 | def __init__( 134 | self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0 135 | ): 136 | super().__init__() 137 | self.num_heads = num_heads 138 | head_dim = embed_dim // num_heads 139 | self.scale = head_dim**-0.5 140 | 141 | self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) 142 | self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) 143 | 144 | self.attn_drop = nn.Dropout(attn_drop) 145 | self.proj = nn.Linear(embed_dim, embed_dim) 146 | self.proj_drop = nn.Dropout(proj_drop) 147 | 148 | def forward(self, x, context): 149 | B, N, C = x.shape 150 | _, M, _ = context.shape 151 | 152 | q = ( 153 | self.q(x) 154 | .reshape(B, N, self.num_heads, C // self.num_heads) 155 | .permute(0, 2, 1, 3) 156 | ) 157 | kv = ( 158 | self.kv(context) 159 | .reshape(B, M, 2, self.num_heads, C // self.num_heads) 160 | .permute(2, 0, 3, 1, 4) 161 | ) 162 | k, v = kv[0], kv[1] 163 | 164 | attn = (q @ k.transpose(-2, -1)) * self.scale 165 | attn = attn.softmax(dim=-1) 166 | attn = self.attn_drop(attn) 167 | 168 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 169 | x = self.proj(x) 170 | x = self.proj_drop(x) 171 | return x 172 | 173 | 174 | class CAMAEDecoder(nn.Module): 175 | def __init__( 176 | self, 177 | num_modalities: int = 6, 178 | tokens_per_modality: int = 256, 179 | embed_dim: int = 256, 180 | depth: int = 2, 181 | num_heads: int = 16, 182 | mlp_ratio: float = 4, 183 | qkv_bias: bool = True, 184 | norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment] 185 | ) -> None: 186 | super().__init__() 187 | self.num_modalities = num_modalities 188 | self.tokens_per_modality = tokens_per_modality 189 | self.embed_dim = embed_dim 190 | self.pos_embeddings = None # to be overwritten by MAE class 191 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 192 | self.placeholder = nn.Parameter( 193 | torch.zeros(1, 1, embed_dim), requires_grad=False 194 | ) 195 | self.modality_tokens = nn.ParameterList( 196 | [ 197 | nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 198 | for modality in range(self.num_modalities) 199 | ] 200 | ) 201 | 202 | self.cross_attention = CrossAttention(embed_dim=self.embed_dim) 203 | self.mlp = Mlp(self.embed_dim, hidden_features=int(self.embed_dim * mlp_ratio)) 204 | 205 | self.decoders = nn.ModuleList( 206 | [ 207 | nn.Sequential( 208 | *[ 209 | Block( 210 | embed_dim, 211 | num_heads, 212 | mlp_ratio, 213 | qkv_bias=qkv_bias, 214 | norm_layer=norm_layer, 215 | ) 216 | for i in range(depth) 217 | ] 218 | ) 219 | for modality in range(self.num_modalities) 220 | ] 221 | ) 222 | # self.norm = norm_layer(embed_dim) # we decided to drop the last layer norm 223 | self.context_norm = norm_layer(embed_dim) 224 | self.query_norm = norm_layer(embed_dim) 225 | self.out_norm = norm_layer(embed_dim) 226 | 227 | def forward(self, x: torch.Tensor) -> torch.Tensor: 228 | x_m_s = [] 229 | 230 | modality_tokens_concat = torch.cat( 231 | [ 232 | self.placeholder, 233 | ] # placeholder for class token 234 | + [ 235 | m_t.repeat(1, self.tokens_per_modality, 1) 236 | for m_t in self.modality_tokens 237 | ], 238 | dim=1, 239 | ) 240 | 241 | x = ( 242 | x + self.pos_embeddings + modality_tokens_concat 243 | ) # add pos and tiled modality tokens 244 | x_ = x[:, 1:, :] # no class token 245 | for m, decoder in enumerate( 246 | self.decoders 247 | ): # iterate through modalities and decoders 248 | x_m = x_[ 249 | :, m * self.tokens_per_modality : (m + 1) * self.tokens_per_modality, : 250 | ] 251 | x_m = self.cross_attention(self.query_norm(x_m), self.context_norm(x_)) 252 | x_m = x_m + self.mlp(self.out_norm(x_m)) 253 | x_m = decoder(x_m) 254 | x_m_s.append(x_m) 255 | x_m_s = torch.cat(x_m_s, dim=1) # concat all tokens 256 | # x_m_s = self.norm(x_m_s) # we decided to drop the last layer norm 257 | x_m_s = torch.cat([x[:, :1, :], x_m_s], dim=1) # add back class token 258 | 259 | return x_m_s 260 | 261 | def forward_masked( 262 | self, x: torch.Tensor, ind_restore: torch.Tensor 263 | ) -> torch.Tensor: 264 | mask_tokens = self.mask_token.repeat( 265 | x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1 266 | ) 267 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token 268 | x_ = torch.gather( 269 | x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) 270 | ) # unshuffle 271 | x = torch.cat([x[:, :1, :], x_], dim=1) # add class token 272 | x = self.forward(x) 273 | return x 274 | -------------------------------------------------------------------------------- /mae_utils.py: -------------------------------------------------------------------------------- 1 | # © Recursion Pharmaceuticals 2024 2 | import math 3 | 4 | import torch 5 | 6 | 7 | def flatten_images( 8 | img: torch.Tensor, patch_size: int, channel_agnostic: bool = False 9 | ) -> torch.Tensor: 10 | """ 11 | Flattens 2D images into tokens with the same pixel values 12 | 13 | Parameters 14 | ---------- 15 | img : input image tensor (N, C, H, W) 16 | 17 | Returns 18 | ------- 19 | flattened_img: flattened image tensor (N, L, patch_size**2 * C) 20 | """ 21 | 22 | if (img.shape[2] != img.shape[3]) or (img.shape[2] % patch_size != 0): 23 | raise ValueError("image H must equal image W and be divisible by patch_size") 24 | in_chans = img.shape[1] 25 | 26 | h = w = int(img.shape[2] // patch_size) 27 | x = img.reshape(shape=(img.shape[0], in_chans, h, patch_size, w, patch_size)) 28 | 29 | if channel_agnostic: 30 | x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHPWQ -> NCHWPQ 31 | x = x.reshape(shape=(img.shape[0], in_chans * h * w, int(patch_size**2))) 32 | else: 33 | x = torch.permute(x, (0, 2, 4, 3, 5, 1)) # NCHPWQ -> NHWPQC 34 | x = x.reshape(shape=(img.shape[0], h * w, int(patch_size**2 * in_chans))) 35 | return x 36 | 37 | 38 | def unflatten_tokens( 39 | tokens: torch.Tensor, 40 | patch_size: int, 41 | num_modalities: int = 1, 42 | channel_agnostic: bool = False, 43 | ) -> torch.Tensor: 44 | """ 45 | Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values 46 | 47 | Parameters 48 | ---------- 49 | tokens : input token tensor (N,L,patch_size**2 * C) 50 | 51 | Returns 52 | ------- 53 | img: image tensor (N,C,H,W) 54 | """ 55 | if num_modalities > 1 and not channel_agnostic: 56 | raise ValueError("Multiple modalities requires channel agnostic unflattening.") 57 | 58 | h = w = int(math.sqrt(tokens.shape[1] // num_modalities)) 59 | if h * w != (tokens.shape[1] // num_modalities): 60 | raise ValueError("sqrt of number of tokens not integer") 61 | 62 | if channel_agnostic: 63 | x = tokens.reshape(shape=(tokens.shape[0], -1, h, w, patch_size, patch_size)) 64 | x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHWPQ -> NCHPWQ 65 | else: 66 | x = tokens.reshape(shape=(tokens.shape[0], h, w, patch_size, patch_size, -1)) 67 | x = torch.permute(x, (0, 5, 1, 3, 2, 4)) # NHWPQC -> NCHPWQ 68 | img = x.reshape(shape=(x.shape[0], -1, h * patch_size, h * patch_size)) 69 | 70 | return img 71 | -------------------------------------------------------------------------------- /masking.py: -------------------------------------------------------------------------------- 1 | # © Recursion Pharmaceuticals 2024 2 | from typing import Tuple, Union 3 | 4 | import torch 5 | 6 | 7 | def transformer_random_masking( 8 | x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None 9 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 10 | """ 11 | Random mask patches per sample 12 | 13 | Parameters 14 | ---------- 15 | x : token tensor (N, L, D) 16 | mask_ratio: float - ratio of image to mask 17 | constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks 18 | 19 | Returns 20 | ------- 21 | x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D) 22 | mask : binary mask indicated masked tokens (1 where masked) (N, L) 23 | ind_restore : locations of masked tokens, needed for decoder 24 | """ 25 | 26 | N, L, D = x.shape # batch, length, dim 27 | len_keep = int(L * (1 - mask_ratio)) 28 | 29 | # use random noise to generate batch based random masks 30 | if constant_noise is not None: 31 | noise = constant_noise 32 | else: 33 | noise = torch.rand(N, L, device=x.device) 34 | 35 | shuffled_tokens = torch.argsort(noise, dim=1) # shuffled index 36 | ind_restore = torch.argsort(shuffled_tokens, dim=1) # unshuffled index 37 | 38 | # get masked input 39 | tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices 40 | x_masked = torch.gather( 41 | x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D) 42 | ) 43 | 44 | # get binary mask used for loss masking: 0 is keep, 1 is remove 45 | mask = torch.ones([N, L], device=x.device) 46 | mask[:, :len_keep] = 0 47 | mask = torch.gather( 48 | mask, dim=1, index=ind_restore 49 | ) # unshuffle to get the binary mask 50 | 51 | return x_masked, mask, ind_restore 52 | -------------------------------------------------------------------------------- /normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Normalizer(torch.nn.Module): 5 | def forward(self, pixels: torch.Tensor) -> torch.Tensor: 6 | pixels = pixels.float() 7 | return pixels / 255.0 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "maes_microscopy_project" 7 | version = "0.1.0" 8 | authors = [ 9 | {name = "kian-kd", email = "kian.kd@recursionpharma.com"}, 10 | {name = "Laksh47", email = "laksh.arumugam@recursionpharma.com"}, 11 | ] 12 | requires-python = ">=3.10.4" 13 | 14 | dependencies = [ 15 | "huggingface-hub", 16 | "timm", 17 | "torch>=2.3", 18 | "torchmetrics", 19 | "torchvision", 20 | "tqdm", 21 | "transformers", 22 | "zarr", 23 | "pytorch-lightning>=2.1", 24 | "matplotlib", 25 | "scikit-image", 26 | "ipykernel", 27 | "isort", 28 | "ruff", 29 | "pytest", 30 | ] 31 | 32 | [tool.setuptools] 33 | py-modules = [] -------------------------------------------------------------------------------- /sample/AA41_s1_1.jp2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/maes_microscopy/06d7705b3c832b0a84baa14c69c71ebe8b19c81f/sample/AA41_s1_1.jp2 -------------------------------------------------------------------------------- /sample/AA41_s1_2.jp2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/maes_microscopy/06d7705b3c832b0a84baa14c69c71ebe8b19c81f/sample/AA41_s1_2.jp2 -------------------------------------------------------------------------------- /sample/AA41_s1_3.jp2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/maes_microscopy/06d7705b3c832b0a84baa14c69c71ebe8b19c81f/sample/AA41_s1_3.jp2 -------------------------------------------------------------------------------- /sample/AA41_s1_4.jp2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/maes_microscopy/06d7705b3c832b0a84baa14c69c71ebe8b19c81f/sample/AA41_s1_4.jp2 -------------------------------------------------------------------------------- /sample/AA41_s1_5.jp2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/maes_microscopy/06d7705b3c832b0a84baa14c69c71ebe8b19c81f/sample/AA41_s1_5.jp2 -------------------------------------------------------------------------------- /sample/AA41_s1_6.jp2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/maes_microscopy/06d7705b3c832b0a84baa14c69c71ebe8b19c81f/sample/AA41_s1_6.jp2 -------------------------------------------------------------------------------- /test_huggingface_mae.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from huggingface_mae import MAEModel 5 | 6 | huggingface_openphenom_model_dir = "." 7 | # huggingface_modelpath = "recursionpharma/OpenPhenom" 8 | 9 | 10 | @pytest.fixture 11 | def huggingface_model(): 12 | # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/OpenPhenom to this directory 13 | # huggingface-cli download recursionpharma/OpenPhenom --local-dir=. 14 | huggingface_model = MAEModel.from_pretrained(huggingface_openphenom_model_dir) 15 | huggingface_model.eval() 16 | return huggingface_model 17 | 18 | 19 | @pytest.mark.parametrize("C", [1, 4, 6, 11]) 20 | @pytest.mark.parametrize("return_channelwise_embeddings", [True, False]) 21 | def test_model_predict(huggingface_model, C, return_channelwise_embeddings): 22 | example_input_array = torch.randint( 23 | low=0, 24 | high=255, 25 | size=(2, C, 256, 256), 26 | dtype=torch.uint8, 27 | device=huggingface_model.device, 28 | ) 29 | huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings 30 | embeddings = huggingface_model.predict(example_input_array) 31 | expected_output_dim = 384 * C if return_channelwise_embeddings else 384 32 | assert embeddings.shape == (2, expected_output_dim) 33 | -------------------------------------------------------------------------------- /vit.py: -------------------------------------------------------------------------------- 1 | # © Recursion Pharmaceuticals 2024 2 | import timm.models.vision_transformer as vit 3 | import torch 4 | 5 | 6 | def generate_2d_sincos_pos_embeddings( 7 | embedding_dim: int, 8 | length: int, 9 | scale: float = 10000.0, 10 | use_class_token: bool = True, 11 | num_modality: int = 1, 12 | ) -> torch.nn.Parameter: 13 | """ 14 | Generate 2Dimensional sin/cosine positional embeddings 15 | 16 | Parameters 17 | ---------- 18 | embedding_dim : int 19 | embedding dimension used in vit 20 | length : int 21 | number of tokens along height or width of image after patching (assuming square) 22 | scale : float 23 | scale for sin/cos functions 24 | use_class_token : bool 25 | True - add zero vector to be added to class_token, False - no vector added 26 | num_modality: number of modalities. If 0, a single modality is assumed. 27 | Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced. 28 | 29 | Returns 30 | ------- 31 | positional_encoding : torch.Tensor 32 | positional encoding to add to vit patch encodings 33 | [num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim] 34 | (w/ or w/o cls_token) 35 | """ 36 | 37 | linear_positions = torch.arange(length, dtype=torch.float32) 38 | height_mesh, width_mesh = torch.meshgrid( 39 | linear_positions, linear_positions, indexing="ij" 40 | ) 41 | positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings 42 | positional_weights = ( 43 | torch.arange(positional_dim, dtype=torch.float32) / positional_dim 44 | ) 45 | positional_weights = 1.0 / (scale**positional_weights) 46 | 47 | height_weights = torch.outer(height_mesh.flatten(), positional_weights) 48 | width_weights = torch.outer(width_mesh.flatten(), positional_weights) 49 | 50 | positional_encoding = torch.cat( 51 | [ 52 | torch.sin(height_weights), 53 | torch.cos(height_weights), 54 | torch.sin(width_weights), 55 | torch.cos(width_weights), 56 | ], 57 | dim=1, 58 | )[None, :, :] 59 | 60 | # repeat positional encoding for multiple channel modalities 61 | positional_encoding = positional_encoding.repeat(1, num_modality, 1) 62 | 63 | if use_class_token: 64 | class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32) 65 | positional_encoding = torch.cat([class_token, positional_encoding], dim=1) 66 | 67 | positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False) 68 | 69 | return positional_encoding 70 | 71 | 72 | class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc] 73 | def __init__( 74 | self, 75 | img_size: int, 76 | patch_size: int, 77 | embed_dim: int, 78 | bias: bool = True, 79 | ) -> None: 80 | super().__init__( 81 | img_size=img_size, 82 | patch_size=patch_size, 83 | in_chans=1, # in_chans is used by self.proj, which we override anyway 84 | embed_dim=embed_dim, 85 | norm_layer=None, 86 | flatten=False, 87 | bias=bias, 88 | ) 89 | # channel-agnostic MAE has a single projection for all chans 90 | self.proj = torch.nn.Conv2d( 91 | 1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias 92 | ) 93 | 94 | def forward(self, x: torch.Tensor) -> torch.Tensor: 95 | in_chans = x.shape[1] 96 | x = torch.stack( 97 | [self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2 98 | ) # single project for all chans 99 | x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC 100 | return x 101 | 102 | 103 | class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc] 104 | def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: 105 | # rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586 106 | to_cat = [] 107 | if self.cls_token is not None: 108 | to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) 109 | 110 | # TODO: upgrade timm to get access to register tokens 111 | # if self.vit_backbone.reg_token is not None: 112 | # to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) 113 | 114 | # MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs 115 | # this supports having CA-MAEs actually be channel-agnostic at inference time 116 | if self.no_embed_class: 117 | x = x + self.pos_embed[:, : x.shape[1]] 118 | if to_cat: 119 | x = torch.cat(to_cat + [x], dim=1) 120 | else: 121 | if to_cat: 122 | x = torch.cat(to_cat + [x], dim=1) 123 | x = x + self.pos_embed[:, : x.shape[1]] 124 | return self.pos_drop(x) # type: ignore[no-any-return] 125 | 126 | 127 | def channel_agnostic_vit( 128 | vit_backbone: vit.VisionTransformer, max_in_chans: int 129 | ) -> vit.VisionTransformer: 130 | # replace patch embedding with channel-agnostic version 131 | vit_backbone.patch_embed = ChannelAgnosticPatchEmbed( 132 | img_size=vit_backbone.patch_embed.img_size[0], 133 | patch_size=vit_backbone.patch_embed.patch_size[0], 134 | embed_dim=vit_backbone.embed_dim, 135 | ) 136 | 137 | # replace positional embedding with channel-agnostic version 138 | vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings( 139 | embedding_dim=vit_backbone.embed_dim, 140 | length=vit_backbone.patch_embed.grid_size[0], 141 | use_class_token=vit_backbone.cls_token is not None, 142 | num_modality=max_in_chans, 143 | ) 144 | 145 | # change the class to be ChannelAgnostic so that it actually uses the new _pos_embed 146 | vit_backbone.__class__ = ChannelAgnosticViT 147 | return vit_backbone 148 | 149 | 150 | def sincos_positional_encoding_vit( 151 | vit_backbone: vit.VisionTransformer, scale: float = 10000.0 152 | ) -> vit.VisionTransformer: 153 | """Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model. 154 | 155 | Parameters 156 | ---------- 157 | vit_backbone : timm.models.vision_transformer.VisionTransformer 158 | the constructed vision transformer from timm 159 | scale : float (default 10000.0) 160 | hyperparameter for sincos positional embeddings, recommend keeping at 10,000 161 | 162 | Returns 163 | ------- 164 | timm.models.vision_transformer.VisionTransformer 165 | the same ViT but with fixed no-grad positional encodings to add to vit patch encodings 166 | """ 167 | # length: number of tokens along height or width of image after patching (assuming square) 168 | length = ( 169 | vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0] 170 | ) 171 | pos_embeddings = generate_2d_sincos_pos_embeddings( 172 | vit_backbone.embed_dim, 173 | length=length, 174 | scale=scale, 175 | use_class_token=vit_backbone.cls_token is not None, 176 | ) 177 | # note, if the model had weight_init == 'skip', this might get overwritten 178 | vit_backbone.pos_embed = pos_embeddings 179 | return vit_backbone 180 | 181 | 182 | def vit_small_patch16_256(**kwargs): 183 | default_kwargs = dict( 184 | img_size=256, 185 | in_chans=6, 186 | num_classes=0, 187 | fc_norm=None, 188 | class_token=True, 189 | drop_path_rate=0.1, 190 | init_values=0.0001, 191 | block_fn=vit.ParallelScalingBlock, 192 | qkv_bias=False, 193 | qk_norm=True, 194 | ) 195 | for k, v in kwargs.items(): 196 | default_kwargs[k] = v 197 | return vit.vit_small_patch16_224(**default_kwargs) 198 | 199 | 200 | def vit_small_patch32_512(**kwargs): 201 | default_kwargs = dict( 202 | img_size=512, 203 | in_chans=6, 204 | num_classes=0, 205 | fc_norm=None, 206 | class_token=True, 207 | drop_path_rate=0.1, 208 | init_values=0.0001, 209 | block_fn=vit.ParallelScalingBlock, 210 | qkv_bias=False, 211 | qk_norm=True, 212 | ) 213 | for k, v in kwargs.items(): 214 | default_kwargs[k] = v 215 | return vit.vit_small_patch32_384(**default_kwargs) 216 | 217 | 218 | def vit_base_patch8_256(**kwargs): 219 | default_kwargs = dict( 220 | img_size=256, 221 | in_chans=6, 222 | num_classes=0, 223 | fc_norm=None, 224 | class_token=True, 225 | drop_path_rate=0.1, 226 | init_values=0.0001, 227 | block_fn=vit.ParallelScalingBlock, 228 | qkv_bias=False, 229 | qk_norm=True, 230 | ) 231 | for k, v in kwargs.items(): 232 | default_kwargs[k] = v 233 | return vit.vit_base_patch8_224(**default_kwargs) 234 | 235 | 236 | def vit_base_patch16_256(**kwargs): 237 | default_kwargs = dict( 238 | img_size=256, 239 | in_chans=6, 240 | num_classes=0, 241 | fc_norm=None, 242 | class_token=True, 243 | drop_path_rate=0.1, 244 | init_values=0.0001, 245 | block_fn=vit.ParallelScalingBlock, 246 | qkv_bias=False, 247 | qk_norm=True, 248 | ) 249 | for k, v in kwargs.items(): 250 | default_kwargs[k] = v 251 | return vit.vit_base_patch16_224(**default_kwargs) 252 | 253 | 254 | def vit_base_patch32_512(**kwargs): 255 | default_kwargs = dict( 256 | img_size=512, 257 | in_chans=6, 258 | num_classes=0, 259 | fc_norm=None, 260 | class_token=True, 261 | drop_path_rate=0.1, 262 | init_values=0.0001, 263 | block_fn=vit.ParallelScalingBlock, 264 | qkv_bias=False, 265 | qk_norm=True, 266 | ) 267 | for k, v in kwargs.items(): 268 | default_kwargs[k] = v 269 | return vit.vit_base_patch32_384(**default_kwargs) 270 | 271 | 272 | def vit_large_patch8_256(**kwargs): 273 | default_kwargs = dict( 274 | img_size=256, 275 | in_chans=6, 276 | num_classes=0, 277 | fc_norm=None, 278 | class_token=True, 279 | patch_size=8, 280 | embed_dim=1024, 281 | depth=24, 282 | num_heads=16, 283 | drop_path_rate=0.3, 284 | init_values=0.0001, 285 | block_fn=vit.ParallelScalingBlock, 286 | qkv_bias=False, 287 | qk_norm=True, 288 | ) 289 | for k, v in kwargs.items(): 290 | default_kwargs[k] = v 291 | return vit.VisionTransformer(**default_kwargs) 292 | 293 | 294 | def vit_large_patch16_256(**kwargs): 295 | default_kwargs = dict( 296 | img_size=256, 297 | in_chans=6, 298 | num_classes=0, 299 | fc_norm=None, 300 | class_token=True, 301 | drop_path_rate=0.3, 302 | init_values=0.0001, 303 | block_fn=vit.ParallelScalingBlock, 304 | qkv_bias=False, 305 | qk_norm=True, 306 | ) 307 | for k, v in kwargs.items(): 308 | default_kwargs[k] = v 309 | return vit.vit_large_patch16_384(**default_kwargs) 310 | -------------------------------------------------------------------------------- /vit_encoder.py: -------------------------------------------------------------------------------- 1 | # © Recursion Pharmaceuticals 2024 2 | from typing import Dict 3 | 4 | import timm.models.vision_transformer as vit 5 | import torch 6 | 7 | 8 | def build_imagenet_baselines() -> Dict[str, torch.jit.ScriptModule]: 9 | """This returns the prepped imagenet encoders from timm, not bad for microscopy data.""" 10 | vit_backbones = [ 11 | _make_vit(vit.vit_small_patch16_384), 12 | _make_vit(vit.vit_base_patch16_384), 13 | _make_vit(vit.vit_base_patch8_224), 14 | _make_vit(vit.vit_large_patch16_384), 15 | ] 16 | model_names = [ 17 | "vit_small_patch16_384", 18 | "vit_base_patch16_384", 19 | "vit_base_patch8_224", 20 | "vit_large_patch16_384", 21 | ] 22 | imagenet_encoders = list(map(_make_torchscripted_encoder, vit_backbones)) 23 | return {name: model for name, model in zip(model_names, imagenet_encoders)} 24 | 25 | 26 | def _make_torchscripted_encoder(vit_backbone) -> torch.jit.ScriptModule: 27 | dummy_input = torch.testing.make_tensor( 28 | (2, 6, 256, 256), 29 | low=0, 30 | high=255, 31 | dtype=torch.uint8, 32 | device=torch.device("cpu"), 33 | ) 34 | encoder = torch.nn.Sequential( 35 | Normalizer(), 36 | torch.nn.LazyInstanceNorm2d( 37 | affine=False, track_running_stats=False 38 | ), # this module performs self-standardization, very important 39 | vit_backbone, 40 | ).to(device="cpu") 41 | _ = encoder(dummy_input) # get those lazy modules built 42 | return torch.jit.freeze(torch.jit.script(encoder.eval())) 43 | 44 | 45 | def _make_vit(constructor): 46 | return constructor( 47 | pretrained=True, # download imagenet weights 48 | img_size=256, # 256x256 crops 49 | in_chans=6, # we expect 6-channel microscopy images 50 | num_classes=0, 51 | fc_norm=None, 52 | class_token=True, 53 | global_pool="avg", # minimal perf diff btwn "cls" and "avg" 54 | ) 55 | 56 | 57 | class Normalizer(torch.nn.Module): 58 | def forward(self, pixels: torch.Tensor) -> torch.Tensor: 59 | pixels = pixels.float() 60 | pixels /= 255.0 61 | return pixels 62 | --------------------------------------------------------------------------------