├── LICENSE
├── README.md
├── evaluation
├── Performance_of_pseudo-cell_generation_expr.py
├── Performance_of_pseudo-cell_generation_lev.py
├── Performance_of_random_cell_generation.py
├── README.md
└── performance_of_classification.py
├── figure
├── example1.jpg
├── example2.jpg
├── example3.jpg
├── example4.jpg
├── intro.gif
├── logo.png
└── overview.jpg
├── finetune.py
├── inference_batch.py
├── inference_one.py
├── inference_web.py
├── raw_data
└── README.md
├── vocabulary_adaptation.py
└── workflow_data
├── GSE117872_to_json.py
├── GSE149383_to_json.py
├── README.md
├── extract_gene_generation.py
├── merge.py
├── mouse_to_json.py
├── sentence_to_experssion.py
├── split.py
├── src
├── __init__.py
├── csdata.py
├── prompts.py
└── utils.py
└── transform.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial-ShareAlike 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58 | Public License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License
63 | ("Public License"). To the extent this Public License may be
64 | interpreted as a contract, You are granted the Licensed Rights in
65 | consideration of Your acceptance of these terms and conditions, and the
66 | Licensor grants You such rights in consideration of benefits the
67 | Licensor receives from making the Licensed Material available under
68 | these terms and conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. BY-NC-SA Compatible License means a license listed at
88 | creativecommons.org/compatiblelicenses, approved by Creative
89 | Commons as essentially the equivalent of this Public License.
90 |
91 | d. Copyright and Similar Rights means copyright and/or similar rights
92 | closely related to copyright including, without limitation,
93 | performance, broadcast, sound recording, and Sui Generis Database
94 | Rights, without regard to how the rights are labeled or
95 | categorized. For purposes of this Public License, the rights
96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
97 | Rights.
98 |
99 | e. Effective Technological Measures means those measures that, in the
100 | absence of proper authority, may not be circumvented under laws
101 | fulfilling obligations under Article 11 of the WIPO Copyright
102 | Treaty adopted on December 20, 1996, and/or similar international
103 | agreements.
104 |
105 | f. Exceptions and Limitations means fair use, fair dealing, and/or
106 | any other exception or limitation to Copyright and Similar Rights
107 | that applies to Your use of the Licensed Material.
108 |
109 | g. License Elements means the license attributes listed in the name
110 | of a Creative Commons Public License. The License Elements of this
111 | Public License are Attribution, NonCommercial, and ShareAlike.
112 |
113 | h. Licensed Material means the artistic or literary work, database,
114 | or other material to which the Licensor applied this Public
115 | License.
116 |
117 | i. Licensed Rights means the rights granted to You subject to the
118 | terms and conditions of this Public License, which are limited to
119 | all Copyright and Similar Rights that apply to Your use of the
120 | Licensed Material and that the Licensor has authority to license.
121 |
122 | j. Licensor means the individual(s) or entity(ies) granting rights
123 | under this Public License.
124 |
125 | k. NonCommercial means not primarily intended for or directed towards
126 | commercial advantage or monetary compensation. For purposes of
127 | this Public License, the exchange of the Licensed Material for
128 | other material subject to Copyright and Similar Rights by digital
129 | file-sharing or similar means is NonCommercial provided there is
130 | no payment of monetary compensation in connection with the
131 | exchange.
132 |
133 | l. Share means to provide material to the public by any means or
134 | process that requires permission under the Licensed Rights, such
135 | as reproduction, public display, public performance, distribution,
136 | dissemination, communication, or importation, and to make material
137 | available to the public including in ways that members of the
138 | public may access the material from a place and at a time
139 | individually chosen by them.
140 |
141 | m. Sui Generis Database Rights means rights other than copyright
142 | resulting from Directive 96/9/EC of the European Parliament and of
143 | the Council of 11 March 1996 on the legal protection of databases,
144 | as amended and/or succeeded, as well as other essentially
145 | equivalent rights anywhere in the world.
146 |
147 | n. You means the individual or entity exercising the Licensed Rights
148 | under this Public License. Your has a corresponding meaning.
149 |
150 |
151 | Section 2 -- Scope.
152 |
153 | a. License grant.
154 |
155 | 1. Subject to the terms and conditions of this Public License,
156 | the Licensor hereby grants You a worldwide, royalty-free,
157 | non-sublicensable, non-exclusive, irrevocable license to
158 | exercise the Licensed Rights in the Licensed Material to:
159 |
160 | a. reproduce and Share the Licensed Material, in whole or
161 | in part, for NonCommercial purposes only; and
162 |
163 | b. produce, reproduce, and Share Adapted Material for
164 | NonCommercial purposes only.
165 |
166 | 2. Exceptions and Limitations. For the avoidance of doubt, where
167 | Exceptions and Limitations apply to Your use, this Public
168 | License does not apply, and You do not need to comply with
169 | its terms and conditions.
170 |
171 | 3. Term. The term of this Public License is specified in Section
172 | 6(a).
173 |
174 | 4. Media and formats; technical modifications allowed. The
175 | Licensor authorizes You to exercise the Licensed Rights in
176 | all media and formats whether now known or hereafter created,
177 | and to make technical modifications necessary to do so. The
178 | Licensor waives and/or agrees not to assert any right or
179 | authority to forbid You from making technical modifications
180 | necessary to exercise the Licensed Rights, including
181 | technical modifications necessary to circumvent Effective
182 | Technological Measures. For purposes of this Public License,
183 | simply making modifications authorized by this Section 2(a)
184 | (4) never produces Adapted Material.
185 |
186 | 5. Downstream recipients.
187 |
188 | a. Offer from the Licensor -- Licensed Material. Every
189 | recipient of the Licensed Material automatically
190 | receives an offer from the Licensor to exercise the
191 | Licensed Rights under the terms and conditions of this
192 | Public License.
193 |
194 | b. Additional offer from the Licensor -- Adapted Material.
195 | Every recipient of Adapted Material from You
196 | automatically receives an offer from the Licensor to
197 | exercise the Licensed Rights in the Adapted Material
198 | under the conditions of the Adapter's License You apply.
199 |
200 | c. No downstream restrictions. You may not offer or impose
201 | any additional or different terms or conditions on, or
202 | apply any Effective Technological Measures to, the
203 | Licensed Material if doing so restricts exercise of the
204 | Licensed Rights by any recipient of the Licensed
205 | Material.
206 |
207 | 6. No endorsement. Nothing in this Public License constitutes or
208 | may be construed as permission to assert or imply that You
209 | are, or that Your use of the Licensed Material is, connected
210 | with, or sponsored, endorsed, or granted official status by,
211 | the Licensor or others designated to receive attribution as
212 | provided in Section 3(a)(1)(A)(i).
213 |
214 | b. Other rights.
215 |
216 | 1. Moral rights, such as the right of integrity, are not
217 | licensed under this Public License, nor are publicity,
218 | privacy, and/or other similar personality rights; however, to
219 | the extent possible, the Licensor waives and/or agrees not to
220 | assert any such rights held by the Licensor to the limited
221 | extent necessary to allow You to exercise the Licensed
222 | Rights, but not otherwise.
223 |
224 | 2. Patent and trademark rights are not licensed under this
225 | Public License.
226 |
227 | 3. To the extent possible, the Licensor waives any right to
228 | collect royalties from You for the exercise of the Licensed
229 | Rights, whether directly or through a collecting society
230 | under any voluntary or waivable statutory or compulsory
231 | licensing scheme. In all other cases the Licensor expressly
232 | reserves any right to collect such royalties, including when
233 | the Licensed Material is used other than for NonCommercial
234 | purposes.
235 |
236 |
237 | Section 3 -- License Conditions.
238 |
239 | Your exercise of the Licensed Rights is expressly made subject to the
240 | following conditions.
241 |
242 | a. Attribution.
243 |
244 | 1. If You Share the Licensed Material (including in modified
245 | form), You must:
246 |
247 | a. retain the following if it is supplied by the Licensor
248 | with the Licensed Material:
249 |
250 | i. identification of the creator(s) of the Licensed
251 | Material and any others designated to receive
252 | attribution, in any reasonable manner requested by
253 | the Licensor (including by pseudonym if
254 | designated);
255 |
256 | ii. a copyright notice;
257 |
258 | iii. a notice that refers to this Public License;
259 |
260 | iv. a notice that refers to the disclaimer of
261 | warranties;
262 |
263 | v. a URI or hyperlink to the Licensed Material to the
264 | extent reasonably practicable;
265 |
266 | b. indicate if You modified the Licensed Material and
267 | retain an indication of any previous modifications; and
268 |
269 | c. indicate the Licensed Material is licensed under this
270 | Public License, and include the text of, or the URI or
271 | hyperlink to, this Public License.
272 |
273 | 2. You may satisfy the conditions in Section 3(a)(1) in any
274 | reasonable manner based on the medium, means, and context in
275 | which You Share the Licensed Material. For example, it may be
276 | reasonable to satisfy the conditions by providing a URI or
277 | hyperlink to a resource that includes the required
278 | information.
279 | 3. If requested by the Licensor, You must remove any of the
280 | information required by Section 3(a)(1)(A) to the extent
281 | reasonably practicable.
282 |
283 | b. ShareAlike.
284 |
285 | In addition to the conditions in Section 3(a), if You Share
286 | Adapted Material You produce, the following conditions also apply.
287 |
288 | 1. The Adapter's License You apply must be a Creative Commons
289 | license with the same License Elements, this version or
290 | later, or a BY-NC-SA Compatible License.
291 |
292 | 2. You must include the text of, or the URI or hyperlink to, the
293 | Adapter's License You apply. You may satisfy this condition
294 | in any reasonable manner based on the medium, means, and
295 | context in which You Share Adapted Material.
296 |
297 | 3. You may not offer or impose any additional or different terms
298 | or conditions on, or apply any Effective Technological
299 | Measures to, Adapted Material that restrict exercise of the
300 | rights granted under the Adapter's License You apply.
301 |
302 |
303 | Section 4 -- Sui Generis Database Rights.
304 |
305 | Where the Licensed Rights include Sui Generis Database Rights that
306 | apply to Your use of the Licensed Material:
307 |
308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309 | to extract, reuse, reproduce, and Share all or a substantial
310 | portion of the contents of the database for NonCommercial purposes
311 | only;
312 |
313 | b. if You include all or a substantial portion of the database
314 | contents in a database in which You have Sui Generis Database
315 | Rights, then the database in which You have Sui Generis Database
316 | Rights (but not its individual contents) is Adapted Material,
317 | including for purposes of Section 3(b); and
318 |
319 | c. You must comply with the conditions in Section 3(a) if You Share
320 | all or a substantial portion of the contents of the database.
321 |
322 | For the avoidance of doubt, this Section 4 supplements and does not
323 | replace Your obligations under this Public License where the Licensed
324 | Rights include other Copyright and Similar Rights.
325 |
326 |
327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328 |
329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339 |
340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349 |
350 | c. The disclaimer of warranties and limitation of liability provided
351 | above shall be interpreted in a manner that, to the extent
352 | possible, most closely approximates an absolute disclaimer and
353 | waiver of all liability.
354 |
355 |
356 | Section 6 -- Term and Termination.
357 |
358 | a. This Public License applies for the term of the Copyright and
359 | Similar Rights licensed here. However, if You fail to comply with
360 | this Public License, then Your rights under this Public License
361 | terminate automatically.
362 |
363 | b. Where Your right to use the Licensed Material has terminated under
364 | Section 6(a), it reinstates:
365 |
366 | 1. automatically as of the date the violation is cured, provided
367 | it is cured within 30 days of Your discovery of the
368 | violation; or
369 |
370 | 2. upon express reinstatement by the Licensor.
371 |
372 | For the avoidance of doubt, this Section 6(b) does not affect any
373 | right the Licensor may have to seek remedies for Your violations
374 | of this Public License.
375 |
376 | c. For the avoidance of doubt, the Licensor may also offer the
377 | Licensed Material under separate terms or conditions or stop
378 | distributing the Licensed Material at any time; however, doing so
379 | will not terminate this Public License.
380 |
381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382 | License.
383 |
384 |
385 | Section 7 -- Other Terms and Conditions.
386 |
387 | a. The Licensor shall not be bound by any additional or different
388 | terms or conditions communicated by You unless expressly agreed.
389 |
390 | b. Any arrangements, understandings, or agreements regarding the
391 | Licensed Material not stated herein are separate from and
392 | independent of the terms and conditions of this Public License.
393 |
394 |
395 | Section 8 -- Interpretation.
396 |
397 | a. For the avoidance of doubt, this Public License does not, and
398 | shall not be interpreted to, reduce, limit, restrict, or impose
399 | conditions on any use of the Licensed Material that could lawfully
400 | be made without permission under this Public License.
401 |
402 | b. To the extent possible, if any provision of this Public License is
403 | deemed unenforceable, it shall be automatically reformed to the
404 | minimum extent necessary to make it enforceable. If the provision
405 | cannot be reformed, it shall be severed from this Public License
406 | without affecting the enforceability of the remaining terms and
407 | conditions.
408 |
409 | c. No term or condition of this Public License will be waived and no
410 | failure to comply consented to unless expressly agreed to by the
411 | Licensor.
412 |
413 | d. Nothing in this Public License constitutes or may be interpreted
414 | as a limitation upon, or waiver of, any privileges and immunities
415 | that apply to the Licensor or You, including from the legal
416 | processes of any jurisdiction or authority.
417 |
418 | =======================================================================
419 |
420 | Creative Commons is not a party to its public
421 | licenses. Notwithstanding, Creative Commons may elect to apply one of
422 | its public licenses to material it publishes and in those instances
423 | will be considered the “Licensor.” The text of the Creative Commons
424 | public licenses is dedicated to the public domain under the CC0 Public
425 | Domain Dedication. Except for the limited purpose of indicating that
426 | material is shared under a Creative Commons public license or as
427 | otherwise permitted by the Creative Commons policies published at
428 | creativecommons.org/policies, Creative Commons does not authorize the
429 | use of the trademark "Creative Commons" or any other trademark or logo
430 | of Creative Commons without its prior written consent including,
431 | without limitation, in connection with any unauthorized modifications
432 | to any of its public licenses or any other arrangements,
433 | understandings, or agreements concerning use of licensed material. For
434 | the avoidance of doubt, this paragraph does not form part of the
435 | public licenses.
436 |
437 | Creative Commons may be contacted at creativecommons.org.
438 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
18 |
19 | The project ChatCell aims to facilitate single-cell analysis with natural language, which derives from the [Cell2Sentence](https://github.com/vandijklab/cell2sentence-ft) technique to obtain cell language tokens and utilizes cell vocabulary adaptation for T5-based pre-training. Have a try with the demo at [GPTStore App](https://chat.openai.com/g/g-vUwj222gQ-chatcell)!
20 |
21 |
22 |
23 | ## ✨ Acknowledgements
24 |
25 | Special thanks to the authors of [Cell2Sentence: Teaching Large Language Models the Language of Biology](https://github.com/vandijklab/cell2sentence-ft) and [Representing cells as sentences enables natural-language processing for single-cell transcriptomics
26 | ](https://github.com/rahuldhodapkar/cell2sentence) for their inspiring work.
27 |
28 | The [`workflow_data/src`](./workflow_data/src) folder and [`transform.py`](./workflow_data/transform.py) in this project are grounded in their research. Grateful for their valuable contributions to the field.
29 |
30 |
31 | ## 🆕 News
32 |
33 | - **\[Feb 2024\]** Our [ChatCell app](https://chat.openai.com/g/g-vUwj222gQ-chatcell) is now live on GPTStore, give it a try📱!
34 | - **\[Feb 2024\]** We released the model weights based on T5 in [small](https://huggingface.co/zjunlp/chatcell-small), [base](https://huggingface.co/zjunlp/chatcell-base), and [large](https://huggingface.co/zjunlp/chatcell-large) configurations on Huggingface 🤗.
35 | - **\[Feb 2024\]** We released the [instructions of ChatCell](https://huggingface.co/datasets/zjunlp/ChatCell-Instructions) on Huggingface 🤗.
36 |
37 |
38 | ## 📌 Table of Contents
39 |
40 | - [⌚️ QuickStart](#2)
41 | - [🛠️ Usage](#3)
42 | - [🚀 Evaluation](#4)
43 | - [🧬 Single-cell Analysis Tasks](#5)
44 | - [📝 Cite](#6)
45 |
46 | ---
47 |
48 | ⌚️ Quickstart
49 |
50 | ```python
51 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
52 |
53 | tokenizer = AutoTokenizer.from_pretrained("zjunlp/chatcell-small")
54 | model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/chatcell-small")
55 | input_text="Distinguish between resistant and sensitive cancer cells in response to Cisplatin, using the data from the 100 most expressed genes in descending order MYL12B FTL MYL12A HIST1H4C RPL23 GSTP1 RPS3 ENO1 RPLP1 TXN ANXA2 PPP1CB B2M RPLP0 HSPA8 H2AFZ TPI1 ANXA1 RPL7 GAPDH CHP1 LDHA RPL3 S100A11 PRDX1 CALM2 CAPZA1 SLC25A5 RPS27 YWHAZ GNB2L1 PTBP3 RPS6 MOB1A S100A2 ACTG1 BROX SAT1 RPL35A CA2 PSMB4 RPL8 TBL1XR1 RPS18 HNRNPH1 RPL27 RPS14 RPS11 ANP32E RPL19 C6ORF62 RPL9 EEF1A1 RPL5 COLGALT1 NPM1 CCT6A RQCD1 CACUL1 RPL4 HSP90AA1 MALAT1 ALDOA PSMA4 SEC61G RPL38 PSMB5 FABP5 HSP90AB1 RPL35 CHCHD2 EIF3E COX4I1 RPL21 PAFAH1B2 PTMA TMED4 PSMB3 H3F3B AGO1 DYNLL1 ATP5A1 LDHB COX7B ACTB RPS27A PSME2 ELMSAN1 NDUFA1 HMGB2 PSMB6 TMSB10 SET RPL12 RPL37A RPS13 EIF1 ATP5G1 RPS3A TOB1."
56 |
57 | # Encode the input text and generate a response with specified generation parameters
58 | input_ids = tokenizer(input_text,return_tensors="pt").input_ids
59 | output_ids = model.generate(input_ids, max_length=512, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, do_sample=True)
60 |
61 | # Decode and print the generated output text
62 | output_text = tokenizer.decode(output_ids[0],skip_special_tokens=True)
63 | print(output_text)
64 | ```
65 |
66 | 🛠️ Usage
67 |
68 | 📚 Step1: Prepare the data
69 |
70 | ❗️Note: You can download the original data from the `raw_data` directory. Alternatively, you can directly download the [pre-processed data we provide on huggingface](https://huggingface.co/datasets/zjunlp/ChatCell-Instructions) to **skip Step 1 of the process**.
71 |
72 | Change to the evaluation directory with the command: `cd workflow_data`.
73 |
74 | **1. For tasks such as random cell sentence generation, pseudo-cell generation, and cell type annotation, we utilize cells from the SHARE-seq mouse skin dataset.**
75 |
76 | - Follow these steps to use the `transform.py` script (This file was initially developed by the [Cell2Sentence](https://github.com/vandijklab/cell2sentence-ft) team, thanks for their great work!🤗) to **translate scENA-seq data into cell sentence**:
77 |
78 | - Define `data_filepath` to specify the path to your downloaded SHARE-seq mouse skin dataset `.h5ad` file.
79 | - Define `output_dir` to specify the directory where the generated cell sentences will be saved.
80 | - Define `eval_output_dir` to specify where figures and evaluation metrics will be stored.
81 | - Run the transformation process by executing the following command in your terminal: `python transform.py`.
82 |
83 | - Then **covert cell sentences to instructions** with `mouse_to_json.py`:
84 |
85 | - Set `input_path` to the `output_dir` specified in `transform.py`.
86 | - Define `train_json_file_path`, `val_json_file_path`, and `test_json_file_path` to specify the paths where you want to save your train, validation, and test datasets in JSON format, respectively.
87 | - Run the following command in your terminal to start the conversion process: `python mouse_to_json.py`.
88 |
89 |
90 | **2. For the drug sensitivity prediction task, we select GSE149383 and GSE117872 datasets.**
91 |
92 | - For GSE149383: Open `GSE149383_to_json.py`, define `expression_data_path` and `cell_info_path` to the location of your downloaded `erl_total_data_2K.csv` and `erl_total_2K_meta.csv` file.
93 | - For GSE117872: Open `GSE117872_to_json.py`, define `expression_data_path` and `cell_info_path` to the location of your downloaded `GSE117872_good_Data_TPM.txt` and `GSE117872_good_Data_cellinfo.txt` file.
94 | - Update `output_json_path` with the desired location for the JSON output files.
95 | - Execute the conversion script:
96 | - Run `python GSE149383_to_json.py` for the GSE149383 dataset.
97 | - Run `python GSE117872_to_json.py` for the GSE117872 dataset.
98 | - Open `split.py`, define `input_path` to the same locations as `output_json_path` used above. Specify the locations for `train_json_file_path`, `val_json_file_path`, and `test_json_file_path` where you want the split datasets to be saved.
99 | - Run the script with `python split.py` to split the dataset into training, validation, and test sets.
100 |
101 | **3. After preparing instructions for each specific task, follow the steps below to merge the datasets using the `merge.py` script.**
102 |
103 | - Ensure that the paths for `train_json_file_path`, `val_json_file_path`, and `test_json_file_path` are correctly set to point to the JSON files you previously generated for each dataset, such as `GSE117872`, `GSE149383`, and `mouse`.
104 | - Run `python merge.py` to start the merging process. This will combine the specified training, validation, and testing datasets into a unified format, ready for further analysis or model training.
105 |
106 |
107 | 📜 Step2 : Vocabulary Adaptation
108 |
109 | To adapt the tokenizer vocabulary with new terms from cell biology, follow these steps using the `vocabulary_adaptation.py` script.
110 |
111 | - Ensure you have the following parameters configured in the script before running it:
112 |
113 | - `tokenizer_last`: The path to the directory containing the pre-existing tokenizer.
114 |
115 | - `tokenizer_now`: The destination path where the updated tokenizer will be saved.
116 |
117 | - `GSE117872_json_file_path`: This should be set to the `output_json_path` variable from the `GSE117872_to_json.py` script
118 |
119 | - `GSE149383_json_file_path`: Similarly, this should match the `output_json_path` variable in the `GSE149383_to_json.py` script.
120 |
121 | - `cell_sentences_hf_path`: This path should correspond to the `cell_sentences_hf` directory, which is specified as the `output_dir` variable within the `transform.py` script
122 |
123 | - Once all parameters are configured, execute the script to update the tokenizer's vocabulary with new cell biology terms. Run the following command in your terminal: `python vocabulary_adaptation.py`.
124 |
125 | 🛠️ Step3: Train and generate
126 |
127 | **1. Training**
128 |
129 | - Open the `finetune.py` script. Update the script with the paths for your training and validation JSON files (`train_json_path` and `valid_json_path`), the tokenizer location (`tokenizer_path`), the base model directory (`model_path`), and the directory where you want to save the fine-tuned model (`output_dir`).
130 | - Execute the fine-tuning process by running the following command in your terminal: `python finetune.py`
131 |
132 | **2. Generation**
133 |
134 | - Single-Instance Inference:
135 | - To run inference on a single instance, set the necessary parameters in `inference_one.py`.
136 | - Execute the script with: `python inference_one.py`.
137 | - Web Interface Inference:
138 | - For interactive web interface inference using Gradio, configure `inference_web.py` with the required parameters.
139 | - Launch the web demo by running: `python inference_web.py`.
140 | - Batch Inference:
141 | - For inference on a batch of instances, adjust the parameters in `inference_batch.py` as needed.
142 | - Start the batch inference process with: `python inference_batch.py`.
143 |
144 |
145 | ⌨️ Step4: Translating sentences into gene expressions
146 |
147 | **For the pseudo-cell generation task, we also translate sentences into gene expressions, including data extraction and transformation stages.**
148 |
149 | - Data Extraction:
150 | - Open `extract_gene_generation.py`. Set up the necessary parameters for generating cells based on cell type. This step is intended for training datasets larger than 500 samples, covering 16 cell types.
151 | - Run the following command in your terminal to start the data extraction process: `python extract_gene_generation.py`.
152 |
153 | - Transformation Process:
154 | - After generating the necessary files, proceed by configuring `sentence_to_expression.py` with the appropriate parameters for the translation process.
155 | - Execute the transformation script with the command: `python sentence_to_expression.py`.
156 |
157 |
158 | 🚀 Evaluation
159 |
160 | To evaluate the performance of various tasks, follow these steps:
161 |
162 | - Change to the evaluation directory with the command: `cd evaluation`.
163 |
164 | - Random Cell Generation Task:
165 | - Open `Performance_of_random_cell_generation.py`.
166 | - Specify the `json_path` to the JSON file with the generated data.
167 | - Specify the `global_path` to the global gene vocabulary file, usually located in the `cell_sentences` subdirectory within `output_dir` specified by the `transform.py` script, and is named `vocab_human.txt`.
168 | - Run the command: `python Performance_of_random_cell_generation.py`.
169 |
170 | - Pseudo-cell Generation Task:
171 | - Depending on the format of your data, open `python Performance_of_pseudo-cell_generation_lev.py`, or `python Performance_of_pseudo-cell_generation_expr.py`.
172 | - Specify the `my_data_path` to the file with the generated pseudo-cell data.
173 | - Specify the `ground_truth_data_path` to the file with the ground truth data.
174 | - Specify the `k` to the K-value for KNN analysis.
175 | - Depending on the format of your data, run: `python Performance_of_pseudo-cell_generation_lev.py`, or `python Performance_of_pseudo-cell_generation_expr.py`.
176 |
177 | - Cell Type Annotation and Drug Sensitivity Prediction Tasks:
178 | - Open `python performance_of_classification.py`.
179 | - Specify the `my_data_path` to the JSON file containing the generated data for the task.
180 | - Run the command: `python performance_of_classification.py`.
181 |
182 | 🧬 Single-cell Analysis Tasks
183 |
184 | ChatCell can handle the following single-cell tasks:
185 |
186 | - Random Cell Sentence Generation.
187 | Random cell sentence generation challenges the model to create cell sentences devoid of predefined biological conditions or constraints. This task aims to evaluate the model's ability to generate valid and contextually appropriate cell sentences, potentially simulating natural variations in cellular behavior.
188 |
189 |
190 |
191 |
192 |
193 |
194 | - Pseudo-cell Generation.
195 | Pseudo-cell generation focuses on generating gene sequences tailored to specific cell type labels. This task is vital for unraveling gene expression and regulation across different cell types, offering insights for medical research and disease studies, particularly in the context of diseased cell types.
196 |
197 |
198 |
199 |
200 |
201 |
202 | - Cell Type Annotation.
203 | For cell type annotation, the model is tasked with precisely classifying cells into their respective types based on gene expression patterns encapsulated in cell sentences. This task is fundamental for understanding cellular functions and interactions within tissues and organs, playing a crucial role in developmental biology and regenerative medicine.
204 |
205 |
206 |
207 |
208 |
209 | - Drug Sensitivity Prediction.
210 | The drug sensitivity prediction task aims to predict the response of different cells to various drugs. It is pivotal in designing effective, personalized treatment plans and contributes significantly to drug development, especially in optimizing drug efficacy and safety.
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 | ### Other Related Projects
220 |
221 | - [Cell2Sentence](https://github.com/rahuldhodapkar/cell2sentence)
222 | - [CellPLM](https://github.com/OmicsML/CellPLM)
223 | - [ScGPT](https://github.com/bowang-lab/scGPT)
224 | - [ScBERT](https://github.com/TencentAILabHealthcare/scBERT)
225 | - [GenePT](https://github.com/yiqunchen/GenePT)
226 | - [ScMulan](https://github.com/SuperBianC/scMulan)
227 | - [bulk2space](https://github.com/ZJUFanLab/bulk2space)
228 |
--------------------------------------------------------------------------------
/evaluation/Performance_of_pseudo-cell_generation_expr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.neighbors import KNeighborsClassifier
3 | from sklearn.metrics import accuracy_score
4 | import anndata
5 | from sklearn.preprocessing import StandardScaler, LabelEncoder
6 | from sklearn.decomposition import PCA
7 | import scanpy as sc
8 | import pandas as pd
9 |
10 | # Load data
11 | my_data_path="path_to_your_data.h5ad"
12 | ground_truth_data_path="path_to_ground_truth_data.h5ad"
13 | k = 5
14 |
15 | my_data = anndata.read_h5ad(my_data_path)
16 | ground_truth_data = anndata.read_h5ad(ground_truth_data_path)
17 |
18 | # Check if gene names from my_data match those from ground_truth_data
19 | genes_my_data = my_data.var_names
20 | genes_ground_truth_data = ground_truth_data.var_names
21 | if not np.array_equal(genes_my_data, genes_ground_truth_data):
22 | #align my_data to contain the gene names present in ground_truth_data
23 | my_data=my_data[:,genes_ground_truth_data]
24 |
25 | # Extract features and labels
26 | features = my_data.X
27 | labels = my_data.obs['cell_type']
28 | features_gt = ground_truth_data.X
29 | labels_gt = ground_truth_data.obs['cell_type']
30 |
31 | # KNN Classifier
32 | knn = KNeighborsClassifier(n_neighbors=k)
33 | knn_gt = KNeighborsClassifier(n_neighbors=k)
34 |
35 | # Train and predict
36 | knn.fit(features, labels)
37 | knn_gt.fit(features_gt, labels_gt)
38 | test_pred = knn.predict(features)
39 | test_pred_gt = knn_gt.predict(features)
40 |
41 | # Evaluate accuracy
42 | accuracy_quality = accuracy_score(labels, test_pred_gt)
43 | print("accuracy_quality:", round(100 * accuracy_quality, 2))
44 |
45 | # Compare predictions
46 | length = len(labels)
47 | count = sum(1 for i in range(length) if test_pred_gt[i] == labels[i] and test_pred[i] == labels[i])
48 | accuracy_discriminability = round(100 * count / length, 2)
49 | print("accuracy_discriminability:", accuracy_discriminability)
50 |
--------------------------------------------------------------------------------
/evaluation/Performance_of_pseudo-cell_generation_lev.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import accuracy_score
3 | import anndata
4 | from Levenshtein import distance
5 | from collections import Counter
6 | from tqdm import tqdm
7 | import json
8 |
9 | class KNNLevenshtein:
10 | def __init__(self, k=3):
11 | self.k = k
12 |
13 | def fit(self, X_train, y_train):
14 | self.X_train = X_train
15 | self.y_train = y_train
16 |
17 | def predict(self, X_test, y_test):
18 | predictions = []
19 | for x in tqdm(X_test):
20 | prediction = self._predict(x)
21 | predictions.append(prediction)
22 | return np.array(predictions)
23 |
24 | def _predict(self, x):
25 | distances = [distance(x, x_train) for x_train in self.X_train]
26 | k_neighbors_indices = np.argsort(distances)[:self.k]
27 | k_neighbor_labels = [self.y_train[i] for i in k_neighbors_indices]
28 | most_common = Counter(k_neighbor_labels).most_common(1)
29 | return most_common[0][0]
30 |
31 | my_data_path='path_to_your_data.json'
32 | ground_truth_data_path='path_to_ground_truth_data.json'
33 | k = 5
34 |
35 | # Load data from the first file
36 | with open(my_data_path, "r", encoding="utf-8") as file:
37 | data = json.load(file)[0:100]
38 |
39 | features = [x['gene'].strip() for x in data]
40 | labels = [x['cell_type'].strip() for x in data]
41 |
42 | # Train and predict on the first dataset
43 | knn_model = KNNLevenshtein(k=k)
44 | knn_model.fit(features, labels)
45 | test_pred = knn_model.predict(features, labels)
46 |
47 | # Load data from the second file
48 | with open(ground_truth_data_path, "r", encoding="utf-8") as file:
49 | data = json.load(file)[0:100]
50 | features_gt = [x['gene'].strip() for x in data]
51 | labels_gt = [x['cell_type'].strip() for x in data]
52 |
53 | # Train and predict on the second dataset
54 | knn_model_gt = KNNLevenshtein(k=k)
55 | knn_model_gt.fit(features_gt, labels_gt)
56 | test_pred_gt = knn_model_gt.predict(features, labels)
57 |
58 | # Evaluate accuracy
59 |
60 | accuracy_quality = accuracy_score(labels, test_pred_gt)
61 | print("accuracy_quality:", round(100 * accuracy_quality, 2))
62 | length = len(labels)
63 | count = sum(1 for i in range(length) if test_pred_gt[i] == labels[i] and test_pred[i] == labels[i])
64 | accuracy_discriminability = round(100 * count / len(test_pred_gt), 2)
65 | print("accuracy_discriminability:", accuracy_discriminability)
66 |
--------------------------------------------------------------------------------
/evaluation/Performance_of_random_cell_generation.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import Counter
3 | from tqdm import tqdm
4 |
5 | # Function to get the gene vocabulary
6 | def get_gene_vocab(path):
7 | # Load gene vocabulary from a file
8 | gene_vocab = set()
9 | with open(path, "r") as file:
10 | for line in file:
11 | gene_name = line.strip().split()[0].upper()
12 | gene_vocab.add(gene_name)
13 | return list(gene_vocab)
14 |
15 | # Function to calculate statistics from the input JSON data
16 | def calculate_statistics(json_path, global_path):
17 | # Load data from JSON file
18 | with open(json_path, "r", encoding="utf-8") as file:
19 | data = json.load(file)
20 |
21 | num_cases = len(data) # Total number of cases
22 | invalid_gene_count = 0 # Count of invalid gene names
23 | total_gene_count = 0 # Total count of gene names
24 | unique_gene_count = 0 # Total count of unique gene names
25 | global_vocab_list = get_gene_vocab(global_path)
26 |
27 | # Iterate over each cell in the data
28 | for cell_idx in tqdm(data):
29 | cell_sentence_list = cell_idx["my_target"]
30 | words = cell_sentence_list.split()
31 | cell_sentence_str = " ".join(words)
32 | generated_gene_names = cell_sentence_str.split(" ")
33 | generated_gene_names = [gene.upper() for gene in generated_gene_names]
34 | gene_name_to_occurrences = Counter(generated_gene_names)
35 |
36 | # Check for invalid gene names
37 | for gene_name in generated_gene_names:
38 | if gene_name not in global_vocab_list:
39 | invalid_gene_count += 1
40 |
41 | unique_gene_count += len(gene_name_to_occurrences)
42 | total_gene_count += len(words)
43 |
44 | print("Total number of cases:", num_cases)
45 | print("Total gene count:", total_gene_count, round(total_gene_count / num_cases, 2))
46 | print("Valid gene count:", total_gene_count - invalid_gene_count, round(100 * (total_gene_count - invalid_gene_count) / total_gene_count, 2))
47 | print("Unique gene count:", unique_gene_count, round(100 * unique_gene_count / total_gene_count, 2))
48 |
49 | # Paths to input files
50 | json_path = 'path_to_your_data.json'
51 | global_path = 'vocab_human.txt'
52 |
53 | # Call the function to calculate statistics
54 | calculate_statistics(json_path, global_path)
55 |
--------------------------------------------------------------------------------
/evaluation/README.md:
--------------------------------------------------------------------------------
1 | 🚀 Evaluation
2 |
3 | To evaluate the performance of various tasks, follow these steps:
4 |
5 | - Random Cell Generation Task:
6 | - Open `Performance_of_random_cell_generation.py`.
7 | - Specify the `json_path` to the JSON file with the generated data.
8 | - Specify the `global_path` to the global gene vocabulary file, usually located in the `cell_sentences` subdirectory within `output_dir` specified by the `transform.py` script, and is named `vocab_human.txt`.
9 | - Run the command: `python Performance_of_random_cell_generation.py`.
10 |
11 | - Pseudo-cell Generation Task:
12 | - Depending on the format of your data, open `python Performance_of_pseudo-cell_generation_lev.py`, or `python Performance_of_pseudo-cell_generation_expr.py`.
13 | - Specify the `my_data_path` to the JSON file with the generated pseudo-cell data.
14 | - Specify the `ground_truth_data_path` to the JSON file with the ground truth data.
15 | - Specify the `k` to the K-value for KNN analysis.
16 | - Depending on the format of your data, run: `python Performance_of_pseudo-cell_generation_lev.py`, or `python Performance_of_pseudo-cell_generation_expr.py`.
17 |
18 | - Cell Type Annotation and Drug Sensitivity Prediction Tasks:
19 | - Open `python performance_of_classification.py`.
20 | - Specify the `my_data_path` to the JSON file containing the generated data for the task.
21 | - Run the command: `python performance_of_classification.py`.
22 |
--------------------------------------------------------------------------------
/evaluation/performance_of_classification.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
2 | import json
3 |
4 | # Path to your data file
5 | my_data_path = 'path_to_your_data.json'
6 |
7 | # Load data from the JSON file
8 | with open(my_data_path, "r", encoding="utf-8") as file:
9 | data = json.load(file)
10 |
11 | # Extract true labels and predicted labels
12 | y_true = [x['target'].strip().lower() for x in data] # Replace with your true labels
13 | y_pred = [x['my_target'].strip().lower() for x in data] # Replace with your predicted labels
14 |
15 | # Calculate evaluation metrics
16 | accuracy = accuracy_score(y_true, y_pred)
17 | precision = precision_score(y_true, y_pred, average='weighted')
18 | recall = recall_score(y_true, y_pred, average='weighted')
19 | f1 = f1_score(y_true, y_pred, average='weighted')
20 |
21 | # Print evaluation metrics
22 | print("Accuracy: ", round(100 * accuracy, 2))
23 | print("Precision: ", round(100 * precision, 2))
24 | print("Recall: ", round(100 * recall, 2))
25 | print("F1 Score: ", round(100 * f1, 2))
26 |
--------------------------------------------------------------------------------
/figure/example1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/ChatCell/f7203340709c31a36fda0a350f9b8c7eac636258/figure/example1.jpg
--------------------------------------------------------------------------------
/figure/example2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/ChatCell/f7203340709c31a36fda0a350f9b8c7eac636258/figure/example2.jpg
--------------------------------------------------------------------------------
/figure/example3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/ChatCell/f7203340709c31a36fda0a350f9b8c7eac636258/figure/example3.jpg
--------------------------------------------------------------------------------
/figure/example4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/ChatCell/f7203340709c31a36fda0a350f9b8c7eac636258/figure/example4.jpg
--------------------------------------------------------------------------------
/figure/intro.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/ChatCell/f7203340709c31a36fda0a350f9b8c7eac636258/figure/intro.gif
--------------------------------------------------------------------------------
/figure/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/ChatCell/f7203340709c31a36fda0a350f9b8c7eac636258/figure/logo.png
--------------------------------------------------------------------------------
/figure/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/ChatCell/f7203340709c31a36fda0a350f9b8c7eac636258/figure/overview.jpg
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from random import choice
3 | from datasets import Dataset, DatasetDict
4 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
5 | import numpy as np
6 | from rouge import Rouge
7 | import os
8 | from datasets import load_dataset, load_metric
9 | import json
10 |
11 | os.environ["WANDB_DISABLED"]="true"
12 | metric = load_metric("rouge")
13 |
14 | train_json_path = "sum/train.json"
15 | valid_json_path = "sum/valid.json"
16 | tokenizer_path="your_new_tokenizer_path"
17 | model_path="google-t5/t5-base"
18 | output_dir="save_model_path"
19 |
20 |
21 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
22 |
23 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
24 | print(len(tokenizer))
25 |
26 | if model.config.vocab_size !=len(tokenizer) :
27 | print("Token embeddings size is not len(tokenizer), resizing.")
28 | model.resize_token_embeddings(len(tokenizer))
29 | else:
30 | print("Token embeddings size is already len(tokenizer), no need to adjust.")
31 |
32 | # Function to load and prepare datasets
33 |
34 | def load_dataset(train_json_path, valid_json_path):
35 | # Load datasets from JSON
36 | with open(train_json_path, 'r', encoding='utf-8') as file:
37 | train_data = json.load(file)
38 | train_dataset = Dataset.from_dict({'source': [item['source'] for item in train_data],
39 | 'target': [item['target'] for item in train_data]})
40 |
41 | with open(valid_json_path, 'r', encoding='utf-8') as file:
42 | valid_data = json.load(file)
43 | valid_dataset = Dataset.from_dict({'source': [item['source'] for item in valid_data],
44 | 'target': [item['target'] for item in valid_data]})
45 |
46 | # Combine into a DatasetDict
47 | dataset_dict = DatasetDict({
48 | 'train': train_dataset,
49 | 'valid': valid_dataset
50 | })
51 |
52 | return dataset_dict
53 | # Load and prepare the dataset
54 | dataset_dict = load_dataset(train_json_path, valid_json_path)
55 | print(dataset_dict)
56 |
57 | # Function to preprocess the datasets
58 |
59 | def process_func(examples):
60 | inputs = examples['source']
61 | targets = examples['target']
62 | model_inputs = tokenizer(inputs)
63 | labels = tokenizer(targets)
64 | model_inputs["labels"] = labels["input_ids"]
65 | return model_inputs
66 |
67 | # Apply preprocessing
68 |
69 | dataset_all = dataset_dict.map(process_func, batched=True)
70 |
71 | print(dataset_all)
72 | # Training arguments
73 | batch_size = 8
74 |
75 | args = Seq2SeqTrainingArguments(
76 | output_dir=output_dir,
77 | evaluation_strategy="steps",
78 | eval_steps=5000,
79 | logging_strategy="steps",
80 | logging_steps=100,
81 | save_strategy="steps",
82 | save_steps=5000,
83 | learning_rate=4e-5,
84 | per_device_train_batch_size=batch_size,
85 | per_device_eval_batch_size=batch_size,
86 | weight_decay=0.01,
87 | save_total_limit=3,
88 | num_train_epochs=50,
89 | predict_with_generate=True,
90 | fp16=False,
91 | load_best_model_at_end=True,
92 | metric_for_best_model="rouge1",
93 | )
94 |
95 | # Function to compute metrics for evaluation
96 |
97 | def compute_metrics(eval_pred):
98 | predictions, labels = eval_pred
99 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
100 | # Replace -100 in the labels as we can't decode them.
101 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
102 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
103 | # Rouge expects a newline after each sentence
104 | decoded_preds = [pred.strip() for pred in decoded_preds]
105 | decoded_labels = [label.strip() for label in decoded_labels]
106 |
107 |
108 | # Compute ROUGE scores
109 | result = metric.compute(predictions=decoded_preds, references=decoded_labels,
110 | use_stemmer=True)
111 |
112 | # Extract ROUGE f1 scores
113 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
114 |
115 | # Add mean generated length to metrics
116 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id)
117 | for pred in predictions]
118 | result["gen_len"] = np.mean(prediction_lens)
119 |
120 | return {k: round(v, 4) for k, v in result.items()}
121 | trainer = Seq2SeqTrainer(
122 | model=model,
123 | args=args,
124 | train_dataset=dataset_all['train'],
125 | eval_dataset=dataset_all['valid'],
126 | data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
127 | tokenizer=tokenizer,
128 | compute_metrics=compute_metrics
129 | )
130 |
131 | # Start training
132 |
133 | trainer.train()
134 | trainer.save_model()
135 |
136 |
--------------------------------------------------------------------------------
/inference_batch.py:
--------------------------------------------------------------------------------
1 | from transformers import (
2 | AutoModelForCausalLM,
3 | AutoTokenizer,
4 | HfArgumentParser,
5 | Trainer,
6 | TrainingArguments,
7 | AutoModelForSeq2SeqLM,
8 | )
9 | import torch
10 | import json
11 | from tqdm import tqdm
12 | # Define paths and load the model and tokenizer
13 |
14 | model_folder = "zjunlp/chatcell-small"
15 | input_path="input_path.json"
16 | output_path="output_path.json"
17 | tokenizer = AutoTokenizer.from_pretrained(model_folder)
18 | model = AutoModelForSeq2SeqLM.from_pretrained(model_folder)
19 | print(f"Tokenizer vocabulary size: {len(tokenizer.vocab.keys())}")
20 |
21 | # Determine the execution device based on availability of CUDA
22 | if torch.cuda.is_available():
23 | device = torch.device("cuda")
24 | model.to(device)
25 | else:
26 | device = torch.device("cpu")
27 | # Move model to the selected device
28 |
29 | # Load data from a JSON file
30 | with open(input_path, "r", encoding="utf-8") as file:
31 | data = json.load(file)
32 |
33 | # Prepare the list of input texts
34 | input_texts = [item["source"] for item in data]
35 |
36 | batch_size = 128
37 |
38 | model.eval()
39 | output_texts = []
40 | for i in tqdm(range(0, len(input_texts), batch_size)):
41 | batch_input_texts = input_texts[i:i+batch_size]
42 | input_ids = tokenizer.batch_encode_plus(batch_input_texts, padding=True, truncation=True, return_tensors="pt").input_ids.to(device)
43 | output_ids = model.generate(input_ids, max_length=512, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, do_sample=True)
44 | batch_output_texts = tokenizer.batch_decode(output_ids,skip_special_tokens=True)
45 | output_texts.extend(batch_output_texts)
46 |
47 | for item, output_text in zip(data, output_texts):
48 | item['my_target'] = output_text
49 |
50 | # Save the updated data to a new JSON file
51 | with open(output_path, "w", encoding="utf-8") as output_file:
52 | json.dump(data, output_file, indent=4, ensure_ascii=False)
53 |
54 |
--------------------------------------------------------------------------------
/inference_one.py:
--------------------------------------------------------------------------------
1 | from transformers import (
2 | AutoModelForCausalLM,
3 | AutoTokenizer,
4 | HfArgumentParser,
5 | Trainer,
6 | TrainingArguments,
7 | AutoModelForSeq2SeqLM,
8 | )
9 | import torch
10 |
11 | # Set the path to the model and specify the input text
12 | model_folder = "zjunlp/chatcell-small"
13 | input_text="Detail the 100 starting genes for a Mix, ranked by expression level: "
14 | tokenizer = AutoTokenizer.from_pretrained(model_folder)
15 | model = AutoModelForSeq2SeqLM.from_pretrained(model_folder)
16 | print(f"Tokenizer vocabulary size: {len(tokenizer.vocab.keys())}")
17 |
18 | # Determine the execution device based on availability of CUDA
19 | if torch.cuda.is_available():
20 | device = torch.device("cuda")
21 | model.to(device)
22 | else:
23 | device = torch.device("cpu")
24 | # Move model to the selected device
25 |
26 |
27 | model.eval()
28 | # Encode the input text and generate a response with specified generation parameters
29 | input_ids = tokenizer(input_text,return_tensors="pt").input_ids.to(device)
30 | output_ids = model.generate(input_ids, max_length=512, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, do_sample=True)
31 | # Decode and print the generated output text
32 | output_text = tokenizer.decode(output_ids[0],skip_special_tokens=True)
33 |
34 | print(output_text)
35 |
--------------------------------------------------------------------------------
/inference_web.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from transformers import (
3 | AutoModelForCausalLM,
4 | AutoTokenizer,
5 | HfArgumentParser,
6 | Trainer,
7 | TrainingArguments,
8 | AutoModelForSeq2SeqLM,
9 | )
10 | import torch
11 | model_folder = "zjunlp/chatcell-small"
12 | tokenizer = AutoTokenizer.from_pretrained(model_folder)
13 | model = AutoModelForSeq2SeqLM.from_pretrained(model_folder)
14 | model.eval()
15 |
16 |
17 |
18 | def run_detector(input_text):
19 | # Encode the input text and generate a response with specified generation parameters
20 | input_ids = tokenizer(input_text,return_tensors="pt").input_ids.to(device)
21 | output_ids = model.generate(input_ids, max_length=512, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, do_sample=True)
22 | # Decode and print the generated output text
23 | output_text = tokenizer.decode(output_ids[0],skip_special_tokens=True)
24 |
25 | return output_text
26 |
27 |
28 |
29 |
30 | css = """
31 | .green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
32 | .red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
33 | .hyperlinks {
34 | display: flex;
35 | align-items: center;
36 | align-content: center;
37 | padding-top: 12px;
38 | justify-content: flex-end;
39 | margin: 0 10px; /* Adjust the margin as needed */
40 | text-decoration: none;
41 | color: #000; /* Set the desired text color */
42 | }
43 | """
44 |
45 | capybara_problem=''
46 | with gr.Blocks(css=css,
47 | theme=gr.themes.Soft(text_size="sm")) as app:
48 |
49 | with gr.Row():
50 |
51 | gr.HTML("""
52 |
53 |
Chatcell: Facilitating Single-Cell Analysis with Natural Language
54 |
55 |
56 | """)
57 | with gr.Row():
58 | gr.Markdown("🐣Project 📃Paper 🥳Code
")
59 |
60 | with gr.Row():
61 | input_box = gr.Textbox(value=capybara_problem, placeholder="Enter text here", lines=4, label="Input Text", )
62 | with gr.Row():
63 | output_text = gr.Textbox(label="Prediction")
64 |
65 | with gr.Row():
66 | clear_button = gr.ClearButton()
67 | submit_button = gr.Button(variant="primary")
68 |
69 | examples = gr.Examples(
70 | examples=[
71 | ["List the initial 100 genes associated with a TAC-1:"],
72 | ["Enumerate the 100 most abundantly expressed starting genes in a cell:"],
73 | ["Could you determine the likely cell type based on these 100 most expressed genes? GM42418 IL1RAPL1 CDK8 MALAT1 JARID2 CAMK1D ZC3H7A GPHN LARS2 HEXB FGFR2 BRWD1 CASC5 MCCC2 NEAT1 PCNT NFIA NIPBL KIF23 GM26917 BZW1 MYOF PRPF38B HSPA9 HNRNPAB RORA ANLN AHNAK CIT ATRX ADGRG6 RTF1 SMC1A TENM1 HMCN1 LDLRAD4 QK AKAP13 LUC7L3 COL1A2 STX12 PTPN14 AKIRIN1 SNRNP48 MYH9 ATXN1 TRAPPC8 MKL1 MAN1A2 S100A14 DPM1 VPS13C FAM132A AMOT ITGA9 TCF4 ARF6 MBNL1 RPS6KC1 ANXA1 NAA35 SRSF6 GGTA1 2410089E03RIK CRYBG3 SMURF1 LITAF CERS6 BEND6 SRSF4 MTUS1 PLCH2 RBM27 ABCB7 PIEZO1 CUL2 RBMS1 RIC8B PTMA CEP128 HNRNPH1 HMMR KPNA4 MTDH EFNA5 EIF2B2 LARP4B SFSWAP CEP83 SLCO3A1 POLR2A KIF20A PGLYRP4 SLC39A11 ITPR2 CDC42SE1 COX7C NCAPG FKBP5 RIOK3 These genes are commonly found in:"],
74 | ["Distinguish between resistant and sensitive cancer cells in response to Cisplatin, using the data from the 100 most expressed genes in descending order MYL12B FTL MYL12A HIST1H4C RPL23 GSTP1 RPS3 ENO1 RPLP1 TXN ANXA2 PPP1CB B2M RPLP0 HSPA8 H2AFZ TPI1 ANXA1 RPL7 GAPDH CHP1 LDHA RPL3 S100A11 PRDX1 CALM2 CAPZA1 SLC25A5 RPS27 YWHAZ GNB2L1 PTBP3 RPS6 MOB1A S100A2 ACTG1 BROX SAT1 RPL35A CA2 PSMB4 RPL8 TBL1XR1 RPS18 HNRNPH1 RPL27 RPS14 RPS11 ANP32E RPL19 C6ORF62 RPL9 EEF1A1 RPL5 COLGALT1 NPM1 CCT6A RQCD1 CACUL1 RPL4 HSP90AA1 MALAT1 ALDOA PSMA4 SEC61G RPL38 PSMB5 FABP5 HSP90AB1 RPL35 CHCHD2 EIF3E COX4I1 RPL21 PAFAH1B2 PTMA TMED4 PSMB3 H3F3B AGO1 DYNLL1 ATP5A1 LDHB COX7B ACTB RPS27A PSME2 ELMSAN1 NDUFA1 HMGB2 PSMB6 TMSB10 SET RPL12 RPL37A RPS13 EIF1 ATP5G1 RPS3A TOB1."],
75 | ["Evaluate a cancer cell's response to Erlotinib (resistant or sensitive), based on the cell's 100 most actively expressed genes in descending order MT-RNR2 MT-CO3 MT-CO1 RPL13 RPLP1 FTH1 GAPDH RPS8 PABPC1 FTL ANXA2 NCL RPS12 PTMA RPS14 CAV1 RPS3 RPS4X MT-CO2 RPL37 KRT7 RPS2 TMSB4X HSP90AA1 GSTP1 MT-ND4 MT-ATP6 S100A6 NAP1L1 RPL31 HSP90AB1 B2M RPL19 PPIA NUCKS1 MT-ND5 RPS27A TPM1 RPL18 ATP5G3 RPLP0 RPL8 TXN GNAS PSMA7 RPL30 MYL6 SLC25A5 RAD21 RTN4 RPL37A HSPD1 LRRFIP1 DEK MT-ND1 RPL11 TPT1 TMSB10 RPS24 SSRP1 LUC7L3 RPS20 KRT18 RPS16 RPL5 RPL35 RPS21 HNRNPC RPLP2 NACA GNB2L1 LGALS1 UQCRH FAU CHCHD2 RPL23 LDHB UBA52 HSP90B1 SEC62 RPL6 DSTN RPL27A RPL23A RPL22L1 KTN1 RPL14 CALM2 PRDX1 ADIPOR2 ZFAS1 HIST1H4C UQCRQ CALU PTTG1 EIF1 RPL26 RPL10A RAN ARPC2."]
76 | ],
77 | examples_per_page=3,
78 | inputs=[input_box],
79 | )
80 |
81 |
82 |
83 | with gr.Accordion("Disclaimer", open=False):
84 | gr.Markdown(
85 | """
86 | - `Accuracy` :
87 | - This model aims for accuracy but cannot guarantee 100% precision. Results should be used as a guide, not as definitive tools.
88 | - `Use Restrictions`:
89 | - Intended for educational and research purposes only. Not for clinical or commercial use without prior verification.
90 | - `Responsibility`:
91 | - Users are solely responsible for any outcomes resulting from the use of this model.
92 | """
93 | )
94 | with gr.Accordion("Cite our work", open=False):
95 | gr.Markdown(
96 | """
97 | ```bibtex
98 | @article{fang2024chatcell,
99 | title={ChatCell: Facilitating Single-Cell Analysis with Natural Language},
100 | author={Fang, Yin and Liu, Kangwei and Zhang, Ningyu and Deng, Xinle and Yang, Penghui and Chen, Zhuo and Tang, Xiangru and Gerstein, Mark and Fan, Xiaohui and Chen, Huajun},
101 | journal={arXiv preprint arXiv:2306.08018},
102 | year={2024}
103 | }
104 | """
105 | )
106 |
107 | submit_button.click(run_detector, inputs=input_box, outputs=output_text)
108 | clear_button.click(lambda: ("", ""), outputs=[input_box, output_text])
109 |
110 | app.launch()
111 | # 📙
112 |
--------------------------------------------------------------------------------
/raw_data/README.md:
--------------------------------------------------------------------------------
1 | - SHARE-seq mouse skin dataset can be downloaded from: [rna.h5ad.zip](https://drive.google.com/file/d/1xXfjqN1wtCg5GqLMnIUih5b9J_wVGGYf/view?usp=sharing)
2 |
3 |
4 | - GSE149383 and GSE117872 datasets can be downloaded from: [GSE149383.zip](https://drive.google.com/file/d/1Qp8Hmb2wARISDsmwqbnndGwJ2BXfamEa/view?usp=sharing) and [GSE117872.zip](https://drive.google.com/file/d/1clBoSwfJKRAlauUwDC3xAV9i1jC6XU5r/view?usp=sharing), respectively.
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/vocabulary_adaptation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from tqdm import tqdm
4 | from datasets import load_from_disk, concatenate_datasets
5 | from transformers import AutoTokenizer
6 |
7 | # Paths for tokenizers and data files
8 | tokenizer_last = 'google-t5/t5-base'
9 | tokenizer_now ='your_new_tokenizer_path'
10 | GSE117872_json_file_path = 'GSE117872.json'
11 | GSE149383_json_file_path = 'GSE149383.json'
12 | cell_sentences_hf_path = 'cell_sentences_hf_path'
13 |
14 | # List of essential cell biology terms
15 | cell_vocab=['Dermal Fibroblast', 'Dermal Papilla', 'TAC-1', 'IRS', 'Basal', 'K6+ Bulge Companion Layer', 'Medulla', 'alowCD34+ bulge', 'Mix', 'Isthmus', 'ORS', 'Infundibulum', 'Spinous', 'ahighCD34+ bulge', 'TAC-2', 'Macrophage DC', 'Endothelial', 'Dermal Sheath', 'Sebaceous Gland', 'Granular', 'Hair Shaft-cuticle.cortex', 'Schwann Cell', 'Melanocyte']
16 | cell_vocab.extend(['PBMC','Erlotinib','Cisplatin'])
17 |
18 | # Load mouse datasets
19 | train_ds = load_from_disk(os.path.join(cell_sentences_hf_path, 'train'))
20 | val_ds = load_from_disk(os.path.join(cell_sentences_hf_path, 'valid'))
21 | test_ds = load_from_disk(os.path.join(cell_sentences_hf_path, 'test'))
22 | # Concatenate datasets and preprocess
23 | total_ds = concatenate_datasets([train_ds, val_ds, test_ds])
24 | total_ds = total_ds.map(lambda example: {"first_100_gene_words": example["input_ids"].split(" ")[:100]})
25 | for cell_idx in tqdm(range(len(total_ds))):
26 | cell_sentence_list = total_ds[cell_idx]["first_100_gene_words"]
27 | cell_vocab.extend(cell_sentence_list)
28 |
29 | # Load GSE datasets
30 | with open(GSE117872_json_file_path, 'r') as f1, open(GSE149383_json_file_path, 'r') as f2:
31 | data1 = json.load(f1)
32 | data2 = json.load(f2)
33 | GSE_data=data1+data2
34 | for x in GSE_data:
35 | gene=x['source']
36 | gene=gene.split()[-100:]
37 | gene[-1]=gene[-1][:-1]
38 | cell_vocab.extend(gene)
39 |
40 | # Load the tokenizer and update its vocabulary
41 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_last,from_slow=True)
42 | dif=list(set(cell_vocab) - set(tokenizer.vocab.keys()))
43 | dif.sort()
44 | tokenizer.add_tokens(dif)
45 | # Save the updated tokenizer
46 | tokenizer.save_pretrained(tokenizer_now)
47 |
48 |
49 |
--------------------------------------------------------------------------------
/workflow_data/GSE117872_to_json.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import json
4 | import random
5 |
6 | # Seed for reproducibility
7 | random.seed(42)
8 |
9 | # Function to construct a template for drug sensitivity prediction
10 | def construct_template(drug_name, gene_list):
11 | prompts=[
12 | "Predict the drug sensitivity of a cancer cell to {} as resistant or sensitive, using its top 100 genes expressed in descending order of expression {}.",
13 | "Determine whether a cancer cell is likely to be resistant or sensitive to {}, based on its 100 highest expressed genes in descending order {}.",
14 | "Identify the drug sensitivity classification (resistant/sensitive) for a cancer cell to {}, using its top 100 genes sorted by decreasing expression levels {}.",
15 | "Assess the likelihood of a cancer cell being resistant or sensitive to {}, considering its 100 genes with the highest expression in descending order {}.",
16 | "Evaluate a cancer cell's response to {} (resistant or sensitive), based on the cell's 100 most actively expressed genes in descending order {}.",
17 | "Analyze the drug sensitivity (resistant/sensitive) of a cancer cell to {}, by examining its top 100 genes with the highest expression levels in descending order {}.",
18 | "Predict the efficacy of {} on a cancer cell, classifying it as resistant or sensitive, based on the cell's 100 most expressed genes in descending order {}.",
19 | "Distinguish between resistant and sensitive cancer cells in response to {}, using the data from the 100 most expressed genes in descending order {}.",
20 | "Classify a cancer cell's reaction to {} as resistant or sensitive, analyzing its top 100 genes by highest expression in descending order {}.",
21 | "Forecast the drug sensitivity outcome (resistant/sensitive) of a cancer cell to {}, guided by its 100 most expressed genes in descending order {}."
22 | ]
23 | selected_template = random.choice(prompts)
24 | return selected_template.format(drug_name, gene_list)
25 |
26 | # Read and process gene expression data
27 | expression_data_path = 'GSE117872/GSE117872_good_Data_TPM.txt'
28 | # Read drug sensitivity data (adjust the file path and column names as necessary)
29 | cell_info_path = 'GSE117872/GSE117872_good_Data_cellinfo.txt'
30 | output_json_path = 'GSE117872.json'
31 |
32 | expression_data = pd.read_csv(expression_data_path, sep='\t').transpose()
33 |
34 | # Extract the top 100 expressed genes for each sample
35 | top100genes = [' '.join(row.sort_values(ascending=False).head(100).index).upper() for _, row in expression_data.iterrows()]
36 |
37 | cell_info_data = pd.read_csv(cell_info_path, sep='\t')
38 |
39 | # Generate structured JSON data
40 | json_data = []
41 | for i, row in cell_info_data.iterrows():
42 | drug_name = 'Cisplatin' # Example; adjust as needed
43 | gene_list = top100genes[i].upper()
44 | sensitivity = cell_info_data.iat[i, 5] #
45 | if sensitivity=='Holiday':
46 | sensitivity='sensitive'
47 | sensitivity=sensitivity.lower()
48 | json_data.append({
49 | "source": construct_template(drug_name, gene_list),
50 | "target": sensitivity,
51 | })
52 |
53 | # Save JSON data to file
54 | with open(output_json_path, 'w', encoding='utf-8') as file:
55 | json.dump(json_data, file, indent=4, ensure_ascii=False)
56 |
57 | print("JSON file has been created.")
58 |
--------------------------------------------------------------------------------
/workflow_data/GSE149383_to_json.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import random
3 | import json
4 | import sys
5 | # Set random seed for reproducibility
6 | random.seed(42)
7 | # Function to construct a template for drug sensitivity prediction
8 | def construct_template(drug_name, gene_list):
9 | prompts=[
10 | "Predict the drug sensitivity of a cancer cell to {} as resistant or sensitive, using its top 100 genes expressed in descending order of expression {}.",
11 | "Determine whether a cancer cell is likely to be resistant or sensitive to {}, based on its 100 highest expressed genes in descending order {}.",
12 | "Identify the drug sensitivity classification (resistant/sensitive) for a cancer cell to {}, using its top 100 genes sorted by decreasing expression levels {}.",
13 | "Assess the likelihood of a cancer cell being resistant or sensitive to {}, considering its 100 genes with the highest expression in descending order {}.",
14 | "Evaluate a cancer cell's response to {} (resistant or sensitive), based on the cell's 100 most actively expressed genes in descending order {}.",
15 | "Analyze the drug sensitivity (resistant/sensitive) of a cancer cell to {}, by examining its top 100 genes with the highest expression levels in descending order {}.",
16 | "Predict the efficacy of {} on a cancer cell, classifying it as resistant or sensitive, based on the cell's 100 most expressed genes in descending order {}.",
17 | "Distinguish between resistant and sensitive cancer cells in response to {}, using the data from the 100 most expressed genes in descending order {}.",
18 | "Classify a cancer cell's reaction to {} as resistant or sensitive, analyzing its top 100 genes by highest expression in descending order {}.",
19 | "Forecast the drug sensitivity outcome (resistant/sensitive) of a cancer cell to {}, guided by its 100 most expressed genes in descending order {}."
20 | ]
21 | selected_template = random.choice(prompts)
22 | return selected_template.format(drug_name, gene_list)
23 |
24 | # Define file paths for CSV data and output JSON
25 | expression_data_path = 'GSE149383/erl_total_data_2K.csv'
26 | cell_info_path = 'GSE149383/erl_total_2K_meta.csv'
27 | output_json_path = 'GSE149383.json'
28 |
29 | # Load expression data from CSV
30 | expression_data = pd.read_csv(expression_data_path)
31 | print(expression_data.head())
32 |
33 | # Transpose data to have genes as columns
34 | expression_data = expression_data.T
35 | rows, columns = expression_data.shape
36 | print(expression_data.head(), f"Rows: {rows}", f"Columns: {columns}", sep='\n')
37 |
38 | # Initialize a list to store top 100 gene names for each sample
39 | top100genes = []
40 |
41 | # Iterate over rows, skipping the first row with gene names
42 | flag=0
43 | for index, row in expression_data.iterrows():
44 | if flag==0:# Skip the header row with gene names
45 | flag=1
46 | continue
47 | # Sort and select top 100 genes
48 | top_100_genes = row.sort_values(ascending=False).head(100).index.tolist()
49 | top_100_genes = [expression_data.iloc[0, int(gene)] for gene in top_100_genes]
50 |
51 | top100genes.append(' '.join(top_100_genes))
52 |
53 | print(f"Total samples processed: {len(top100genes)}")
54 |
55 |
56 | # Load metadata
57 | metadata = pd.read_csv(cell_info_path, header=None)
58 | print(metadata.head(), f"Rows: {metadata.shape[0]}", f"Columns: {metadata.shape[1]}", sep='\n')
59 |
60 | # Prepare JSON data
61 | json_data = []
62 | for i in range(len(metadata)):
63 | drug_name = 'Erlotinib' # Example drug name
64 | gene_list = top100genes[i].upper()
65 | sensitivity_value = metadata.iloc[i, 3]
66 | entry = {"source": construct_template(drug_name, gene_list), "target": sensitivity_value}
67 | json_data.append(entry)
68 |
69 | # Save JSON data to file
70 | with open(output_json_path, 'w', encoding='utf-8') as file:
71 | json.dump(json_data, file, indent=4, ensure_ascii=False)
72 |
73 | print("JSON file has been created.")
74 |
--------------------------------------------------------------------------------
/workflow_data/README.md:
--------------------------------------------------------------------------------
1 | ## ✨ Acknowledgements
2 |
3 | Special thanks to the authors of [Cell2Sentence: Teaching Large Language Models the Language of Biology](https://github.com/vandijklab/cell2sentence-ft) and [Representing cells as sentences enables natural-language processing for single-cell transcriptomics
4 | ](https://github.com/rahuldhodapkar/cell2sentence) for their inspiring work.
5 |
6 | The [`src`](src) folder and [`transform.py`](transform.py) in this project are grounded in their research. Grateful for their valuable contributions to the field. This portion of the intellectual property belongs to the Cell2Sentence team, adhering to the same Attribution-NonCommercial-ShareAlike 4.0 International License.
7 |
8 | If you use those code, please cite the following related papers:
9 | ```
10 | @article{levine2023cell2sentence,
11 | title={Cell2sentence: Teaching large language models the language of biology},
12 | author={Levine, Daniel and Rizvi, Syed Asad and L{\'e}vy, Sacha and Pallikkavaliyaveetil, Nazreen and Wu, Ruiming and Han, Insu and Zheng, Zihe and Oliveira Fonseca, Antonio Henrique de and Chen, Xingyu and Ghadermarzi, Sina and others},
13 | journal={bioRxiv},
14 | pages={2023--09},
15 | year={2023},
16 | publisher={Cold Spring Harbor Laboratory}
17 | }
18 |
19 | @article{dhodapkar2022representing,
20 | title={Representing cells as sentences enables natural-language processing for single-cell transcriptomics},
21 | author={Dhodapkar, Rahul M},
22 | journal={bioRxiv},
23 | pages={2022--09},
24 | year={2022},
25 | publisher={Cold Spring Harbor Laboratory}
26 | }
27 | ```
28 |
--------------------------------------------------------------------------------
/workflow_data/extract_gene_generation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import anndata
3 | import json
4 | from tqdm import tqdm
5 | import scanpy as sc
6 |
7 | # List of cell types that have appeared at least 500 times in the training dataset
8 | cell500=['Dermal Fibroblast', 'TAC-1', 'IRS', 'Basal', 'Medulla', 'alowCD34+ bulge', 'Mix', 'ORS', 'Infundibulum', 'Spinous', 'ahighCD34+ bulge', 'TAC-2', 'Hair Shaft-cuticle.cortex','Endothelial','Isthmus','Dermal Papilla']
9 | for i in range(len(cell500)):
10 | cell500[i] = ' ' + cell500[i]
11 |
12 | initial_prompt_templates = [
13 | "Identify the cell type most likely associated with these 100 highly expressed genes listed in descending order.",
14 | "Determine the probable cell type for the following 100 genes with the highest expression levels.",
15 | "Indicate the cell type typically linked to these 100 top-expressed genes.",
16 | "Specify the most likely cell type based on these 100 genes sorted by decreasing expression.",
17 | "Find the cell type that corresponds to these top 100 highly expressed genes.",
18 | "Point out the cell type these 100 genes with peak expression levels most commonly represent.",
19 | "Deduce the cell type likely to have these 100 highly expressed genes.",
20 | "Pinpoint the cell type that these 100 genes with the highest expression levels are most likely associated with.",
21 | "Ascertain the cell type from which these 100 highly expressed genes likely originate.",
22 | "Reveal the likely cell type linked to these 100 genes, listed by decreasing expression levels.",
23 | "Uncover the most probable cell type related to these 100 highly expressed genes.",
24 | "Indicate the cell type that would logically have these 100 top-expressed genes.",
25 | "Provide the likely cell type based on these 100 genes with high expression levels.",
26 | "Isolate the cell type commonly associated with these 100 top genes.",
27 | "Establish the cell type that these 100 genes with the highest expression levels are most likely from.",
28 | "Discern the likely cell type for these 100 genes sorted by expression level.",
29 | "Note the cell type typically associated with these 100 most expressed genes.",
30 | "Report the cell type most probably linked to these 100 genes with peak expression.",
31 | "Conclude the most likely cell type these 100 genes are associated with.",
32 | "State the probable cell type connected to these 100 top-expressed genes.",
33 | "What cell type is most likely represented by these top 100 highly expressed genes?",
34 | "Identify the probable cell type for these 100 genes with the highest expression levels.",
35 | "Which cell type is typically associated with these 100 most expressed genes?",
36 | "Can you deduce the cell type based on this list of 100 highly expressed genes?",
37 | "Given these 100 genes sorted by decreasing expression, what is the likely cell type?",
38 | "Based on these top 100 genes, which cell type are they most commonly found in?",
39 | "What type of cell is most likely to express these 100 genes in decreasing order of expression?",
40 | "What is the probable cell type these 100 highly expressed genes are associated with?",
41 | "From which cell type do these 100 most expressed genes likely originate?",
42 | "Determine the cell type likely associated with these 100 genes listed by decreasing expression.",
43 | "Given these 100 highly expressed genes, can you identify the likely cell type?",
44 | "Infer the cell type based on these 100 genes with the highest expression levels.",
45 | "Which cell type is likely to have these 100 genes with the highest expression?",
46 | "Could you specify the cell type most likely associated with these top 100 genes?",
47 | "What cell type would you associate with these 100 highly expressed genes?",
48 | "Can you tell the likely cell type for these 100 genes, sorted by decreasing expression?",
49 | "What is the likely cell type based on these 100 top expressed genes?",
50 | "Identify the cell type most commonly associated with these 100 genes.",
51 | "Based on these genes listed by decreasing expression, what cell type are they likely from?",
52 | "Given these 100 genes with high expression levels, what is the probable cell type?",
53 | "Which cell type is expected to have these 100 genes with the highest levels of expression?",
54 | "What is the most probable cell type based on these 100 genes with peak expression levels?",
55 | "What cell type would most likely have these 100 top expressed genes?",
56 | "Which cell type most probably corresponds to these 100 highly expressed genes?",
57 | "Could you determine the likely cell type based on these 100 most expressed genes?",
58 | "What type of cell would most likely contain these 100 genes with highest expression?",
59 | "Based on the list of 100 genes, what is the most likely corresponding cell type?",
60 | "Please identify the cell type that these 100 highly expressed genes are most likely linked to.",
61 | "Given these 100 genes ranked by expression, what would be the associated cell type?",
62 | "What would be the probable cell type for these 100 genes, listed by decreasing expression?",
63 | "Can you deduce the most likely cell type for these top 100 highly expressed genes?",
64 | "Identify the likely cell type these 100 genes with top expression could represent.",
65 | "Based on the following 100 genes, can you determine the cell type they are commonly found in?",
66 | "What is the likely originating cell type of these 100 top expressed genes?",
67 | "Specify the cell type most commonly linked with these 100 highly expressed genes.",
68 | "Which cell type would you expect to find these 100 genes with high expression levels?",
69 | "Indicate the probable cell type these 100 genes are commonly associated with.",
70 | "According to these 100 genes with highest expression, what cell type are they most likely from?",
71 | "Which cell type is these 100 genes with the highest expression levels most commonly found in?",
72 | "Could you point out the likely cell type linked with these 100 genes sorted by decreasing expression?",
73 | ## add
74 | "Ascertain which cell type is most closely associated with these 100 genes exhibiting the highest levels of expression.",
75 | "Elucidate the cell type that these 100 genes, ranked by their expression, most closely correlate with.",
76 | "Predict the cell type associated with the highest expression of these 100 genes.",
77 | "Identify the cell lineage that these 100 genes, ordered by expression magnitude, suggest.",
78 | "Decipher the cell type connected to the top 100 genes by expression level.",
79 | "Clarify the cell type that most likely expresses these 100 genes at high levels.",
80 | "Characterize the cell type associated with the highest expression among these 100 genes.",
81 | "Trace the likely cell type for these 100 genes, characterized by their elevated expression levels.",
82 | "Profile the cell type most aligned with the expression patterns of these 100 genes.",
83 | "Outline the cell type that is most probably expressed by these top 100 genes.",
84 | "Summarize the cell type indicative of these 100 genes with the highest expression rankings.",
85 | "Highlight the cell type that these 100 genes, when highly expressed, most likely indicate.",
86 | "Interpret the cell type likely reflected by the top 100 genes according to their expression levels.",
87 | "Sketch the probable cell type that these 100 genes with elevated expression levels delineate.",
88 | "Extrapolate the cell type likely exemplified by these top 100 expressed genes.",
89 | "Map out the cell type that is suggested by the high expression of these 100 genes.",
90 | "Predict the cell type signified by the highest expression levels in these 100 genes.",
91 | "Synthesize the probable cell type from the expression data of these 100 top genes.",
92 | "Derive the cell type most indicative of these 100 genes with the highest expression signatures.",
93 | "Elaborate on the cell type that these 100 genes with top expression levels are suggesting.",
94 | "Formulate the cell type hypothesis based on the expression profile of these 100 genes.",
95 | "Project the cell type that would typically express these 100 genes at high levels.",
96 | "Render the likely cell type associated with the expression pattern of these 100 genes.",
97 | "Dissect the probable cell type based on the high expression of these 100 genes.",
98 | "Propose the cell type that is inferred by the expression data of these 100 top genes.",
99 | "Assess the cell type that these 100 highly expressed genes most plausibly suggest.",
100 | "Conceptualize the cell type that would manifest these 100 genes with their peak expression levels.",
101 | "Analyze the cell type that is most resonant with these 100 genes’ expression profiles.",
102 | "Frame the cell type likely to be identified by the expression patterns of these 100 genes.",
103 | "Articulate the cell type that these 100 genes with the highest expression might typify."
104 | ]
105 |
106 | print(len(cell500))
107 |
108 |
109 | def get_sentence(path_input,path_output):
110 | """
111 | Processes input JSON data to filter and annotate items based on two criteria:
112 | 1. The item must be part of a Pseudo-cell Generation task.
113 | 2. The cell type associated with the item must have appeared at least 500 times in the training set.
114 |
115 | This ensures that only relevant data for Pseudo-cell Generation tasks with sufficiently represented cell types are included.
116 |
117 | Args:
118 | - path_input: Path to the input JSON file.
119 | - path_output: Path for saving the processed output JSON file.
120 | """
121 | with open(path_input, "r", encoding="utf-8") as file:
122 | data = json.load(file)
123 | filtered_data = []
124 | for item in tqdm(data, desc="Processing", unit="data"):
125 | flag=0
126 | for prompt in initial_prompt_templates:
127 | if(item['source'].startswith(prompt)):
128 | flag=1
129 | # If not a match, it's considered a Pseudo-cell Generation task
130 | if flag==0:
131 | for cell_type in cell500:
132 | if cell_type in item['source']:
133 | item['cell_type']=cell_type.strip()
134 | item['gene']=item['my_target']
135 | filtered_data.append(item)
136 | break
137 | print(len(filtered_data))
138 | # Output the filtered and annotated data to a new JSON file
139 | with open(path_output, "w", encoding="utf-8") as output_file:
140 | json.dump(filtered_data, output_file, indent=4, ensure_ascii=False)
141 |
142 | # Define input and output file paths
143 | input_path='your_sum_test.json'
144 | output_Pseudo-cell_path='your_Pseudo-cell_test.json'
145 |
146 | # Call the function with the specified input and output paths
147 | get_sentence(input_path,output_Pseudo-cell_path)
148 |
--------------------------------------------------------------------------------
/workflow_data/merge.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 |
5 | def merge_and_shuffle_json(file1_json_file_path, file2_json_file_path, file3_json_file_path):
6 | """
7 | Merges and shuffles data from two JSON files and returns the combined list.
8 |
9 | :param file1: Path to the first JSON file.
10 | :param file2: Path to the second JSON file.
11 | :param file3: Path to the third JSON file.
12 |
13 | :return: Merged and shuffled list of data from both files.
14 | """
15 | with open(file1_json_file_path, 'r') as f1, open(file2_json_file_path, 'r') as f2, open(file3_json_file_path, 'r') as f3:
16 | data1 = json.load(f1)
17 | data2 = json.load(f2)
18 | data3 = json.load(f3)
19 | # Merge and shuffle the data
20 | merged_data = data1 + data2
21 | random.seed(42)
22 | random.shuffle(merged_data)
23 | # Merge and shuffle the data
24 | merged_data = data3 +merged_data
25 | random.seed(42)
26 | random.shuffle(merged_data)
27 | return merged_data
28 |
29 | # File paths
30 | mouse_train_json_file_path = 'mouse/train.json'
31 | mouse_val_json_file_path = 'mouse/valid.json'
32 | mouse_test_json_file_path = 'mouse/test.json'
33 |
34 | GSE117872_train_json_file_path = 'GSE117872/train.json'
35 | GSE117872_val_json_file_path = 'GSE117872/valid.json'
36 | GSE117872_test_json_file_path = 'GSE117872/test.json'
37 |
38 | GSE149383_train_json_file_path = 'GSE149383/train.json'
39 | GSE149383_val_json_file_path = 'GSE149383/valid.json'
40 | GSE149383_test_json_file_path = 'GSE149383/test.json'
41 | # Output file paths
42 | sum_train_json_file_path = 'sum/train.json'
43 | sum_val_json_file_path = 'sum/valid.json'
44 | sum_test_json_file_path = 'sum/test.json'
45 |
46 | # Merging and shuffling
47 | sum_train_data = merge_and_shuffle_json(GSE117872_train_json_file_path, GSE149383_train_json_file_path, mouse_train_json_file_path)
48 | sum_val_data = merge_and_shuffle_json(GSE117872_val_json_file_path, GSE149383_val_json_file_path, mouse_val_json_file_path)
49 | sum_test_data = merge_and_shuffle_json(GSE117872_test_json_file_path, GSE149383_test_json_file_path, mouse_test_json_file_path)
50 |
51 | # Assuming you intend to write the merged and shuffled data to new JSON files
52 | def write_to_file(data, file_path):
53 | """
54 | Writes given data to a file specified by file_path.
55 |
56 | :param data: Data to be written.
57 | :param file_path: Path for the output file.
58 | """
59 | with open(file_path, 'w') as outfile:
60 | json.dump(data, outfile, indent=4)
61 |
62 |
63 | # Writing the merged and shuffled data to files
64 | write_to_file(sum_train_data, sum_train_json_file_path)
65 | write_to_file(sum_val_data, sum_val_json_file_path)
66 | write_to_file(sum_test_data, sum_test_json_file_path)
67 |
--------------------------------------------------------------------------------
/workflow_data/mouse_to_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | from datasets import concatenate_datasets, load_from_disk
3 | from random import choice,seed
4 | from src.prompts import construct_cell_type_template, construct_prediction_template
5 | from tqdm import tqdm
6 |
7 | random_seed = 42 # You can use any integer value as the seed
8 | input_path='output_dir_in_transform.py_path'
9 | train_json_file_path = 'mouse/train.json'
10 | val_json_file_path = 'mouse/valid.json'
11 | test_json_file_path = 'mouse/test.json'
12 |
13 | # Load the dataset
14 | dataset = load_from_disk(input_path)
15 | train_dataset, val_dataset,test_dataset = dataset["train"], dataset["valid"],dataset["test"]
16 |
17 | print(dataset)
18 | seed(random_seed)
19 |
20 |
21 | def preprocess_function(examples):
22 | """Preprocess the dataset to generate source and target texts.
23 |
24 | Args:
25 | examples (Dataset): The slice of the dataset to process.
26 |
27 | Returns:
28 | list[dict]: A list of processed data containing source and target texts.
29 | """
30 | text_column = "cell_type"
31 | label_column = "input_ids"
32 | max_length = 1024
33 |
34 | batch_size = len(examples[text_column])
35 | inputs = []
36 | targets = []
37 | for i in tqdm(range(batch_size)):
38 | prompt_type = choice([0, 1, 2])
39 | if prompt_type == 0:
40 | input = construct_cell_type_template(examples["cell_type"][i])
41 | target = " ".join(examples["input_ids"][i].split(" ")[:100])
42 | elif prompt_type == 1:
43 | input = construct_cell_type_template("cell")
44 | target = " ".join(examples["input_ids"][i].split(" ")[:100])
45 | else:
46 | input = construct_prediction_template(
47 | " ".join(examples["input_ids"][i].split(" ")[:100])
48 | )
49 | target = examples["cell_type"][i]
50 |
51 | inputs.append(input)
52 | targets.append(target)
53 |
54 | data_list = []
55 | for input, target in zip(inputs, targets):
56 | data_list.append({"source": input, "target": target})
57 | return data_list
58 |
59 | def save_to_json(data, file_path):
60 | """Save data to a JSON file.
61 |
62 | Args:
63 | data (list): The data to be saved.
64 | file_path (str): The path where the JSON file will be saved.
65 | """
66 | with open(file_path, 'w', encoding='utf-8') as json_file:
67 | json.dump(data, json_file, ensure_ascii=False, indent=4)
68 |
69 | # Process and save datasets
70 | for dataset, path in zip([train_dataset, val_dataset, test_dataset], [train_json_file_path, val_json_file_path, test_json_file_path]):
71 | preprocessed_data = preprocess_function(dataset)
72 | save_to_json(preprocessed_data, path)
73 |
74 |
--------------------------------------------------------------------------------
/workflow_data/sentence_to_experssion.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import anndata
3 | import json
4 | from tqdm import tqdm
5 | import scanpy as sc
6 | import sys
7 | import pandas as pd
8 |
9 | from src.utils import post_process_generated_cell_sentences, convert_cell_sentence_back_to_expression_vector
10 | def get_vocab(path):
11 | """
12 | Load and process the gene vocabulary from a specified file.
13 |
14 | Parameters:
15 | - path: Path to the vocabulary file.
16 |
17 | Returns:
18 | - List of unique, uppercase gene names.
19 | """
20 | global_vocab = set()
21 | with open(path, "r") as fp:
22 | for line in fp:
23 | line = line.rstrip() # remove end whitespace, e.g. newline
24 | line_elements = line.split(" ")
25 | gene_name = line_elements[0]
26 | global_vocab.add(gene_name)
27 |
28 | global_vocab_list = list(global_vocab)
29 | global_vocab_list = [gene_name.upper() for gene_name in global_vocab_list]
30 | return global_vocab_list
31 |
32 |
33 |
34 | def reconstruct(path_input,path_output):
35 | """
36 | Reconstruct cell expression data from sentences and save to an AnnData object.
37 |
38 | Parameters:
39 | - path_input: Path to the input JSON file containing cell sentences.
40 | - path_output: Path to the output file where the AnnData object will be saved.
41 | """
42 | with open(path_input, "r", encoding="utf-8") as file:
43 | data = json.load(file)
44 |
45 |
46 | all_cell_sentences_converted_back_to_expression = []
47 | # Load gene vocabulary
48 |
49 | global_vocab_list=get_vocab('vocab_human.txt')
50 | #The text file named vocab_human.txt is typically located within the cell_sentences subdirectory of the output_dir, as specified by the transform.py script.
51 |
52 | # Load transformation parameters
53 | dataset_df = pd.read_csv("transformation_metrics_and_parameters.csv")
54 | #The CSV file named transformation_metrics_and_parameters.csv is typically located within the eval_output_dir, as specified by the transform.py script.
55 |
56 | slope = dataset_df.iloc[0, 2].item()
57 | intercept = dataset_df.iloc[0, 3].item()
58 | print(f"slope: {slope:.4f}, intercept: {intercept:.4f}")
59 |
60 | for cell_idx in tqdm(data):
61 | cell_sentence_list = cell_idx["my_target"]
62 | words = cell_sentence_list.split()
63 | cell_sentence_str = " ".join(words)
64 | post_processed_sentence, num_genes_replaced = post_process_generated_cell_sentences(
65 | cell_sentence=cell_sentence_str,
66 | global_dictionary=global_vocab_list,
67 | replace_nonsense_string="NOT_A_GENE",
68 | )
69 | post_processed_sentence=post_processed_sentence[:100]
70 | reconstructed_expr_vec = convert_cell_sentence_back_to_expression_vector(
71 | cell_sentence=post_processed_sentence,
72 | global_dictionary=global_vocab_list,
73 | slope=slope,
74 | intercept=intercept
75 | )
76 | all_cell_sentences_converted_back_to_expression.append(reconstructed_expr_vec)
77 | all_cell_sentences_converted_back_to_expression = np.stack(all_cell_sentences_converted_back_to_expression, dtype=np.float32)
78 | all_cell_sentences_converted_back_to_expression.shape
79 | reconstructed_adata = sc.AnnData(X=all_cell_sentences_converted_back_to_expression)
80 |
81 | if 'cell_type' not in reconstructed_adata.obs.columns:
82 | reconstructed_adata.obs['cell_type'] = ''
83 | reconstructed_adata.var.index = global_vocab_list
84 |
85 | for i, cell_data in enumerate(data):
86 | reconstructed_adata.obs["cell_type"][i] = cell_data["cell_type"]
87 |
88 |
89 | reconstructed_adata.write_h5ad(path_output)
90 | json_path='yout_json_path.json'
91 | h5ad_path="yout_h5ad_path.h5ad"
92 | reconstruct(json_path,h5ad_path)
93 |
94 |
--------------------------------------------------------------------------------
/workflow_data/split.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | # Set a fixed seed for reproducibility
5 | random.seed(42)
6 |
7 | # Load the JSON data
8 | input_path='GSE117872.json'
9 | train_json_file_path='GSE117872/train.json'
10 | val_json_file_path='GSE117872/valid.json'
11 | test_json_file_path='GSE117872/test.json'
12 | with open(input_path, 'r') as file:
13 | data = json.load(file)
14 |
15 | # Shuffle the data randomly for distribution in datasets
16 | random.shuffle(data)
17 |
18 | # Calculate the split indices for training, validation, and test sets
19 | train_size = int(0.8 * len(data)) # 80% for training
20 | val_size = int(0.1 * len(data)) # 10% for validation
21 |
22 | # Split the data into training, validation, and test sets
23 | train_data = data[:train_size]
24 | val_data = data[train_size:train_size + val_size]
25 | test_data = data[train_size + val_size:]
26 |
27 | # Print the sizes of the full dataset and each split to verify
28 | print(f"Total: {len(data)}, Train: {len(train_data)}, Validation: {len(val_data)}, Test: {len(test_data)}")
29 |
30 | # Function to save datasets back to JSON files with indentation for readability
31 | def save_to_json(file_name, data):
32 | with open(file_name, 'w') as file:
33 | json.dump(data, file, indent=4)
34 |
35 | # Save each dataset to its respective JSON file
36 | save_to_json(train_json_file_path, train_data)
37 | save_to_json(val_json_file_path, val_data)
38 | save_to_json(test_json_file_path, test_data)
39 |
--------------------------------------------------------------------------------
/workflow_data/src/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This file was initially developed by the project at https://github.com/vandijklab/cell2sentence-ft.
3 | Many thanks for their contributions to this field. It adheres to the Attribution-NonCommercial-ShareAlike
4 | 4.0 International License.
5 |
6 | If you use this file, please cite the papers "Levine et al., Cell2Sentence: Teaching Large Language
7 | Models the Language of Biology. 2023 (https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3)" and
8 | "Rahul M Dhodapkar. Representing cells as sentences enables natural-language processing for single-cell
9 | transcriptomics. 2022 (https://www.biorxiv.org/content/10.1101/2022.09.18.508438)."
10 | """
11 |
12 | __version__ = "0.0.1"
13 | from src.csdata import CSData
14 |
--------------------------------------------------------------------------------
/workflow_data/src/csdata.py:
--------------------------------------------------------------------------------
1 | """
2 | This file was initially developed by the project at https://github.com/vandijklab/cell2sentence-ft.
3 | Many thanks for their contributions to this field. It adheres to the Attribution-NonCommercial-ShareAlike
4 | 4.0 International License.
5 |
6 | If you use this file, please cite the papers "Levine et al., Cell2Sentence: Teaching Large Language
7 | Models the Language of Biology. 2023 (https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3)" and
8 | "Rahul M Dhodapkar. Representing cells as sentences enables natural-language processing for single-cell
9 | transcriptomics. 2022 (https://www.biorxiv.org/content/10.1101/2022.09.18.508438)."
10 | """
11 |
12 | import zlib
13 |
14 | import igraph as ig
15 | import jellyfish
16 | import numpy as np
17 | import pandas as pd
18 | from scipy import stats
19 | from sklearn import model_selection
20 |
21 |
22 | def zlib_ncd(s1, s2):
23 | """
24 | Return the zlib normalized compression distance between two strings
25 | """
26 | bs1 = bytes(s1, "utf-8")
27 | bs2 = bytes(s2, "utf-8")
28 |
29 | comp_cat = zlib.compress(bs1 + bs2)
30 | comp_bs1 = zlib.compress(bs1)
31 | comp_bs2 = zlib.compress(bs2)
32 |
33 | return (len(comp_cat) - min(len(comp_bs1), len(comp_bs2))) / max(
34 | len(comp_bs1), len(comp_bs2)
35 | )
36 |
37 |
38 | class CSData:
39 | """
40 | Lightweight wrapper class to wrap cell2sentence results.
41 | """
42 |
43 | def __init__(self, vocab, sentences, cell_names, feature_names):
44 | self.vocab = vocab # Ordered Dictionary: {gene_name: num_expressed_cells}
45 | self.sentences = sentences # list of sentences
46 | self.cell_names = cell_names # list of cell names
47 | self.feature_names = feature_names # list of gene names
48 | self.distance_matrix = None
49 | self.distance_params = None
50 | self.knn_graph = None
51 |
52 | def create_distance_matrix(self, dist_type="jaro", prefix_len=20):
53 | """
54 | Calculate the distance matrix for the CSData object with the specified
55 | edit distance method. Currently supported: ("levenshtein").
56 |
57 | Distance caculated as d = 1 / (1 + x) where x is the similarity score.
58 | """
59 | if self.distance_matrix is not None and (
60 | self.distance_params["dist_type"] == dist_type
61 | and self.distance_params["prefix_len"] == prefix_len
62 | ):
63 | return self.distance_matrix
64 |
65 | dist_funcs = {
66 | "levenshtein": jellyfish.levenshtein_distance,
67 | "damerau_levenshtein": jellyfish.damerau_levenshtein_distance,
68 | "jaro": lambda x, y: 1 - jellyfish.jaro_similarity(x, y), # NOQA
69 | "jaro_winkler": lambda x, y: 1
70 | - jellyfish.jaro_winkler_similarity(x, y), # NOQA
71 | "zlib_ncd": zlib_ncd,
72 | }
73 |
74 | is_symmetric = {
75 | "levenshtein": True,
76 | "damerau_levenshtein": True,
77 | "jaro": True,
78 | "jaro_winkler": True,
79 | "zlib_ncd": False,
80 | }
81 |
82 | mat = np.zeros(shape=(len(self.sentences), len(self.sentences)))
83 |
84 | for i, s_i in enumerate(self.sentences):
85 | for j, s_j in enumerate(self.sentences):
86 | if j < i and is_symmetric[dist_type]:
87 | mat[i, j] = mat[j, i]
88 | continue
89 |
90 | mat[i, j] = dist_funcs[dist_type](s_i[:prefix_len], s_j[:prefix_len])
91 |
92 | self.distance_params = {"dist_type": dist_type, "prefix_len": prefix_len}
93 | self.distance_matrix = mat
94 | # reset KNN graph if previously computed on old distance
95 | self.knn_graph = None
96 |
97 | return self.distance_matrix
98 |
99 | def create_knn_graph(self, k=15):
100 | """
101 | Create KNN graph
102 | """
103 | if self.distance_matrix is None:
104 | raise RuntimeError(
105 | 'cannot "build_knn_graph" without running "create_distance_matrix" first'
106 | )
107 |
108 | adj_matrix = 1 / (1 + self.distance_matrix)
109 | knn_mask = np.zeros(shape=adj_matrix.shape)
110 |
111 | for i in range(adj_matrix.shape[0]):
112 | for j in np.argsort(-adj_matrix[i])[:k]:
113 | knn_mask[i, j] = 1
114 |
115 | masked_adj_matrix = knn_mask * adj_matrix
116 |
117 | self.knn_graph = ig.Graph.Weighted_Adjacency(masked_adj_matrix).as_undirected()
118 | return self.knn_graph
119 |
120 | def create_rank_matrix(self):
121 | """
122 | Generates a per-cell rank matrix for use with matrix-based tools. Features with zero
123 | expression are zero, while remaining features are ranked according to distance from
124 | the end of the rank list.
125 | """
126 | full_rank_matrix = np.zeros((len(self.cell_names), len(self.feature_names)))
127 |
128 | for i, s in enumerate((self.sentences)):
129 | for rank_position, c in enumerate(s):
130 | full_rank_matrix[i, ord(c)] = len(s) - rank_position
131 |
132 | return full_rank_matrix
133 |
134 | def find_differential_features(self, ident_1, ident_2=None, min_pct=0.1):
135 | """
136 | Perform differential feature rank testing given a set of sentence indexes.
137 | If only one group is given, the remaining sentences are automatically used
138 | as the comparator group.
139 | """
140 |
141 | if ident_2 is None:
142 | ident_2 = list(set(range(len(self.sentences))).difference(set(ident_1)))
143 |
144 | full_rank_matrix = self.create_rank_matrix()
145 | feature_ixs_to_test = np.array(
146 | np.sum(full_rank_matrix > 0, axis=0) > min_pct * len(self.cell_names)
147 | ).nonzero()[0]
148 |
149 | stats_results = []
150 | for f in feature_ixs_to_test:
151 | wilcox_stat, pval = stats.ranksums(
152 | x=full_rank_matrix[ident_1, f], y=full_rank_matrix[ident_2, f]
153 | )
154 | stats_results.append(
155 | {
156 | "feature": self.feature_names[f],
157 | "w_stat": wilcox_stat,
158 | "p_val": pval,
159 | "mean_rank_group_1": np.mean(full_rank_matrix[ident_1, f]),
160 | "mean_rank_group_2": np.mean(full_rank_matrix[ident_2, f]),
161 | }
162 | )
163 | return pd.DataFrame(stats_results)
164 |
165 | def get_rank_data_for_feature(self, feature_name, invert=False):
166 | """
167 | Return an array of ranks corresponding to the prescence of a gene within
168 | each cell sentence. If a gene is not present in a cell sentence, np.nan
169 | is returned for that cell.
170 |
171 | Note that this returns rank (1-indexed), not position within the underlying
172 | gene rank list string (0-indexed).
173 | """
174 | feature_code = -1
175 | for i, k in enumerate(self.vocab.keys()):
176 | if k == feature_name:
177 | feature_code = i
178 | break
179 |
180 | if feature_code == -1:
181 | raise ValueError(
182 | "invalid feature {} not found in vocabulary".format(feature_name)
183 | )
184 | feature_enc = chr(feature_code)
185 |
186 | rank_data_vec = np.full((len(self.cell_names)), np.nan)
187 | for i, s in enumerate(self.sentences):
188 | ft_loc = s.find(feature_enc)
189 | if invert:
190 | rank_data_vec[i] = len(s) - ft_loc if ft_loc != -1 else np.nan
191 | else:
192 | rank_data_vec[i] = ft_loc + 1 if ft_loc != -1 else np.nan
193 |
194 | return rank_data_vec
195 |
196 | def create_sentence_strings(self, delimiter=" "):
197 | """
198 | Convert internal sentence representation (arrays of ints) to traditional
199 | delimited character strings for integration with text-processing utilities.
200 | """
201 | if np.any([delimiter in x for x in self.feature_names]):
202 | raise ValueError(
203 | (
204 | 'feature names cannot contain sentence delimiter "{}", '
205 | + "please re-format and try again"
206 | ).format(delimiter)
207 | )
208 |
209 | enc_map = list(self.vocab.keys())
210 |
211 | joined_sentences = []
212 | for s in self.sentences:
213 | joined_sentences.append(delimiter.join([enc_map[ord(x)] for x in s]))
214 |
215 | return np.array(joined_sentences, dtype=object)
216 |
217 | def create_sentence_lists(self):
218 | """
219 | Convert internal sentence representation (arrays of ints) to
220 | sentence lists compatible with gensim
221 | """
222 | enc_map = list(self.vocab.keys())
223 |
224 | joined_sentences = []
225 | for s in self.sentences:
226 | joined_sentences.append([enc_map[ord(x)] for x in s])
227 |
228 | return np.array(joined_sentences, dtype=object)
229 |
230 | def train_test_validation_split(
231 | self, train_pct=0.8, test_pct=0.1, val_pct=0.1, random_state=42
232 | ):
233 | """
234 | Create train, test, and validation splits of the data given the supplied
235 | percentages with a specified random state for reproducibility.
236 |
237 | Arguments:
238 | sentences: an numpy.ndarray of sentences to be split.
239 | train_pct: Default = 0.6. the percentage of samples to assign to the training set.
240 | test_pct: Default = 0.2. the percentage of samples to assign to the test set.
241 | val_pct: Default = 0.2. the percentage of samples to assign to the validation set.
242 | Return:
243 | (train_sentences, test_sentences, val_sentences) split from the
244 | originally supplied sentences array.
245 | """
246 | if train_pct + test_pct + val_pct != 1:
247 | raise ValueError(
248 | "train_pct = {} + test_pct = {} + val_pct = {} do not sum to 1.".format(
249 | train_pct, test_pct, val_pct
250 | )
251 | )
252 |
253 | s_1 = test_pct
254 | s_2 = val_pct / (1 - test_pct)
255 |
256 | X = range(len(self.sentences))
257 | X_train, X_test = model_selection.train_test_split(
258 | X, test_size=s_1, random_state=random_state
259 | )
260 |
261 | X_train, X_val = model_selection.train_test_split(
262 | X_train, test_size=s_2, random_state=random_state
263 | )
264 |
265 | return (X_train, X_test, X_val)
266 |
--------------------------------------------------------------------------------
/workflow_data/src/prompts.py:
--------------------------------------------------------------------------------
1 | """
2 | This file was initially developed by the project at https://github.com/vandijklab/cell2sentence-ft.
3 | Many thanks for their contributions to this field. It adheres to the Attribution-NonCommercial-ShareAlike
4 | 4.0 International License.
5 |
6 | If you use this file, please cite the papers "Levine et al., Cell2Sentence: Teaching Large Language
7 | Models the Language of Biology. 2023 (https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3)" and
8 | "Rahul M Dhodapkar. Representing cells as sentences enables natural-language processing for single-cell
9 | transcriptomics. 2022 (https://www.biorxiv.org/content/10.1101/2022.09.18.508438)."
10 | """
11 | import random
12 |
13 |
14 | def construct_cell_type_template(cell_type):
15 | vowels = {"a", "e", "i", "o", "u", "A", "E", "I", "O", "U"}
16 | if cell_type[0] in vowels:
17 | cell_type = f"n {cell_type}"
18 | else:
19 | cell_type = f" {cell_type}"
20 |
21 | cell_type_templates = [
22 | "List the initial 100 genes associated with a{}: ",
23 | "Provide the top 100 genes for a{}: ",
24 | "Identify 100 genes corresponding to a{}: ",
25 | "Catalog the 100 primary genes for a{}: ",
26 | "Name the first set of 100 genes in a{}: ",
27 | "Outline 100 genes initially associated with a{}: ",
28 | "Specify the first 100 genes linked to a{}: ",
29 | "Show the leading 100 genes of a{}: ",
30 | "Enumerate the 100 starting genes for a{}: ",
31 | "List the 100 most highly expressed genes in a{}, ordered from highest to lowest expression: ",
32 | "Provide a ranking of the top 100 expressed genes in a{} by decreasing levels: ",
33 | "Detail the 100 genes with the greatest expression levels in a{}: ",
34 | "Identify the 100 genes in a{} with the highest expression, sorted in descending order: ",
35 | "Catalog the 100 most active genes in a{} cell, ordered by decreasing expression: ",
36 | "Name the 100 genes with peak expression levels in a{}, from highest to lowest: ",
37 | "Describe the 100 most abundantly expressed genes in a{}, in descending order: ",
38 | "Which 100 genes have the highest expression in a{}? List them in descending order: ",
39 | "Show the 100 most significantly expressed genes in a{}, organized by decreasing levels: ",
40 | "Enumerate the 100 genes with the strongest expression in a{}, ranked from highest to lowest: ",
41 | "List the top 100 genes by decreasing expression specifically for a{}: ",
42 | "Provide the first 100 highly expressed genes in a{}, sorted by expression level: ",
43 | "Identify the initial 100 genes with peak expression levels in a{}: ",
44 | "Catalog the top 100 expressed genes corresponding to a{}, from highest to lowest: ",
45 | "What are the first 100 genes with the greatest expression in a{}? ",
46 | "Name the top 100 genes sorted by decreasing expression for a{}: ",
47 | "Show the leading 100 genes by expression level associated with a{}: ",
48 | "Enumerate the first 100 genes by highest expression levels in a{}: ",
49 | "Specify the 100 primary genes in a{} ordered by decreasing expression: ",
50 | "Outline the initial set of 100 genes with high expression in a{}: ",
51 | "Detail the 100 starting genes for a{}, ranked by expression level: ",
52 | "Rank the first 100 genes by expression level found in a{}: ",
53 | "Describe the 100 most active genes initially associated with a{}: ",
54 | "Which are the first 100 genes in a{} sorted by decreasing expression? ",
55 | "Provide a ranking of the initial 100 genes by decreasing expression levels in a{}: ",
56 | "List 100 primary genes from highest to lowest expression in a{}: ",
57 | "Catalog the 100 initial genes for a{}, ordered by expression level from highest to lowest: ",
58 | "Identify the 100 leading genes in a{} based on expression levels: ",
59 | "Show the 100 primary genes for a{}, sorted by decreasing expression: ",
60 | "Enumerate the 100 most abundantly expressed starting genes in a{}: ",
61 | ## add
62 | "Reveal the first 100 genes with the highest expression in a{}, sorted in order of decreasing levels: ",
63 | "Compile a list of the top 100 genes by expression in a{}, descending from highest to lowest: ",
64 | "Pinpoint the 100 leading genes based on expression levels in a{}: ",
65 | "Disclose the initial 100 genes showing the highest expression in a{}, in a descending sequence: ",
66 | "Present the foremost 100 genes by expression in a{}, ranked from the highest to the lowest: ",
67 | "Highlight the 100 genes with the utmost expression in a{}, arranged by diminishing levels: ",
68 | "Unveil the top 100 genes characterized by their expression in a{}, in descending order: ",
69 | "Report on the first 100 genes with maximum expression in a{}, listed from highest to lowest: ",
70 | "Profile the 100 genes leading in expression in a{}, ordered by decreasing levels: ",
71 | "Summarize the primary 100 genes by expression level in a{}, from the highest to the lowest: ",
72 | "Expose the 100 genes at the pinnacle of expression in a{}, sequenced by decreasing order: ",
73 | "Discern the top 100 genes for expression in a{}, aligned from highest to lowest: ",
74 | "Uncover the first 100 genes exhibiting the highest levels of expression in a{}, organized by descending order: ",
75 | "Depict the leading 100 genes in terms of expression in a{}, ranked by decreasing levels: ",
76 | "Characterize the 100 most prominently expressed genes in a{}, in a downward sequence: ",
77 | "Account for the 100 genes with superior expression in a{}, sorted from highest to lowest: ",
78 | "Render the 100 most expressed genes in a{}, categorized by descending expression levels: ",
79 | "Illustrate the 100 top-expressed genes in a{}, ordered by decreasing levels of expression: ",
80 | "Convey the initial 100 genes with the highest level of expression in a{}, from highest to lowest: ",
81 | "Sketch the primary 100 genes based on their expression in a{}, descending from highest expression: ",
82 | "Clarify the 100 genes leading in expression within a{}, sorted in decreasing order of expression: ",
83 | "Communicate the first 100 genes ranked by expression in a{}, from the highest to the lowest levels: ",
84 | "Indicate the 100 most prominently featured genes in a{} based on expression, in reverse order: ",
85 | "Relate the top 100 genes according to expression in a{}, descending by level: ",
86 | "Itemize the 100 genes with the leading expression figures in a{}, ordered from the highest downward: ".
87 | ]
88 |
89 | selected_template = random.choice(cell_type_templates)
90 |
91 | formatted_template = selected_template.format(cell_type)
92 |
93 | return formatted_template
94 |
95 |
96 | def construct_prediction_template(genes):
97 | initial_prompt_templates = [
98 | "Identify the cell type most likely associated with these 100 highly expressed genes listed in descending order.",
99 | "Determine the probable cell type for the following 100 genes with the highest expression levels.",
100 | "Indicate the cell type typically linked to these 100 top-expressed genes.",
101 | "Specify the most likely cell type based on these 100 genes sorted by decreasing expression.",
102 | "Find the cell type that corresponds to these top 100 highly expressed genes.",
103 | "Point out the cell type these 100 genes with peak expression levels most commonly represent.",
104 | "Deduce the cell type likely to have these 100 highly expressed genes.",
105 | "Pinpoint the cell type that these 100 genes with the highest expression levels are most likely associated with.",
106 | "Ascertain the cell type from which these 100 highly expressed genes likely originate.",
107 | "Reveal the likely cell type linked to these 100 genes, listed by decreasing expression levels.",
108 | "Uncover the most probable cell type related to these 100 highly expressed genes.",
109 | "Indicate the cell type that would logically have these 100 top-expressed genes.",
110 | "Provide the likely cell type based on these 100 genes with high expression levels.",
111 | "Isolate the cell type commonly associated with these 100 top genes.",
112 | "Establish the cell type that these 100 genes with the highest expression levels are most likely from.",
113 | "Discern the likely cell type for these 100 genes sorted by expression level.",
114 | "Note the cell type typically associated with these 100 most expressed genes.",
115 | "Report the cell type most probably linked to these 100 genes with peak expression.",
116 | "Conclude the most likely cell type these 100 genes are associated with.",
117 | "State the probable cell type connected to these 100 top-expressed genes.",
118 | "What cell type is most likely represented by these top 100 highly expressed genes?",
119 | "Identify the probable cell type for these 100 genes with the highest expression levels.",
120 | "Which cell type is typically associated with these 100 most expressed genes?",
121 | "Can you deduce the cell type based on this list of 100 highly expressed genes?",
122 | "Given these 100 genes sorted by decreasing expression, what is the likely cell type?",
123 | "Based on these top 100 genes, which cell type are they most commonly found in?",
124 | "What type of cell is most likely to express these 100 genes in decreasing order of expression?",
125 | "What is the probable cell type these 100 highly expressed genes are associated with?",
126 | "From which cell type do these 100 most expressed genes likely originate?",
127 | "Determine the cell type likely associated with these 100 genes listed by decreasing expression.",
128 | "Given these 100 highly expressed genes, can you identify the likely cell type?",
129 | "Infer the cell type based on these 100 genes with the highest expression levels.",
130 | "Which cell type is likely to have these 100 genes with the highest expression?",
131 | "Could you specify the cell type most likely associated with these top 100 genes?",
132 | "What cell type would you associate with these 100 highly expressed genes?",
133 | "Can you tell the likely cell type for these 100 genes, sorted by decreasing expression?",
134 | "What is the likely cell type based on these 100 top expressed genes?",
135 | "Identify the cell type most commonly associated with these 100 genes.",
136 | "Based on these genes listed by decreasing expression, what cell type are they likely from?",
137 | "Given these 100 genes with high expression levels, what is the probable cell type?",
138 | "Which cell type is expected to have these 100 genes with the highest levels of expression?",
139 | "What is the most probable cell type based on these 100 genes with peak expression levels?",
140 | "What cell type would most likely have these 100 top expressed genes?",
141 | "Which cell type most probably corresponds to these 100 highly expressed genes?",
142 | "Could you determine the likely cell type based on these 100 most expressed genes?",
143 | "What type of cell would most likely contain these 100 genes with highest expression?",
144 | "Based on the list of 100 genes, what is the most likely corresponding cell type?",
145 | "Please identify the cell type that these 100 highly expressed genes are most likely linked to.",
146 | "Given these 100 genes ranked by expression, what would be the associated cell type?",
147 | "What would be the probable cell type for these 100 genes, listed by decreasing expression?",
148 | "Can you deduce the most likely cell type for these top 100 highly expressed genes?",
149 | "Identify the likely cell type these 100 genes with top expression could represent.",
150 | "Based on the following 100 genes, can you determine the cell type they are commonly found in?",
151 | "What is the likely originating cell type of these 100 top expressed genes?",
152 | "Specify the cell type most commonly linked with these 100 highly expressed genes.",
153 | "Which cell type would you expect to find these 100 genes with high expression levels?",
154 | "Indicate the probable cell type these 100 genes are commonly associated with.",
155 | "According to these 100 genes with highest expression, what cell type are they most likely from?",
156 | "Which cell type is these 100 genes with the highest expression levels most commonly found in?",
157 | "Could you point out the likely cell type linked with these 100 genes sorted by decreasing expression?",
158 | ## add
159 | "Ascertain which cell type is most closely associated with these 100 genes exhibiting the highest levels of expression.",
160 | "Elucidate the cell type that these 100 genes, ranked by their expression, most closely correlate with.",
161 | "Predict the cell type associated with the highest expression of these 100 genes.",
162 | "Identify the cell lineage that these 100 genes, ordered by expression magnitude, suggest.",
163 | "Decipher the cell type connected to the top 100 genes by expression level.",
164 | "Clarify the cell type that most likely expresses these 100 genes at high levels.",
165 | "Characterize the cell type associated with the highest expression among these 100 genes.",
166 | "Trace the likely cell type for these 100 genes, characterized by their elevated expression levels.",
167 | "Profile the cell type most aligned with the expression patterns of these 100 genes.",
168 | "Outline the cell type that is most probably expressed by these top 100 genes.",
169 | "Summarize the cell type indicative of these 100 genes with the highest expression rankings.",
170 | "Highlight the cell type that these 100 genes, when highly expressed, most likely indicate.",
171 | "Interpret the cell type likely reflected by the top 100 genes according to their expression levels.",
172 | "Sketch the probable cell type that these 100 genes with elevated expression levels delineate.",
173 | "Extrapolate the cell type likely exemplified by these top 100 expressed genes.",
174 | "Map out the cell type that is suggested by the high expression of these 100 genes.",
175 | "Predict the cell type signified by the highest expression levels in these 100 genes.",
176 | "Synthesize the probable cell type from the expression data of these 100 top genes.",
177 | "Derive the cell type most indicative of these 100 genes with the highest expression signatures.",
178 | "Elaborate on the cell type that these 100 genes with top expression levels are suggesting.",
179 | "Formulate the cell type hypothesis based on the expression profile of these 100 genes.",
180 | "Project the cell type that would typically express these 100 genes at high levels.",
181 | "Render the likely cell type associated with the expression pattern of these 100 genes.",
182 | "Dissect the probable cell type based on the high expression of these 100 genes.",
183 | "Propose the cell type that is inferred by the expression data of these 100 top genes.",
184 | "Assess the cell type that these 100 highly expressed genes most plausibly suggest.",
185 | "Conceptualize the cell type that would manifest these 100 genes with their peak expression levels.",
186 | "Analyze the cell type that is most resonant with these 100 genes’ expression profiles.",
187 | "Frame the cell type likely to be identified by the expression patterns of these 100 genes.",
188 | "Articulate the cell type that these 100 genes with the highest expression might typify."
189 |
190 | ]
191 |
192 | prediction_templates = [
193 | "This is the cell type corresponding to these genes: ",
194 | "These genes are most likely associated with the following cell type: ",
195 | "This is the probable cell type for these genes: ",
196 | "Based on these genes, the corresponding cell type is: ",
197 | "These genes suggest the cell type is most likely: ",
198 | "These genes are indicative of the cell type: ",
199 | "The associated cell type for these genes appears to be: ",
200 | "These genes typically correspond to: ",
201 | "The expected cell type based on these genes is: ",
202 | "These genes are commonly found in: ",
203 | "The cell type that these genes are most commonly linked with is: ",
204 | "Based on the expression levels, the cell type would likely be: ",
205 | "The genes provided are most commonly expressed in: ",
206 | "Given these genes, the likely corresponding cell type is: ",
207 | "The cell type these genes most likely originate from is: ",
208 | "These genes are most frequently associated with the cell type: ",
209 | "From these genes, it can be inferred that the cell type is: ",
210 | "The cell type best represented by these genes is: ",
211 | ## add
212 | "The cell type that typically exhibits these genes is: ",
213 | "The genes in question are characteristic of the following cell type: ",
214 | "Considering the genetic markers, the cell type is likely: ",
215 | "These genetic sequences suggest a strong association with the cell type: ",
216 | "The cellular identity associated with these genes is most likely: ",
217 | "Analysis of these genes points to the cell type being: ",
218 | "The genetic profile suggests the cell type as: ",
219 | "Identifying the cell type, these genes are predominantly linked to: ",
220 | "The predominant cell type for these gene expressions is: ",
221 | "Correlating these genes, we deduce the cell type to be: ",
222 | "Given the genetic evidence, the corresponding cell type is inferred as: ",
223 | "The genetic markers indicate the cell type is: ",
224 | "From the genetic signatures, the associated cell type is: ",
225 | "Linking these genes to a cell type leads us to conclude: ",
226 | "The cell type, as suggested by these genes, is likely: ",
227 | "Drawing from these genes, the cell type is identified as: ",
228 | "The gene analysis suggests a cell type affiliation with: ",
229 | "These genes delineate the cell type as: "
230 | ]
231 |
232 | # build prompt
233 |
234 | selected_initial_template = random.choice(initial_prompt_templates)
235 | selected_prediction_template = random.choice(prediction_templates)
236 |
237 | formatted_template = (
238 | selected_initial_template + " " + genes + " " + selected_prediction_template
239 | )
240 |
241 | return formatted_template
242 |
--------------------------------------------------------------------------------
/workflow_data/src/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file was initially developed by the project at https://github.com/vandijklab/cell2sentence-ft.
3 | Many thanks for their contributions to this field. It adheres to the Attribution-NonCommercial-ShareAlike
4 | 4.0 International License.
5 |
6 | If you use this file, please cite the papers "Levine et al., Cell2Sentence: Teaching Large Language
7 | Models the Language of Biology. 2023 (https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3)" and
8 | "Rahul M Dhodapkar. Representing cells as sentences enables natural-language processing for single-cell
9 | transcriptomics. 2022 (https://www.biorxiv.org/content/10.1101/2022.09.18.508438)."
10 | """
11 | import os
12 | import sys
13 | from collections import OrderedDict
14 | from pathlib import Path
15 | from typing import List
16 | from collections import Counter
17 |
18 | import numpy as np
19 | from scipy import sparse
20 | from sklearn.utils import shuffle
21 | from tqdm import tqdm
22 |
23 | from src.csdata import CSData
24 |
25 | DATA_DIR = Path("data/")
26 | DATA_DIR.mkdir(exist_ok=True, parents=True)
27 |
28 | BASE10_THRESHOLD = 3
29 | SEED = 42
30 |
31 |
32 | def generate_vocabulary(adata):
33 | """
34 | Create a vocabulary dictionary, where each key represents a single gene
35 | token and the value represents the number of non-zero cells in the provided
36 | count matrix.
37 |
38 | Arguments:
39 | adata: an AnnData object to generate cell sentences from. Expects that
40 | `obs` correspond to cells and `vars` correspond to genes.
41 | Return:
42 | a dictionary of gene vocabulary
43 | """
44 | if len(adata.var) > len(adata.obs):
45 | print(
46 | (
47 | "WARN: more variables ({}) than observations ({})... "
48 | + "did you mean to transpose the object (e.g. adata.T)?"
49 | ).format(len(adata.var), len(adata.obs)),
50 | file=sys.stderr,
51 | )
52 |
53 | vocabulary = OrderedDict()
54 | gene_sums = np.ravel(np.sum(adata.X > 0, axis=0))
55 |
56 | for i, name in enumerate(adata.var_names):
57 | vocabulary[name] = gene_sums[i]
58 |
59 | return vocabulary
60 |
61 |
62 | def generate_sentences(adata, prefix_len=None, random_state=42):
63 | """
64 | Transform expression matrix to sentences. Sentences contain gene "words"
65 | denoting genes with non-zero expression. Genes are ordered from highest
66 | expression to lowest expression.
67 |
68 | Arguments:
69 | adata: an AnnData object to generate cell sentences from. Expects that
70 | `obs` correspond to cells and `vars` correspond to genes.
71 | random_state: sets the numpy random state for splitting ties
72 | Return:
73 | a `numpy.ndarray` of sentences, split by delimiter.
74 | """
75 | np.random.seed(random_state)
76 |
77 | if len(adata.var) > len(adata.obs):
78 | print(
79 | (
80 | "WARN: more variables ({}) than observations ({}), "
81 | + "did you mean to transpose the object (e.g. adata.T)?"
82 | ).format(len(adata.var), len(adata.obs)),
83 | file=sys.stderr,
84 | )
85 |
86 | mat = sparse.csr_matrix(adata.X)
87 | sentences = []
88 | for i in tqdm(range(mat.shape[0])):
89 | cols = mat.indices[mat.indptr[i] : mat.indptr[i + 1]]
90 | vals = mat.data[mat.indptr[i] : mat.indptr[i + 1]]
91 |
92 | cols, vals = shuffle(cols, vals)
93 |
94 | sentences.append(
95 | "".join([chr(x) for x in cols[np.argsort(-vals, kind="stable")]])
96 | )
97 |
98 | if prefix_len is not None:
99 | sentences = [s[:prefix_len] for s in sentences]
100 |
101 | return np.array(sentences, dtype=object)
102 |
103 |
104 | def csdata_from_adata(adata, prefix_len=None, random_state=42):
105 | """
106 | Generate a CSData object from an AnnData object.
107 |
108 | Arguments:
109 | adata: an AnnData object to generate cell sentences from. Expects that
110 | `obs` correspond to cells and `vars` correspond to genes.
111 | prefix_len: consider only rank substrings of length prefix_len
112 | random_state: sets the numpy random state for splitting ties
113 | Return:
114 | a CSData object containing a vocabulary, sentences, and associated name data.
115 | """
116 | return CSData(
117 | vocab=generate_vocabulary(adata),
118 | sentences=generate_sentences(
119 | adata, prefix_len=prefix_len, random_state=random_state
120 | ),
121 | cell_names=adata.obs_names,
122 | feature_names=adata.var_names,
123 | )
124 |
125 |
126 | def xlm_prepare_outpath(csdata, outpath, species_tag, params=None):
127 | """
128 | Write formatted data to the outpath file location, for direct processing
129 | by the XLM monolinguistic translation model. If creating an outpath for
130 | multiple species, use the same `outpath` with different `species_tag`
131 | values. They will not conflict so long as species_tags are appropriately
132 | assigned.
133 |
134 | Note that XLM requires a dictionary sorted in order of increasing
135 | frequency of occurence.
136 |
137 | Arguments:
138 | csdata: a CSData object from a single species to be written.
139 | outpath: directory to write files to. Will create this directory
140 | if it does not already exist.
141 | species_tag: a short string to be used as the species name in XLM.
142 | Fulfills functions analaglous to language tags such as
143 | 'en', 'es', or 'zh'.
144 | delimiter: default = ' '. A token delimter for the generated sentences.
145 | params: a parameter object passed to train_test_validation_split:
146 | Return:
147 | None
148 | """
149 |
150 | if params is None:
151 | params = {}
152 |
153 | sentence_strings = csdata.create_sentence_strings(delimiter=" ")
154 | train, test, val = csdata.train_test_validation_split(**params)
155 |
156 | train_sentences = sentence_strings[train]
157 | test_sentences = sentence_strings[test]
158 | val_sentences = sentence_strings[val]
159 |
160 | os.makedirs(outpath, exist_ok=True)
161 | np.save(
162 | os.path.join(outpath, "train_partition_indices.npy"),
163 | np.array(train, dtype=np.int64),
164 | )
165 | np.save(
166 | os.path.join(outpath, "valid_partition_indices.npy"),
167 | np.array(val, dtype=np.int64),
168 | )
169 | np.save(
170 | os.path.join(outpath, "test_partition_indices.npy"),
171 | np.array(test, dtype=np.int64),
172 | )
173 |
174 | print("INFO: Writing Vocabulary File", file=sys.stderr)
175 | fn = "{}/vocab_{}.txt".format(outpath, species_tag)
176 | with open(fn, "w") as f:
177 | for k in tqdm(sorted(csdata.vocab, key=csdata.vocab.get, reverse=True)):
178 | if csdata.vocab[k] == 0:
179 | continue
180 | print("{} {}".format(k, csdata.vocab[k]), file=f)
181 |
182 | print("INFO: Writing Training Sentences", file=sys.stderr)
183 | fn = "{}/train_{}.txt".format(outpath, species_tag)
184 | with open(fn, "w") as f:
185 | for l in tqdm(train_sentences):
186 | print(l, file=f)
187 |
188 | print("INFO: Writing Training Cell Barcodes", file=sys.stderr)
189 | fn = "{}/train_barcodes_{}.txt".format(outpath, species_tag)
190 | with open(fn, "w") as f:
191 | for l in tqdm(csdata.cell_names[train]):
192 | print(l, file=f)
193 |
194 | print("INFO: Writing Testing Sentences", file=sys.stderr)
195 | fn = "{}/test_{}.txt".format(outpath, species_tag)
196 | with open(fn, "w") as f:
197 | for l in tqdm(test_sentences):
198 | print(l, file=f)
199 |
200 | print("INFO: Writing Testing Cell Barcodes", file=sys.stderr)
201 | fn = "{}/train_barcodes_{}.txt".format(outpath, species_tag)
202 | with open(fn, "w") as f:
203 | for l in tqdm(csdata.cell_names[test]):
204 | print(l, file=f)
205 |
206 | print("INFO: Writing Validation Sentences", file=sys.stderr)
207 | fn = "{}/valid_{}.txt".format(outpath, species_tag)
208 | with open(fn, "w") as f:
209 | for l in tqdm(val_sentences):
210 | print(l, file=f)
211 |
212 | print("INFO: Writing Validation Cell Barcodes", file=sys.stderr)
213 | fn = "{}/valid_barcodes_{}.txt".format(outpath, species_tag)
214 | with open(fn, "w") as f:
215 | for l in tqdm(csdata.cell_names[val]):
216 | print(l, file=f)
217 |
218 |
219 | def post_process_generated_cell_sentences(
220 | cell_sentence: str,
221 | global_dictionary: List,
222 | replace_nonsense_string: str = "NOT_A_GENE",
223 | ):
224 | """
225 | Post-processing function for generated cell sentences. Nonsense genes are replaced with
226 | some string, e.g. 'NOT_A_GENE', so that ranks are not changed in generated output.
227 |
228 | Current assumptions in this function:
229 | - We replace nonsense genes with some string, e.g. 'NOT_A_GENE', so that ranks are not
230 | changed in generated output.
231 |
232 | Steps:
233 | 1. Replace any nonsense genes with a specified token, e.g. 'NOT_A_GENE'
234 | 2. Average the ranks of duplicated genes in generated sentence
235 |
236 | Arguments:
237 | cell_sentence: generated cell sentence string
238 | global_dictionary: list of global gene vocabulary (all uppercase)
239 | replace_nonsense_string: string which will replace nonsense genes in generated output
240 |
241 | Returns:
242 | post_processed_sentence: generated cell sentence after post processing steps
243 | num_nonsense_genes: number of genes replaced with defined nonsense token
244 | """
245 | generated_gene_names = cell_sentence.split(" ")
246 | generated_gene_names = [generated_gene.upper() for generated_gene in generated_gene_names]
247 |
248 | # --- Replace nonsense genes ---#
249 | generated_gene_names = [
250 | gene_name if gene_name in global_dictionary else replace_nonsense_string
251 | for gene_name in generated_gene_names
252 | ]
253 | num_genes_replaced = generated_gene_names.count(replace_nonsense_string)
254 |
255 | # --- Average ranks ---#
256 | gene_name_to_occurrences = Counter(
257 | generated_gene_names
258 | ) # get mapping of gene name --> number of occurrences
259 | post_processed_sentence = generated_gene_names.copy() # copy of generated gene list
260 |
261 | for gene_name in gene_name_to_occurrences:
262 | if (
263 | gene_name_to_occurrences[gene_name] > 1
264 | and gene_name != replace_nonsense_string
265 | ):
266 | # Find positions of all occurrences of duplicated generated gene in list
267 | # Note: using post_processed_sentence here; since duplicates are being removed, list will be
268 | # getting shorter. Getting indices in original list will no longer be accurate positions
269 | occurrence_positions = [
270 | idx
271 | for idx, elem in enumerate(post_processed_sentence)
272 | if elem == gene_name
273 | ]
274 | average_position = int(
275 | sum(occurrence_positions) / len(occurrence_positions)
276 | )
277 |
278 | # Remove occurrences
279 | post_processed_sentence = [
280 | elem for elem in post_processed_sentence if elem != gene_name
281 | ]
282 | # Reinsert gene_name at average position
283 | post_processed_sentence.insert(
284 | average_position, gene_name
285 | )
286 |
287 | return post_processed_sentence, num_genes_replaced
288 |
289 |
290 | def convert_cell_sentence_back_to_expression_vector(
291 | cell_sentence: List, global_dictionary: List, slope: float, intercept: float
292 | ):
293 | """
294 | Function to convert
295 |
296 | Current assumptions in this function:
297 | - We replace nonsense genes with some string, e.g. 'NOT_A_GENE', so that ranks are not
298 | changed in generated output.
299 |
300 | Steps:
301 | 1. Replace any nonsense genes with a specified token, e.g. 'nan'
302 | 2. Average the ranks of duplicated genes in generated sentence
303 |
304 | Arguments:
305 | cell_sentence: generated cell sentence list, e.g. ['GENE1', 'GENE2']
306 | global_dictionary: list of global gene vocabulary
307 | slope: slope value to use in inverse rank->expression transformation
308 | intercept: intercept value to use in inverse rank->expression transformation
309 |
310 | Returns:
311 | expression_vector: expression vector for generated cell
312 | """
313 | expression_vector = np.zeros(len(global_dictionary), dtype=np.float32)
314 | for rank, gene_name in enumerate(cell_sentence):
315 | if gene_name in global_dictionary:
316 | log_rank = np.log10(1 + rank).item()
317 | gene_expr_val = intercept + (slope * log_rank)
318 | gene_idx_in_vector = global_dictionary.index(gene_name)
319 | expression_vector[gene_idx_in_vector] = gene_expr_val
320 |
321 | return expression_vector
322 |
--------------------------------------------------------------------------------
/workflow_data/transform.py:
--------------------------------------------------------------------------------
1 | """
2 | This file was initially developed by the project at https://github.com/vandijklab/cell2sentence-ft.
3 | Many thanks for their contributions to this field. It adheres to the Attribution-NonCommercial-ShareAlike
4 | 4.0 International License.
5 |
6 | If you use this file, please cite the papers "Levine et al., Cell2Sentence: Teaching Large Language
7 | Models the Language of Biology. 2023 (https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3)" and
8 | "Rahul M Dhodapkar. Representing cells as sentences enables natural-language processing for single-cell
9 | transcriptomics. 2022 (https://www.biorxiv.org/content/10.1101/2022.09.18.508438)."
10 | """
11 |
12 | import os
13 | import argparse
14 | from pathlib import Path
15 |
16 | import anndata
17 | import numpy as np
18 | import pandas as pd
19 | import plotnine as pn
20 | import scanpy as sc
21 | import sklearn.linear_model as lm
22 | from datasets import Dataset, load_dataset, concatenate_datasets
23 | from scipy.stats import pearsonr, spearmanr
24 | from sklearn.metrics import r2_score
25 | from sklearn.utils import shuffle
26 | from tqdm import tqdm
27 | import sys
28 | from src import utils
29 |
30 | ROW_SUM = 10000
31 |
32 |
33 | def normalize_and_rank_transform(data_matrix_X, normalize=True):
34 | """
35 | Helper function which accepts a data matrix, optionally row-normalizes it,
36 | and calculated a rank transformation of the data.
37 |
38 | Args:
39 | data_matrix_X: numpy matrix of shape [num_cells, num_genes]
40 | normalize: boolean flag for whether to normalize data
41 |
42 | Returns:
43 | data_matrix_X: normalized data matrix
44 | rank_matrix_X: matrix of rank values for each cell, shame shape as data_matrix_X
45 | """
46 | if normalize:
47 | normalized_data_matrix_X = (
48 | np.diag(ROW_SUM / np.ravel(np.sum(data_matrix_X, axis=1))) @ data_matrix_X
49 | )
50 | data_matrix_X = np.asarray(normalized_data_matrix_X)
51 |
52 | rank_matrix_X = np.zeros(shape=data_matrix_X.shape)
53 | for i in tqdm(range(data_matrix_X.shape[0])):
54 | cols = np.ravel(range(data_matrix_X.shape[1]))
55 | vals = np.ravel(data_matrix_X[i, :])
56 | cols, vals = shuffle(cols, vals)
57 | ranks = cols[np.argsort(-vals, kind="stable")]
58 | for j in range(len(ranks)):
59 | rank_matrix_X[i, ranks[j]] = j
60 |
61 | return data_matrix_X, rank_matrix_X
62 |
63 |
64 | def evaluate_transformation(df, plotting_sample_size=10000):
65 | """
66 | Helper function which takes as input a pandas DataFrame of expression values and
67 | ranks, and fits a linear regression model to predict back expression value from
68 | log rank.
69 |
70 | Plots are created to show the relationship between log rank and log expression,
71 | as well as the performance of expression reconstruction by the linear model.
72 | Metrics for expression reconstruction, as well as the parameters of the linear
73 | model are saved in a CSV file.
74 |
75 | Args:
76 | df: pandas DataFrame with keys: 'preprocessed_transcript_count,
77 | 'preprocessed_rank', 'log_preprocessed_transcript_count',
78 | and 'log_preprocessed_rank'
79 | plotting_sample_size: how many values to sample for plotting
80 | """
81 | eval_output_dir=Path("Output_directory_filepath/eval")
82 | eval_output_dir.mkdir(exist_ok=True, parents=True)
83 |
84 | # (1) Fit linear regression between log rank (x-axis) and log expression (y-axis)
85 | x_axis_name = "log_preprocessed_rank"
86 | y_axis_name = "log_preprocessed_transcript_count"
87 | x = np.array(df.loc[df[x_axis_name] < utils.BASE10_THRESHOLD, x_axis_name]).reshape(
88 | -1, 1
89 | )
90 | y = df.loc[df[x_axis_name] < utils.BASE10_THRESHOLD, y_axis_name]
91 |
92 | reg = lm.LinearRegression().fit(x, y)
93 |
94 | # Plot relationship
95 | plot = (
96 | pn.ggplot(
97 | df.sample(plotting_sample_size),
98 | pn.aes(x="log_preprocessed_rank", y="log_preprocessed_transcript_count"),
99 | )
100 | + pn.geom_abline(slope=reg.coef_, intercept=reg.intercept_, color="red")
101 | + pn.geom_point(color="blue", size=0.5)
102 | + pn.labs(
103 | x="Gene Log Rank",
104 | y="Gene Log Expression",
105 | title="Log Rank vs Log Expression",
106 | )
107 | )
108 | plot.save(os.path.join(eval_output_dir, "plot_log_rank_vs_log_expr.png"), dpi=300)
109 |
110 | # (2) Reconstruct expression from log rank, calculate reconstruction performance metrics
111 | rank_reconstructed_X = reg.predict(
112 | np.array(df["log_preprocessed_rank"]).reshape(-1, 1)
113 | )
114 |
115 | r_squared_score = r2_score(
116 | np.asarray(df["log_preprocessed_transcript_count"]),
117 | np.asarray(rank_reconstructed_X),
118 | )
119 | pearson_r_score = pearsonr(
120 | np.asarray(df["log_preprocessed_transcript_count"]),
121 | np.asarray(rank_reconstructed_X),
122 | )
123 | spearman_r_score = spearmanr(
124 | np.asarray(df["log_preprocessed_transcript_count"]),
125 | np.asarray(rank_reconstructed_X),
126 | )
127 |
128 | reconstructed_expr_values_df = pd.DataFrame(
129 | {
130 | "Ground Truth Expression": df["log_preprocessed_transcript_count"],
131 | "Reconstructed Expression from Log Rank": rank_reconstructed_X,
132 | }
133 | )
134 | plot = (
135 | pn.ggplot(
136 | reconstructed_expr_values_df.sample(plotting_sample_size),
137 | pn.aes(
138 | x="Ground Truth Expression", y="Reconstructed Expression from Log Rank"
139 | ),
140 | )
141 | + pn.geom_point(color="blue", size=0.5)
142 | + pn.geom_abline(slope=1, intercept=0, color="red")
143 | + pn.labs(
144 | x="Ground Truth Expression",
145 | y="Reconstructed Expression from Log Rank",
146 | title="Ground Truth Expression vs Reconstruction from Rank",
147 | )
148 | )
149 | plot.save(
150 | os.path.join(
151 | eval_output_dir, "plot_gt_expr_vs_reconstructed_expr_from_rank.png"
152 | ),
153 | dpi=300,
154 | )
155 |
156 | # 3. Create results dataframe and return
157 | metrics_df = pd.DataFrame(
158 | {
159 | "threshold": [utils.BASE10_THRESHOLD],
160 | "slope": [reg.coef_.item()],
161 | "intercept": [reg.intercept_.item()],
162 | "R^2": [r_squared_score.item()],
163 | "Pearson_R_statistic": [pearson_r_score.statistic.item()],
164 | "Pearson_R_p_value": [pearson_r_score.pvalue.item()],
165 | "Spearman_R_statistic": [spearman_r_score.statistic.item()],
166 | "Spearman_R_p_value": [spearman_r_score.pvalue.item()],
167 | }
168 | )
169 | metrics_df.to_csv(
170 | os.path.join(eval_output_dir, "transformation_metrics_and_parameters.csv")
171 | )
172 |
173 |
174 | def main(data_filepath: Path, output_dir: Path):
175 | """Apply preprocessing steps and transform to cell sentences.
176 |
177 | Preprocessing follows https://scanpy-tutorials.readthedocs.io/en/latest/pbmc3k.html.
178 | """
179 | print(f"Loading data from {data_filepath}.")
180 | adata = anndata.read_h5ad(data_filepath)
181 |
182 | # reach for raw transcript counts in the .raw attribute
183 | if hasattr(adata, "raw") and adata.raw is not None:
184 | adata.X = adata.raw.X
185 | print(f"Done loading data for {len(adata)} cells.")
186 |
187 | adata.var["feature_name"]=adata.var["features"].copy().str.upper()
188 | adata.var["feature_name"] = adata.var["feature_name"].astype(str)
189 | duplicates = adata.var["feature_name"].duplicated(keep=False)
190 | adata.var.loc[duplicates, "feature_name"] = (
191 | adata.var.loc[duplicates, "feature_name"]
192 | + '_'
193 | + adata.var.loc[duplicates, "feature_name"].groupby(adata.var["feature_name"]).cumcount().add(1).astype(str)
194 | )
195 |
196 | adata.var["ensembl_ids"] = adata.var.index
197 | adata.var_names = adata.var["feature_name"]
198 | adata.var_names_make_unique(join="_")
199 |
200 | sc.pp.filter_cells(adata, min_genes=200)
201 | sc.pp.filter_genes(adata, min_cells=3)
202 |
203 | # annotate the group of mitochondrial genes as 'mt'
204 | adata.var["mt"] = adata.var_names.str.startswith("MT-")
205 | sc.pp.calculate_qc_metrics(
206 | adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
207 | )
208 |
209 | adata = adata[adata.obs.n_genes_by_counts < 2500, :]
210 | adata = adata[adata.obs.pct_counts_mt < 200, :]
211 | print(f"Done filtering cells, remaining data of shape {adata.shape}.")
212 |
213 | raw_X = np.copy(adata.X.toarray())
214 | norm_X, rank_norm_X = normalize_and_rank_transform(
215 | np.copy(adata.X.todense()), normalize=True
216 | )
217 | # update adata object with normalized expression
218 | adata.X = np.log10(1 + norm_X)
219 | # create dataframe of ranks and expression values for plotting
220 | expr_and_rank_df = pd.DataFrame(
221 | {
222 | "raw_transcript_count": np.ravel(raw_X),
223 | "preprocessed_transcript_count": np.ravel(norm_X),
224 | "preprocessed_rank": np.ravel(rank_norm_X),
225 | "log_preprocessed_transcript_count": np.log10(1 + np.ravel(norm_X)),
226 | "log_preprocessed_rank": np.log10(1 + np.ravel(rank_norm_X)),
227 | }
228 | )
229 | # remove 0 expression entries in the cellxgene matrix
230 | expr_and_rank_df = expr_and_rank_df[expr_and_rank_df["raw_transcript_count"] != 0]
231 | print(f"Done normalizing data, {len(expr_and_rank_df)} data points remaining.")
232 |
233 | # compute metrics for transformation to cells and back
234 | evaluate_transformation(df=expr_and_rank_df, plotting_sample_size=10000)
235 |
236 | preprocessed_output_filepath = data_filepath.parent / (
237 | data_filepath.stem + data_filepath.suffix.replace(".h5ad", "_preprocessed.h5ad")
238 | )
239 | print(f"Saving preprocessed transcript counts to {preprocessed_output_filepath}.")
240 | del adata.raw
241 | adata.write_h5ad(preprocessed_output_filepath)
242 |
243 | # convert the adata into ranked sequences of gene names ("cell sentences")
244 | csdata = utils.csdata_from_adata(adata)
245 |
246 | # make text files containing the cell sentences
247 | txt_output_dir = output_dir / "cell_sentences"
248 | txt_output_dir.mkdir(exist_ok=True, parents=True)
249 | utils.xlm_prepare_outpath(csdata, txt_output_dir, species_tag="human")
250 | print(f"Done writing cell sentences to file.")
251 |
252 | # make arrow-formatted dataset compatible with HuggingFace's datasets
253 | hf_output_dir = output_dir / "cell_sentences_hf"
254 | hf_output_dir.mkdir(exist_ok=True, parents=True)
255 | data_splits = ["train", "valid", "test"]
256 | data_files = {
257 | data_split: str(txt_output_dir / f"{data_split}_human.txt")
258 | for data_split in data_splits
259 | }
260 | dataset = load_dataset("text", data_files=data_files)
261 |
262 | # load cell type labels if available with transcript counts
263 | for data_split in data_splits:
264 | dataset[data_split] = dataset[data_split].rename_column("text", "input_ids")
265 | # retrieve split chunk from preprocessed transcript counts
266 | dataset_split_sample_indices = np.load(
267 | txt_output_dir / f"{data_split}_partition_indices.npy"
268 | )
269 | adata_split = adata[dataset_split_sample_indices, :].copy()
270 | if "cell_type" in adata_split.obs.columns:
271 | cell_type_labels = {"cell_type": adata_split.obs["cell_type"].tolist()}
272 | cell_type_dataset = Dataset.from_dict(cell_type_labels)
273 | dataset[data_split] = concatenate_datasets(
274 | [dataset[data_split], cell_type_dataset], axis=1
275 | )
276 |
277 | dataset.save_to_disk(hf_output_dir)
278 | print(f"Done transforming data to cell sentences.")
279 |
280 |
281 | if __name__ == "__main__":
282 | parser = argparse.ArgumentParser()
283 | parser.add_argument(
284 | "--data_filepath",
285 | type=Path,
286 | help="Input data filepath.",
287 | default='SHARE-seq_mouse_skin_dataset.h5ad',
288 | )
289 | parser.add_argument(
290 | "--output_dir",
291 | type=Path,
292 | help="Output directory filepath.",
293 | default='Output_directory_filepath',
294 | )
295 | args = parser.parse_args()
296 |
297 | main(args.data_filepath, args.output_dir)
298 |
--------------------------------------------------------------------------------