├── source ├── __init__.py ├── Matcher │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── config_loader.py │ │ └── config.json │ ├── models │ │ ├── __init__.py │ │ ├── llm │ │ │ ├── __init__.py │ │ │ ├── llm_loader.py │ │ │ └── llm_reranker.py │ │ └── embedding │ │ │ ├── __init__.py │ │ │ ├── query_embedder.py │ │ │ └── sentence_embedder.py │ ├── utils │ │ ├── __init__.py │ │ ├── logging_config.py │ │ ├── file_utils.py │ │ └── temporal_utils.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── trial_search │ │ │ └── __init__.py │ │ └── trial_ranker.py │ └── services │ │ ├── __init__.py │ │ └── biomedner_service.py ├── Parser │ ├── __init__.py │ ├── normalizers │ │ ├── __init__.py │ │ ├── cellline_normalizer.py │ │ ├── species_normalizer.py │ │ ├── readme.md │ │ ├── celltype_normalizer.py │ │ ├── chemical_normalizer.py │ │ └── normalizer_all.py │ ├── lib │ │ ├── libcrfpp.so │ │ ├── libcrfpp.so.0 │ │ ├── libcrfpp.a │ │ ├── libcrfpp.so.0.0.0 │ │ └── libcrfpp.la │ ├── bin │ │ ├── crf_learn │ │ └── crf_test │ ├── input │ │ ├── 89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace.PubTator.gner.PubTator │ │ ├── 525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e.PubTator.gner.PubTator │ │ ├── 611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637.PubTator.gner.PubTator │ │ ├── 6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377.PubTator.gner.PubTator │ │ ├── 89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace.PubTator.biomedner.PubTator │ │ ├── 427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453.PubTator.gner.PubTator │ │ ├── 525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e.PubTator.biomedner.PubTator │ │ ├── 611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637.PubTator.biomedner.PubTator │ │ ├── 61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc.PubTator.gner.PubTator │ │ ├── 6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377.PubTator.biomedner.PubTator │ │ ├── 7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3.PubTator.gner.PubTator │ │ ├── b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c.PubTator.gner.PubTator │ │ ├── 427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453.PubTator.biomedner.PubTator │ │ ├── 61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc.PubTator.biomedner.PubTator │ │ ├── 7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3.PubTator.biomedner.PubTator │ │ ├── b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c.PubTator.biomedner.PubTator │ │ ├── 1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42.PubTator.gner.PubTator │ │ ├── 1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42.PubTator.biomedner.PubTator │ │ ├── 6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28.PubTator.gner.PubTator │ │ └── 6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28.PubTator.biomedner.PubTator │ ├── output │ │ ├── 525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e.PubTator.biomedner.json │ │ ├── 89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace.PubTator.biomedner.json │ │ ├── 611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637.PubTator.biomedner.json │ │ ├── 6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377.PubTator.biomedner.json │ │ ├── 89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace.PubTator.gner.json │ │ ├── 427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453.PubTator.gner.json │ │ ├── 525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e.PubTator.gner.json │ │ ├── 61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc.PubTator.gner.json │ │ ├── 7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3.PubTator.gner.json │ │ ├── b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c.PubTator.gner.json │ │ ├── 1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42.PubTator.gner.json │ │ ├── 427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453.PubTator.biomedner.json │ │ ├── 61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc.PubTator.biomedner.json │ │ ├── 7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3.PubTator.biomedner.json │ │ ├── b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c.PubTator.biomedner.json │ │ ├── 1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42.PubTator.biomedner.json │ │ ├── 6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28.PubTator.gner.json │ │ └── 6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28.PubTator.biomedner.json │ ├── scripts │ │ ├── stop_biomedner.sh │ │ └── run_biomedner.sh │ └── biomedner_server.py ├── biomedner_services │ ├── stop_biomedner.sh │ └── run_biomedner.sh └── regex │ └── exception_regex_patterns.json ├── utils ├── DataLoader │ ├── __init__.py │ ├── test │ │ └── __init__.py │ └── nct_ids.txt ├── Indexer │ ├── __init__.py │ ├── config.json │ ├── zipper.sh │ ├── flatten.py │ ├── index_trials.py │ ├── prepare_trials.py │ └── prepare_criteria.py ├── Preprocessor │ ├── __init__.py │ ├── test │ │ └── __init__.py │ ├── preprocessing.py │ └── jsonify.py ├── finetuning │ └── finetune_instruct │ │ ├── __init__.py │ │ ├── split_train_test.py │ │ ├── finetune.sh │ │ ├── trainer.py │ │ ├── load_model.py │ │ ├── modeling.py │ │ ├── run.py │ │ ├── evaluate_CoT.py │ │ ├── arguments.py │ │ └── data_llama.py └── gpt │ └── gpt_generate_reranking_data.py ├── elasticsearch ├── certs │ ├── es01 │ │ ├── ca.crt │ │ ├── es01.crt │ │ └── es01.key │ ├── ca.zip │ ├── certs.zip │ ├── instances.yml │ ├── es02 │ │ ├── es02.crt │ │ └── es02.key │ ├── es03 │ │ ├── es03.crt │ │ └── es03.key │ ├── ca.crt │ └── ca │ │ ├── ca.crt │ │ └── ca.key ├── config │ ├── es01 │ │ ├── users │ │ ├── users_roles │ │ ├── elasticsearch.keystore │ │ ├── roles.yml │ │ ├── role_mapping.yml │ │ ├── elasticsearch.yml │ │ ├── elasticsearch-plugins.example.yml │ │ ├── es01.crt │ │ ├── ca.crt │ │ ├── es01.key │ │ └── jvm.options │ ├── es02 │ │ ├── users │ │ ├── users_roles │ │ ├── elasticsearch.keystore │ │ ├── roles.yml │ │ ├── role_mapping.yml │ │ ├── elasticsearch.yml │ │ ├── elasticsearch-plugins.example.yml │ │ ├── es02.crt │ │ ├── ca.crt │ │ ├── es02.key │ │ └── jvm.options │ └── es03 │ │ ├── users │ │ ├── users_roles │ │ ├── elasticsearch.keystore │ │ ├── roles.yml │ │ ├── role_mapping.yml │ │ ├── elasticsearch.yml │ │ ├── elasticsearch-plugins.example.yml │ │ ├── es03.crt │ │ ├── ca.crt │ │ ├── es03.key │ │ └── jvm.options ├── tmp-config │ ├── users │ ├── users_roles │ ├── elasticsearch.yml │ ├── roles.yml │ ├── role_mapping.yml │ ├── elasticsearch-plugins.example.yml │ └── jvm.options ├── .env └── apptainer-run-es.sh ├── .gitattributes ├── img └── logo.png ├── requirements.txt ├── LICENSE ├── .gitignore ├── example └── phenopacket │ └── keywords.json ├── README.md └── setup.sh /source/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Matcher/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Parser/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/DataLoader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/Indexer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elasticsearch/certs/es01/ca.crt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elasticsearch/config/es01/users: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elasticsearch/config/es02/users: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elasticsearch/config/es03/users: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elasticsearch/tmp-config/users: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Matcher/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Matcher/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Matcher/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/DataLoader/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/Preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elasticsearch/config/es01/users_roles: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elasticsearch/config/es02/users_roles: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elasticsearch/config/es03/users_roles: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elasticsearch/tmp-config/users_roles: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Matcher/models/llm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Matcher/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Matcher/services/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Parser/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/Preprocessor/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Matcher/models/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/DataLoader/nct_ids.txt: -------------------------------------------------------------------------------- 1 | NCT04127110 -------------------------------------------------------------------------------- /source/Matcher/pipeline/trial_search/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/Parser/lib/libcrfpp.so: -------------------------------------------------------------------------------- 1 | libcrfpp.so.0.0.0 -------------------------------------------------------------------------------- /source/Parser/lib/libcrfpp.so.0: -------------------------------------------------------------------------------- 1 | libcrfpp.so.0.0.0 -------------------------------------------------------------------------------- /utils/finetuning/finetune_instruct/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | data/* filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbib/TrialMatchAI/HEAD/img/logo.png -------------------------------------------------------------------------------- /elasticsearch/tmp-config/elasticsearch.yml: -------------------------------------------------------------------------------- 1 | cluster.name: "docker-cluster" 2 | network.host: 0.0.0.0 3 | -------------------------------------------------------------------------------- /elasticsearch/certs/ca.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbib/TrialMatchAI/HEAD/elasticsearch/certs/ca.zip -------------------------------------------------------------------------------- /source/Parser/bin/crf_learn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbib/TrialMatchAI/HEAD/source/Parser/bin/crf_learn -------------------------------------------------------------------------------- /source/Parser/bin/crf_test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbib/TrialMatchAI/HEAD/source/Parser/bin/crf_test -------------------------------------------------------------------------------- /elasticsearch/certs/certs.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbib/TrialMatchAI/HEAD/elasticsearch/certs/certs.zip -------------------------------------------------------------------------------- /source/Parser/lib/libcrfpp.a: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbib/TrialMatchAI/HEAD/source/Parser/lib/libcrfpp.a -------------------------------------------------------------------------------- /source/Parser/lib/libcrfpp.so.0.0.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbib/TrialMatchAI/HEAD/source/Parser/lib/libcrfpp.so.0.0.0 -------------------------------------------------------------------------------- /elasticsearch/config/es01/elasticsearch.keystore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbib/TrialMatchAI/HEAD/elasticsearch/config/es01/elasticsearch.keystore -------------------------------------------------------------------------------- /elasticsearch/config/es02/elasticsearch.keystore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbib/TrialMatchAI/HEAD/elasticsearch/config/es02/elasticsearch.keystore -------------------------------------------------------------------------------- /elasticsearch/config/es03/elasticsearch.keystore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbib/TrialMatchAI/HEAD/elasticsearch/config/es03/elasticsearch.keystore -------------------------------------------------------------------------------- /elasticsearch/tmp-config/roles.yml: -------------------------------------------------------------------------------- 1 | # The default roles file is empty as the preferred method of defining roles is 2 | # through the API/UI. File based roles are useful in error scenarios when the 3 | # API based roles may not be available. 4 | -------------------------------------------------------------------------------- /elasticsearch/config/es01/roles.yml: -------------------------------------------------------------------------------- 1 | # The default roles file is empty as the preferred method of defining roles is 2 | # through the API/UI. File based roles are useful in error scenarios when the 3 | # API based roles may not be available. 4 | -------------------------------------------------------------------------------- /elasticsearch/config/es02/roles.yml: -------------------------------------------------------------------------------- 1 | # The default roles file is empty as the preferred method of defining roles is 2 | # through the API/UI. File based roles are useful in error scenarios when the 3 | # API based roles may not be available. 4 | -------------------------------------------------------------------------------- /elasticsearch/config/es03/roles.yml: -------------------------------------------------------------------------------- 1 | # The default roles file is empty as the preferred method of defining roles is 2 | # through the API/UI. File based roles are useful in error scenarios when the 3 | # API based roles may not be available. 4 | -------------------------------------------------------------------------------- /source/Parser/input/89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace.PubTator.gner.PubTator: -------------------------------------------------------------------------------- 1 | 89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace|t| 2 | 89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace|a|type 1 diabetes 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e.PubTator.gner.PubTator: -------------------------------------------------------------------------------- 1 | 525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e|t| 2 | 525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e|a|chronic hepatitis c 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637.PubTator.gner.PubTator: -------------------------------------------------------------------------------- 1 | 611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637|t| 2 | 611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637|a|myocardial infarction 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377.PubTator.gner.PubTator: -------------------------------------------------------------------------------- 1 | 6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377|t| 2 | 6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377|a|myocardial infarction 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace.PubTator.biomedner.PubTator: -------------------------------------------------------------------------------- 1 | 89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace|t| 2 | 89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace|a|type 1 diabetes 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453.PubTator.gner.PubTator: -------------------------------------------------------------------------------- 1 | 427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453|t| 2 | 427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453|a|severe aortic stenosis 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e.PubTator.biomedner.PubTator: -------------------------------------------------------------------------------- 1 | 525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e|t| 2 | 525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e|a|chronic hepatitis c 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637.PubTator.biomedner.PubTator: -------------------------------------------------------------------------------- 1 | 611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637|t| 2 | 611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637|a|myocardial infarction 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc.PubTator.gner.PubTator: -------------------------------------------------------------------------------- 1 | 61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc|t| 2 | 61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc|a|severe aortic stenosis 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377.PubTator.biomedner.PubTator: -------------------------------------------------------------------------------- 1 | 6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377|t| 2 | 6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377|a|myocardial infarction 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3.PubTator.gner.PubTator: -------------------------------------------------------------------------------- 1 | 7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3|t| 2 | 7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3|a|severe aortic stenosis 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c.PubTator.gner.PubTator: -------------------------------------------------------------------------------- 1 | b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c|t| 2 | b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c|a|severe aortic stenosis 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453.PubTator.biomedner.PubTator: -------------------------------------------------------------------------------- 1 | 427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453|t| 2 | 427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453|a|severe aortic stenosis 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc.PubTator.biomedner.PubTator: -------------------------------------------------------------------------------- 1 | 61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc|t| 2 | 61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc|a|severe aortic stenosis 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3.PubTator.biomedner.PubTator: -------------------------------------------------------------------------------- 1 | 7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3|t| 2 | 7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3|a|severe aortic stenosis 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c.PubTator.biomedner.PubTator: -------------------------------------------------------------------------------- 1 | b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c|t| 2 | b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c|a|severe aortic stenosis 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42.PubTator.gner.PubTator: -------------------------------------------------------------------------------- 1 | 1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42|t| 2 | 1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42|a|clostridioides difficile infection 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42.PubTator.biomedner.PubTator: -------------------------------------------------------------------------------- 1 | 1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42|t| 2 | 1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42|a|clostridioides difficile infection 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28.PubTator.gner.PubTator: -------------------------------------------------------------------------------- 1 | 6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28|t| 2 | 6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28|a|Deficiency, alpha-Lecithin:Cholesterol Acyltransferase 3 | 4 | -------------------------------------------------------------------------------- /source/Parser/input/6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28.PubTator.biomedner.PubTator: -------------------------------------------------------------------------------- 1 | 6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28|t| 2 | 6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28|a|Deficiency, alpha-Lecithin:Cholesterol Acyltransferase 3 | 4 | -------------------------------------------------------------------------------- /utils/Indexer/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "elasticsearch": { 3 | "hosts": ["https://localhost:9200"], 4 | "ca_certs": "../../elasticsearch/certs/ca.crt", 5 | "username": "elastic", 6 | "password": "QQ7wWoB_WnKe*L*X9tAW", 7 | "request_timeout": 300, 8 | "retry_on_timeout": true, 9 | "max_retries": 3 10 | } 11 | } -------------------------------------------------------------------------------- /elasticsearch/certs/instances.yml: -------------------------------------------------------------------------------- 1 | instances: 2 | - name: es01 3 | dns: 4 | - es01 5 | - localhost 6 | ip: 7 | - 127.0.0.1 8 | - name: es02 9 | dns: 10 | - es02 11 | - localhost 12 | ip: 13 | - 127.0.0.1 14 | - name: es03 15 | dns: 16 | - es03 17 | - localhost 18 | ip: 19 | - 127.0.0.1 20 | -------------------------------------------------------------------------------- /source/Matcher/utils/logging_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | 5 | def setup_logging(): 6 | """Configure logging for the application.""" 7 | logging.basicConfig( 8 | level=logging.INFO, 9 | format="[%(asctime)s] %(levelname)s: %(message)s", 10 | datefmt="%H:%M:%S", 11 | handlers=[logging.StreamHandler(sys.stdout)], 12 | ) 13 | return logging.getLogger(__name__) 14 | -------------------------------------------------------------------------------- /source/Matcher/config/config_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import os 5 | from typing import Any, Dict 6 | 7 | 8 | def load_config(config_path: str = "Matcher/config/config.json") -> Dict[str, Any]: 9 | """Load configuration from a JSON file.""" 10 | if not os.path.exists(config_path): 11 | raise FileNotFoundError(f"Configuration file not found: {config_path}") 12 | with open(config_path, "r", encoding="utf-8") as f: 13 | return json.load(f) 14 | -------------------------------------------------------------------------------- /source/Parser/output/525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e.PubTator.biomedner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e", "entities": {"disease": [{"start": 0, "end": 18}], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": []}, "title": "", "abstract": "chronic hepatitis c", "prob": {"disease": [[{"start": 0, "end": 18}, 0.9806779623031616]], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": []}, "num_entities": 1} -------------------------------------------------------------------------------- /source/Parser/output/89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace.PubTator.biomedner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace", "entities": {"disease": [{"start": 0, "end": 14}], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": []}, "title": "", "abstract": "type 1 diabetes", "prob": {"disease": [[{"start": 0, "end": 14}, 0.9999966621398926]], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": []}, "num_entities": 1} -------------------------------------------------------------------------------- /source/Parser/output/611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637.PubTator.biomedner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "611019227d9ea65a714df4e8c4498bfc13e21c67879025c3b03d8637", "entities": {"disease": [{"start": 0, "end": 20}], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": []}, "title": "", "abstract": "myocardial infarction", "prob": {"disease": [[{"start": 0, "end": 20}, 0.9999976754188538]], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": []}, "num_entities": 1} -------------------------------------------------------------------------------- /source/Parser/output/6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377.PubTator.biomedner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "6db4a18d11e3899c21b4cc11489cf3c8b457a40273c48ecd39ab4377", "entities": {"disease": [{"start": 0, "end": 20}], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": []}, "title": "", "abstract": "myocardial infarction", "prob": {"disease": [[{"start": 0, "end": 20}, 0.9999976754188538]], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": []}, "num_entities": 1} -------------------------------------------------------------------------------- /source/Parser/output/89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace.PubTator.gner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "89e92a6963f2179038668173907fc2a79ae48e7712fe3dddfe49dace", "entities": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "title": "", "abstract": "type 1 diabetes", "prob": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "num_entities": 0} -------------------------------------------------------------------------------- /source/Parser/output/427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453.PubTator.gner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453", "entities": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "title": "", "abstract": "severe aortic stenosis", "prob": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "num_entities": 0} -------------------------------------------------------------------------------- /source/Parser/output/525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e.PubTator.gner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "525bbc5e481bdaf825bf80725e1ac63c15786fa13120fe734d703c0e", "entities": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "title": "", "abstract": "chronic hepatitis c", "prob": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "num_entities": 0} -------------------------------------------------------------------------------- /source/Parser/output/61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc.PubTator.gner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc", "entities": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "title": "", "abstract": "severe aortic stenosis", "prob": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "num_entities": 0} -------------------------------------------------------------------------------- /source/Parser/output/7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3.PubTator.gner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3", "entities": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "title": "", "abstract": "severe aortic stenosis", "prob": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "num_entities": 0} -------------------------------------------------------------------------------- /source/Parser/output/b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c.PubTator.gner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c", "entities": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "title": "", "abstract": "severe aortic stenosis", "prob": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "num_entities": 0} -------------------------------------------------------------------------------- /elasticsearch/config/es01/role_mapping.yml: -------------------------------------------------------------------------------- 1 | # Role mapping configuration file which has elasticsearch roles as keys 2 | # that map to one or more user or group distinguished names 3 | 4 | #roleA: this is an elasticsearch role 5 | # - groupA-DN this is a group distinguished name 6 | # - groupB-DN 7 | # - user1-DN this is the full user distinguished name 8 | 9 | #power_user: 10 | # - "cn=admins,dc=example,dc=com" 11 | #user: 12 | # - "cn=users,dc=example,dc=com" 13 | # - "cn=admins,dc=example,dc=com" 14 | # - "cn=John Doe,cn=other users,dc=example,dc=com" 15 | -------------------------------------------------------------------------------- /elasticsearch/config/es02/role_mapping.yml: -------------------------------------------------------------------------------- 1 | # Role mapping configuration file which has elasticsearch roles as keys 2 | # that map to one or more user or group distinguished names 3 | 4 | #roleA: this is an elasticsearch role 5 | # - groupA-DN this is a group distinguished name 6 | # - groupB-DN 7 | # - user1-DN this is the full user distinguished name 8 | 9 | #power_user: 10 | # - "cn=admins,dc=example,dc=com" 11 | #user: 12 | # - "cn=users,dc=example,dc=com" 13 | # - "cn=admins,dc=example,dc=com" 14 | # - "cn=John Doe,cn=other users,dc=example,dc=com" 15 | -------------------------------------------------------------------------------- /elasticsearch/config/es03/role_mapping.yml: -------------------------------------------------------------------------------- 1 | # Role mapping configuration file which has elasticsearch roles as keys 2 | # that map to one or more user or group distinguished names 3 | 4 | #roleA: this is an elasticsearch role 5 | # - groupA-DN this is a group distinguished name 6 | # - groupB-DN 7 | # - user1-DN this is the full user distinguished name 8 | 9 | #power_user: 10 | # - "cn=admins,dc=example,dc=com" 11 | #user: 12 | # - "cn=users,dc=example,dc=com" 13 | # - "cn=admins,dc=example,dc=com" 14 | # - "cn=John Doe,cn=other users,dc=example,dc=com" 15 | -------------------------------------------------------------------------------- /elasticsearch/tmp-config/role_mapping.yml: -------------------------------------------------------------------------------- 1 | # Role mapping configuration file which has elasticsearch roles as keys 2 | # that map to one or more user or group distinguished names 3 | 4 | #roleA: this is an elasticsearch role 5 | # - groupA-DN this is a group distinguished name 6 | # - groupB-DN 7 | # - user1-DN this is the full user distinguished name 8 | 9 | #power_user: 10 | # - "cn=admins,dc=example,dc=com" 11 | #user: 12 | # - "cn=users,dc=example,dc=com" 13 | # - "cn=admins,dc=example,dc=com" 14 | # - "cn=John Doe,cn=other users,dc=example,dc=com" 15 | -------------------------------------------------------------------------------- /source/Parser/output/1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42.PubTator.gner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42", "entities": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "title": "", "abstract": "clostridioides difficile infection", "prob": {"diagnostic test": [], "treatment": [], "laboratory test": [], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "num_entities": 0} -------------------------------------------------------------------------------- /source/Parser/output/427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453.PubTator.biomedner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "427736011957cbb0ce549f492b0330b9d77ba984a2beb7b2abca8453", "entities": {"disease": [{"start": 7, "end": 21}], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": [{"start": 0, "end": 21}]}, "title": "", "abstract": "severe aortic stenosis", "prob": {"disease": [[{"start": 7, "end": 21}, 0.9999975562095642]], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": [[{"start": 0, "end": 21}, 0.41163790225982666]]}, "num_entities": 2} -------------------------------------------------------------------------------- /source/Parser/output/61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc.PubTator.biomedner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "61d5392634e0658327dab1a020c904645ec31ef6fbb49b058ee86cdc", "entities": {"disease": [{"start": 7, "end": 21}], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": [{"start": 0, "end": 21}]}, "title": "", "abstract": "severe aortic stenosis", "prob": {"disease": [[{"start": 7, "end": 21}, 0.9999975562095642]], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": [[{"start": 0, "end": 21}, 0.41163790225982666]]}, "num_entities": 2} -------------------------------------------------------------------------------- /source/Parser/output/7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3.PubTator.biomedner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "7aed0d619cfb7fb7edf932fc5aeff01e489c3be8482e4c08c26f4de3", "entities": {"disease": [{"start": 7, "end": 21}], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": [{"start": 0, "end": 21}]}, "title": "", "abstract": "severe aortic stenosis", "prob": {"disease": [[{"start": 7, "end": 21}, 0.9999975562095642]], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": [[{"start": 0, "end": 21}, 0.41163790225982666]]}, "num_entities": 2} -------------------------------------------------------------------------------- /source/Parser/output/b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c.PubTator.biomedner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "b5d32e3a13ff4d519235aa93bf6faeaa8e80439b77a08925ae9c617c", "entities": {"disease": [{"start": 7, "end": 21}], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": [{"start": 0, "end": 21}]}, "title": "", "abstract": "severe aortic stenosis", "prob": {"disease": [[{"start": 7, "end": 21}, 0.9999975562095642]], "drug": [], "gene": [], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": [[{"start": 0, "end": 21}, 0.41163790225982666]]}, "num_entities": 2} -------------------------------------------------------------------------------- /source/Parser/output/1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42.PubTator.biomedner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "1cd9760fe682423e9e3c37d70f3e932d45259c7c9d4b6e911fbf9b42", "entities": {"disease": [{"start": 0, "end": 33}], "drug": [], "gene": [], "species": [{"start": 0, "end": 23}], "cell line": [], "DNA": [], "RNA": [], "cell type": []}, "title": "", "abstract": "clostridioides difficile infection", "prob": {"disease": [[{"start": 0, "end": 33}, 0.9938603043556213]], "drug": [], "gene": [], "species": [[{"start": 0, "end": 23}, 0.5144878625869751]], "cell line": [], "DNA": [], "RNA": [], "cell type": []}, "num_entities": 2} -------------------------------------------------------------------------------- /source/Parser/output/6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28.PubTator.gner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28", "entities": {"diagnostic test": [], "treatment": [], "laboratory test": [{"start": 12, "end": 54}], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "title": "", "abstract": "Deficiency, alpha-Lecithin:Cholesterol Acyltransferase", "prob": {"diagnostic test": [], "treatment": [], "laboratory test": [[{"start": 12, "end": 54}, 0.8846812844276428]], "surgical procedure": [], "sign symptom": [], "radiology": [], "genomic analysis technique": []}, "num_entities": 1} -------------------------------------------------------------------------------- /source/Parser/scripts/stop_biomedner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define a function to stop a process by its name 4 | stop_process() { 5 | process_name="$1" 6 | pid=$(ps auxww | grep "$process_name" | grep -v grep | awk '{print $2}' | sort -r) 7 | if [ "$pid" != "" ]; then 8 | # Kill each PID one by one 9 | for p in $pid; do 10 | kill -9 "$p" 11 | echo "Stopped $process_name (PID: $p)" 12 | done 13 | else 14 | echo "No $process_name found to stop." 15 | fi 16 | } 17 | 18 | # Call the function for each process 19 | stop_process "biomedner_server.py" 20 | stop_process "disease_normalizer_21.jar" 21 | stop_process "gnormplus-normalization_21.jar" 22 | stop_process "gner_server.py" 23 | 24 | -------------------------------------------------------------------------------- /source/biomedner_services/stop_biomedner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define a function to stop a process by its name 4 | stop_process() { 5 | process_name="$1" 6 | pid=$(ps auxww | grep "$process_name" | grep -v grep | awk '{print $2}' | sort -r) 7 | if [ "$pid" != "" ]; then 8 | # Kill each PID one by one 9 | for p in $pid; do 10 | kill -9 "$p" 11 | echo "Stopped $process_name (PID: $p)" 12 | done 13 | else 14 | echo "No $process_name found to stop." 15 | fi 16 | } 17 | 18 | # Call the function for each process 19 | stop_process "biomedner_server.py" 20 | stop_process "disease_normalizer_21.jar" 21 | stop_process "gnormplus-normalization_21.jar" 22 | stop_process "gner_server.py" 23 | 24 | -------------------------------------------------------------------------------- /source/Parser/output/6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28.PubTator.biomedner.json: -------------------------------------------------------------------------------- 1 | {"pmid": "6d60d5573378fb2ca71a90099fe304ec4d5532cafebec4676f42fe28", "entities": {"disease": [{"start": 0, "end": 25}], "drug": [{"start": 12, "end": 25}, {"start": 27, "end": 37}], "gene": [{"start": 12, "end": 25}], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": [{"start": 0, "end": 9}]}, "title": "", "abstract": "Deficiency, alpha-Lecithin:Cholesterol Acyltransferase", "prob": {"disease": [[{"start": 0, "end": 25}, 0.9047816395759583]], "drug": [[{"start": 12, "end": 25}, 0.9823459982872009], [{"start": 27, "end": 37}, 0.9986056089401245]], "gene": [[{"start": 12, "end": 25}, 0.8241974711418152]], "species": [], "cell line": [], "DNA": [], "RNA": [], "cell type": [[{"start": 0, "end": 9}, 0.3515378534793854]]}, "num_entities": 5} -------------------------------------------------------------------------------- /utils/finetuning/finetune_instruct/split_train_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from sklearn.model_selection import train_test_split 4 | 5 | # File paths 6 | input_file = "medical_o1_reasoning.jsonl" 7 | train_file = "medical_o1_reasoning_train.jsonl" 8 | test_file = "medical_o1_reasoning_test.jsonl" 9 | 10 | # Load data 11 | with open(input_file, "r") as file: 12 | data = [json.loads(line) for line in file] 13 | 14 | # Split data 15 | train_data, test_data = train_test_split(data, test_size=0.1, random_state=42) 16 | 17 | # Save to JSONL 18 | with open(train_file, "w") as file: 19 | for item in train_data: 20 | file.write(json.dumps(item) + "\n") 21 | 22 | with open(test_file, "w") as file: 23 | for item in test_data: 24 | file.write(json.dumps(item) + "\n") 25 | 26 | print(f"Train set size: {len(train_data)}") 27 | print(f"Test set size: {len(test_data)}") 28 | -------------------------------------------------------------------------------- /elasticsearch/.env: -------------------------------------------------------------------------------- 1 | # Password for the 'elastic' user (at least 6 characters) 2 | ELASTIC_PASSWORD="QQ7wWoB_WnKe*L*X9tAW" 3 | 4 | # Password for the 'kibana_system' user (at least 6 characters) 5 | KIBANA_PASSWORD="QQ7wWoB_WnKe*L*X9tAW" 6 | 7 | # Version of Elastic products 8 | STACK_VERSION=8.13.4 9 | 10 | # Set the cluster name 11 | CLUSTER_NAME=docker-cluster 12 | 13 | # Set to 'basic' or 'trial' to automatically start the 30-day trial 14 | LICENSE=basic 15 | #LICENSE=trial 16 | 17 | # Port to expose Elasticsearch HTTP API to the host 18 | ES_PORT=9200 19 | #ES_PORT=127.0.0.1:9200 20 | 21 | # Port to expose Kibana to the host 22 | KIBANA_PORT=5601 23 | #KIBANA_PORT=80 24 | 25 | # Increase or decrease based on the available host memory (in bytes) 26 | MEM_LIMIT=1073741824 27 | 28 | # Project namespace (defaults to the current folder name if not set) 29 | #COMPOSE_PROJECT_NAME=trialmatchai 30 | -------------------------------------------------------------------------------- /source/Parser/normalizers/cellline_normalizer.py: -------------------------------------------------------------------------------- 1 | class CellLineNormalizer(object): 2 | def __init__(self, dict_path): 3 | self.NO_ENTITY_ID = "CUI-less" 4 | 5 | # Create dictionary for exact match 6 | self.cl2oid = dict() 7 | with open(dict_path, "r", encoding="utf-8") as f: 8 | for line in f: 9 | oid, names = line[:-1].split("||") 10 | names = names.split("|") 11 | for name in names: 12 | self.cl2oid[name] = oid 13 | 14 | def normalize(self, names): 15 | oids = list() 16 | for name in names: 17 | if name in self.cl2oid: 18 | oids.append(self.cl2oid[name]) 19 | elif name.lower() in self.cl2oid: 20 | oids.append(self.cl2oid[name.lower()]) 21 | else: 22 | oids.append(self.NO_ENTITY_ID) 23 | 24 | return oids 25 | -------------------------------------------------------------------------------- /elasticsearch/config/es01/elasticsearch.yml: -------------------------------------------------------------------------------- 1 | node.name: es01 2 | path.data: /usr/share/elasticsearch/data 3 | path.logs: /usr/share/elasticsearch/logs 4 | network.host: 127.0.0.1 5 | cluster.name: docker-cluster 6 | discovery.seed_hosts: ["127.0.0.1:9300", "127.0.0.1:9301", "127.0.0.1:9302"] 7 | cluster.initial_master_nodes: ["es01", "es02", "es03"] 8 | bootstrap.memory_lock: false 9 | xpack.security.enabled: true 10 | xpack.security.http.ssl.enabled: true 11 | xpack.security.http.ssl.key: es01.key 12 | xpack.security.http.ssl.certificate: es01.crt 13 | xpack.security.http.ssl.certificate_authorities: ["ca.crt"] 14 | xpack.security.transport.ssl.enabled: true 15 | xpack.security.transport.ssl.verification_mode: certificate 16 | xpack.security.transport.ssl.key: es01.key 17 | xpack.security.transport.ssl.certificate: es01.crt 18 | xpack.security.transport.ssl.certificate_authorities: ["ca.crt"] 19 | xpack.license.self_generated.type: basic 20 | -------------------------------------------------------------------------------- /elasticsearch/config/es02/elasticsearch.yml: -------------------------------------------------------------------------------- 1 | node.name: es02 2 | path.data: /usr/share/elasticsearch/data 3 | path.logs: /usr/share/elasticsearch/logs 4 | network.host: 127.0.0.1 5 | cluster.name: docker-cluster 6 | discovery.seed_hosts: ["127.0.0.1:9300", "127.0.0.1:9301", "127.0.0.1:9302"] 7 | cluster.initial_master_nodes: ["es01", "es02", "es03"] 8 | bootstrap.memory_lock: false 9 | xpack.security.enabled: true 10 | xpack.security.http.ssl.enabled: true 11 | xpack.security.http.ssl.key: es02.key 12 | xpack.security.http.ssl.certificate: es02.crt 13 | xpack.security.http.ssl.certificate_authorities: ["ca.crt"] 14 | xpack.security.transport.ssl.enabled: true 15 | xpack.security.transport.ssl.verification_mode: certificate 16 | xpack.security.transport.ssl.key: es02.key 17 | xpack.security.transport.ssl.certificate: es02.crt 18 | xpack.security.transport.ssl.certificate_authorities: ["ca.crt"] 19 | xpack.license.self_generated.type: basic 20 | -------------------------------------------------------------------------------- /elasticsearch/config/es03/elasticsearch.yml: -------------------------------------------------------------------------------- 1 | node.name: es03 2 | path.data: /usr/share/elasticsearch/data 3 | path.logs: /usr/share/elasticsearch/logs 4 | network.host: 127.0.0.1 5 | cluster.name: docker-cluster 6 | discovery.seed_hosts: ["127.0.0.1:9300", "127.0.0.1:9301", "127.0.0.1:9302"] 7 | cluster.initial_master_nodes: ["es01", "es02", "es03"] 8 | bootstrap.memory_lock: false 9 | xpack.security.enabled: true 10 | xpack.security.http.ssl.enabled: true 11 | xpack.security.http.ssl.key: es03.key 12 | xpack.security.http.ssl.certificate: es03.crt 13 | xpack.security.http.ssl.certificate_authorities: ["ca.crt"] 14 | xpack.security.transport.ssl.enabled: true 15 | xpack.security.transport.ssl.verification_mode: certificate 16 | xpack.security.transport.ssl.key: es03.key 17 | xpack.security.transport.ssl.certificate: es03.crt 18 | xpack.security.transport.ssl.certificate_authorities: ["ca.crt"] 19 | xpack.license.self_generated.type: basic 20 | -------------------------------------------------------------------------------- /utils/finetuning/finetune_instruct/finetune.sh: -------------------------------------------------------------------------------- 1 | nohup torchrun --nproc_per_node 1 ./run.py \ 2 | --output_dir ./finetuned_phi_reasoning \ 3 | --model_name_or_path microsoft/phi-4 \ 4 | --train_data ./finetuning_data/medical_o1_reasoning_train.jsonl\ 5 | --learning_rate 5e-5 \ 6 | --num_train_epochs 2 \ 7 | --per_device_train_batch_size 3 \ 8 | --gradient_accumulation_steps 16 \ 9 | --dataloader_drop_last True \ 10 | --query_max_len 1024 \ 11 | --passage_max_len 1024 \ 12 | --logging_steps 10 \ 13 | --save_steps 1000 \ 14 | --save_total_limit 5 \ 15 | --ddp_find_unused_parameters False \ 16 | --warmup_ratio 0.1 \ 17 | --use_lora True \ 18 | --lora_rank 32 \ 19 | --lora_alpha 64 \ 20 | --lora_dropout 0.1 \ 21 | --use_flash_attn True \ 22 | --max_example_num_per_dataset 26000 \ 23 | --cache_dir scratch/huggingface_cache/hub \ 24 | --target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \ 25 | --bf16 > ./finetune_log.log 2>&1 & disown -------------------------------------------------------------------------------- /source/Parser/normalizers/species_normalizer.py: -------------------------------------------------------------------------------- 1 | class SpeciesNormalizer(object): 2 | def __init__(self, dict_path): 3 | self.NO_ENTITY_ID = "CUI-less" 4 | 5 | # Create dictionary for exact match 6 | self.species2oid = dict() 7 | with open(dict_path, "r", encoding="utf-8") as f: 8 | for line in f: 9 | oid, names = line[:-1].split("||") 10 | names = names.split("|") 11 | for name in names: 12 | # a part of tmChem normalization 13 | self.species2oid[name] = oid 14 | 15 | def normalize(self, names): 16 | oids = list() 17 | for name in names: 18 | if name in self.species2oid: 19 | oids.append(self.species2oid[name]) 20 | elif name.lower() in self.species2oid: 21 | oids.append(self.species2oid[name.lower()]) 22 | else: 23 | oids.append(self.NO_ENTITY_ID) 24 | 25 | return oids 26 | -------------------------------------------------------------------------------- /utils/Indexer/zipper.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euo pipefail 3 | IFS=$'\n\t' 4 | 5 | SRC_DIR="processed_criteria" 6 | CHUNKS=6 7 | CHUNK_PREFIX="criteria_part" 8 | 9 | # Create a working directory for chunks 10 | mkdir -p zip_chunks 11 | 12 | # Step 1: List top-level folders and count them 13 | folders=($(find "$SRC_DIR" -mindepth 1 -maxdepth 1 -type d)) 14 | total=${#folders[@]} 15 | per_chunk=$(( (total + CHUNKS - 1) / CHUNKS )) # Round up 16 | 17 | # Step 2: Divide folders into 6 chunks and zip each 18 | info() { echo -e "\033[0;32m[INFO]\033[0m $*"; } 19 | 20 | info "Total subfolders: $total" 21 | info "Creating $CHUNKS zip chunks, ~${per_chunk} folders each..." 22 | 23 | for ((i=0; i> logs/nohup_multi_ner.out 2>&1 & 14 | 15 | nohup python gner_server.py \ 16 | --model_name_or_path gliner-community/gliner_large-v2.5 \ 17 | --gner_port 18783 >> logs/nohup_gner.out 2>&1 & 18 | 19 | #################################### 20 | ##### Normalization ##### 21 | #################################### 22 | cd resources 23 | # Disease (working dir: normalization/) 24 | cd normalization 25 | nohup java -Xmx16G -jar normalizers/disease/disease_normalizer_21.jar \ 26 | "inputs/disease" \ 27 | "outputs/disease" \ 28 | "dictionary/dict_Disease.txt" \ 29 | "normalizers/disease/resources" \ 30 | 9 \ 31 | 18892 \ 32 | >> ../../logs/nohup_disease_normalize.out 2>&1 & 33 | 34 | # Gene (working dir: normalization/normalizers/gene/, port:18888) 35 | cd normalizers/gene 36 | nohup java -Xmx20G -jar gnormplus-normalization_21.jar \ 37 | 18888 \ 38 | >> ../../../../logs/nohup_gene_normalize.out 2>&1 & 39 | cd ../../../../.. -------------------------------------------------------------------------------- /source/Parser/scripts/run_biomedner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd .. 3 | #!/bin/bash 4 | 5 | if [ ! -d "logs" ]; then 6 | mkdir logs 7 | fi 8 | #################################### 9 | ##### NER ##### 10 | #################################### 11 | 12 | # run neural NER 13 | nohup python biomedner_server.py \ 14 | --model_name_or_path models/finetuned_model_roberta \ 15 | --biomedner_port 18894 >> logs/nohup_multi_ner.out 2>&1 & 16 | 17 | nohup python gner_server.py \ 18 | --model_name_or_path gliner-community/gliner_large-v2.5 \ 19 | --gner_port 18783 >> logs/nohup_gner.out 2>&1 & 20 | 21 | #################################### 22 | ##### Normalization ##### 23 | #################################### 24 | cd resources 25 | # Disease (working dir: normalization/) 26 | cd normalization 27 | nohup java -Xmx16G -jar normalizers/disease/disease_normalizer_21.jar \ 28 | "inputs/disease" \ 29 | "outputs/disease" \ 30 | "dictionary/dict_Disease.txt" \ 31 | "normalizers/disease/resources" \ 32 | 9 \ 33 | 18892 \ 34 | >> ../../logs/nohup_disease_normalize.out 2>&1 & 35 | 36 | # Gene (working dir: normalization/normalizers/gene/, port:18888) 37 | cd normalizers/gene 38 | nohup java -Xmx20G -jar gnormplus-normalization_21.jar \ 39 | 18888 \ 40 | >> ../../../../logs/nohup_gene_normalize.out 2>&1 & 41 | cd ../../../.. -------------------------------------------------------------------------------- /source/Matcher/services/biomedner_service.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import subprocess 3 | import time 4 | 5 | from Matcher.utils.logging_config import setup_logging 6 | 7 | logger = setup_logging() 8 | 9 | 10 | def is_port_in_use(port: int, host: str = "127.0.0.1") -> bool: 11 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 12 | return s.connect_ex((host, port)) == 0 13 | 14 | 15 | def check_ports_in_use(ports: list) -> bool: 16 | return any(is_port_in_use(port) for port in ports) 17 | 18 | 19 | def run_script(script_path: str): 20 | logger.info(f"Executing: {script_path}") 21 | subprocess.run(["bash", script_path], check=True) 22 | 23 | 24 | def initialize_biomedner_services(config: dict): 25 | ports_to_check = [ 26 | config["bio_med_ner"]["biomedner_port"], 27 | config["bio_med_ner"]["gner_port"], 28 | config["bio_med_ner"]["gene_norm_port"], 29 | config["bio_med_ner"]["disease_norm_port"], 30 | ] 31 | if check_ports_in_use(ports_to_check): 32 | logger.info("Detected active services. Stopping running instances...") 33 | run_script(config["services"]["stop_script"]) 34 | logger.info("Waiting for 10 seconds before restarting...") 35 | time.sleep(10) 36 | logger.info("Starting BioMedNER services...") 37 | run_script(config["services"]["run_script"]) 38 | logger.info("BioMedNER services started successfully.") 39 | -------------------------------------------------------------------------------- /source/Matcher/models/embedding/query_embedder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | from transformers import AutoModel, AutoTokenizer 5 | 6 | 7 | class QueryEmbedder: 8 | def __init__( 9 | self, 10 | model_name: str = "ncbi/MedCPT-Query-Encoder", 11 | max_length: int = 512, 12 | use_gpu: bool = True, 13 | ): 14 | self.device = torch.device( 15 | "cuda" if use_gpu and torch.cuda.is_available() else "cpu" 16 | ) 17 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 18 | self.model = AutoModel.from_pretrained(model_name).to(self.device) 19 | self.max_length = max_length 20 | self.model.eval() 21 | 22 | def get_embeddings(self, texts: List[str]) -> Dict[str, List[float]]: 23 | embeddings_dict = {} 24 | for text in texts: 25 | encoded = self.tokenizer( 26 | text, 27 | truncation=True, 28 | padding=True, 29 | return_tensors="pt", 30 | max_length=self.max_length, 31 | ).to(self.device) 32 | with torch.no_grad(): 33 | embeddings = self.model(**encoded).last_hidden_state[:, 0, :] 34 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 35 | embeddings_dict[text] = embeddings.flatten().tolist() 36 | return embeddings_dict 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 4 | Majd Abdallah (1,2) (majd.abdallah@u-bordeaux.fr ) 5 | Macha Nikolski (1,2) (macha.nikolski@u-bordeaux.fr) 6 | 7 | (1) CBiB - University of Bordeaux, 8 | 146, rue Leo Saignat, 33076 Bordeaux, France 9 | 10 | (2) CNRS, IBGC - University of Bordeaux, 11 | 1, rue Camille Saint-Saens, 33077 Bordeaux, France 12 | 13 | 14 | Permission is hereby granted, free of charge, to any person obtaining a copy 15 | of this software and associated documentation files (the "Software"), to deal 16 | in the Software without restriction, including without limitation the rights 17 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 18 | copies of the Software, and to permit persons to whom the Software is 19 | furnished to do so, subject to the following conditions: 20 | 21 | The above copyright notice and this permission notice shall be included in all 22 | copies or substantial portions of the Software. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 28 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 29 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 | SOFTWARE. -------------------------------------------------------------------------------- /source/Parser/normalizers/celltype_normalizer.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import re 3 | 4 | # Load only the lemmatizer component 5 | nlp = spacy.load("en_core_web_sm") 6 | nlp.add_pipe("lemmatizer") 7 | nlp.initialize() 8 | 9 | 10 | class CellTypeNormalizer(object): 11 | def __init__(self, dict_path): 12 | self.NO_ENTITY_ID = "CUI-less" 13 | 14 | # Create dictionary for exact match 15 | self.ct2oid = dict() 16 | with open(dict_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | oid, names = line.strip().split("||") 19 | names = names.split("|") 20 | for name in names: 21 | normalized_name = self.get_tmchem_name(name) 22 | self.ct2oid[normalized_name] = oid 23 | 24 | def normalize(self, names): 25 | oids = [] 26 | for name in names: 27 | normalized_name = self.get_tmchem_name(name) 28 | if normalized_name in self.ct2oid: 29 | oids.append(self.ct2oid[normalized_name]) 30 | else: 31 | oids.append(self.NO_ENTITY_ID) 32 | return oids 33 | 34 | def get_tmchem_name(self, name): 35 | # Lowercase and remove all whitespace and punctuation 36 | cleaned_name = re.sub(r"[^\w\s-]", "", name.lower()).replace(" ", "") 37 | # Use SpaCy to lemmatize the cleaned name 38 | doc = nlp(cleaned_name) 39 | lemmatized_name = "".join([token.lemma_ for token in doc]) 40 | return lemmatized_name 41 | -------------------------------------------------------------------------------- /utils/Indexer/flatten.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | def flatten_folder_structure( 6 | root_dir, output_dir, separator="_", handle_duplicates=True 7 | ): 8 | os.makedirs(output_dir, exist_ok=True) 9 | filename_set = set() 10 | 11 | for subdir, dirs, files in os.walk(root_dir): 12 | # Skip the root directory itself 13 | if subdir == root_dir: 14 | continue 15 | 16 | parent_folder = os.path.basename(subdir) 17 | 18 | for file in files: 19 | old_path = os.path.join(subdir, file) 20 | new_filename = f"{parent_folder}{separator}{file}" 21 | new_path = os.path.join(output_dir, new_filename) 22 | 23 | if handle_duplicates: 24 | base, ext = os.path.splitext(new_filename) 25 | counter = 1 26 | while new_path in filename_set or os.path.exists(new_path): 27 | new_filename = f"{base}_{counter}{ext}" 28 | new_path = os.path.join(output_dir, new_filename) 29 | counter += 1 30 | 31 | filename_set.add(new_path) 32 | shutil.copy2(old_path, new_path) 33 | print(f"Copied: {old_path} → {new_path}") 34 | 35 | 36 | # Example usage: 37 | root_directory = "/home/mabdallah/scratch/TrialMatchAI/src/Indexer/processed_criteria" 38 | output_directory = ( 39 | "/home/mabdallah/scratch/TrialMatchAI/src/Indexer/processed_criteria_flattened" 40 | ) 41 | flatten_folder_structure(root_directory, output_directory) 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ============================ 2 | # General dataset / output dirs 3 | # ============================ 4 | data/ 5 | results/ 6 | src/ 7 | ablation/ 8 | logs/ 9 | *.log 10 | 11 | # ============================ 12 | # Finetuning ignore rules 13 | # ============================ 14 | 15 | # Ignore ANY directory under finetuning/ starting with checkpoint 16 | **/finetuning/**/checkpoint*/ 17 | 18 | # Ignore ANY directory under finetuning/ starting with finetuned 19 | **/finetuning/**/finetuned*/ 20 | 21 | # Ignore ANY directory under finetuning/ ending with "data" 22 | **/finetuning/**/*data/ 23 | 24 | # Ignore this explicit model folder 25 | utils/finetuning/finetune_ner/RoBERTa-large-PM-M3-Voc/ 26 | utils/finetuning/finetune_ner/output_eval 27 | utils/finetuning/finetune_ner 28 | 29 | # ============================ 30 | # ElasticSearch artifacts 31 | # ============================ 32 | elasticsearch/sif/ 33 | elasticsearch/data1/ 34 | elasticsearch/data/ 35 | elasticsearch/logs/ 36 | elasticsearch/sif/*.sif 37 | 38 | # ============================ 39 | # Parser artifacts 40 | # ============================ 41 | source/models/ 42 | source/Parser/resources/ 43 | source/Parser/models/ 44 | source/Parser/input/ 45 | source/Parser/output/ 46 | Parser/logs/ 47 | 48 | # ============================ 49 | # Python cache 50 | # ============================ 51 | __pycache__/ 52 | *.py[cod] 53 | 54 | # ============================ 55 | # Temporary / cache files 56 | # ============================ 57 | *.tmp 58 | *.bak 59 | *.swp 60 | *.slurm 61 | **/tmp/ 62 | .DS_Store 63 | Thumbs.db 64 | -------------------------------------------------------------------------------- /source/Matcher/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Dict, List 4 | 5 | 6 | def read_json_file(file_path: str) -> Dict: 7 | """Read a JSON file and return its contents.""" 8 | try: 9 | with open(file_path, "r", encoding="utf-8") as f: 10 | return json.load(f) 11 | except Exception as e: 12 | raise ValueError(f"Failed to read {file_path}: {str(e)}") 13 | 14 | 15 | def write_json_file(data: Dict, file_path: str): 16 | """Write data to a JSON file.""" 17 | try: 18 | with open(file_path, "w", encoding="utf-8") as f: 19 | json.dump(data, f, indent=4, ensure_ascii=False) 20 | except Exception as e: 21 | raise ValueError(f"Failed to write {file_path}: {str(e)}") 22 | 23 | 24 | def read_text_file(file_path: str) -> List[str]: 25 | """Read lines from a text file.""" 26 | try: 27 | with open(file_path, "r", encoding="utf-8") as f: 28 | return [line.strip() for line in f if line.strip()] 29 | except Exception as e: 30 | raise ValueError(f"Failed to read {file_path}: {str(e)}") 31 | 32 | 33 | def write_text_file(lines: List[str], file_path: str): 34 | """Write lines to a text file.""" 35 | try: 36 | with open(file_path, "w", encoding="utf-8") as f: 37 | f.write("\n".join(lines)) 38 | except Exception as e: 39 | raise ValueError(f"Failed to write {file_path}: {str(e)}") 40 | 41 | 42 | def create_directory(path: str): 43 | """Create a directory if it doesn't exist.""" 44 | os.makedirs(path, exist_ok=True) 45 | -------------------------------------------------------------------------------- /source/Matcher/models/embedding/sentence_embedder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from transformers import AutoModel, AutoTokenizer 6 | 7 | 8 | class SecondLevelSentenceEmbedder: 9 | def __init__(self, model_name: str = "BAAI/bge-m3", use_gpu: bool = True): 10 | self.device = torch.device( 11 | "cuda" if use_gpu and torch.cuda.is_available() else "cpu" 12 | ) 13 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 14 | self.model = AutoModel.from_pretrained(model_name).to(self.device) 15 | self.model.eval() 16 | 17 | @staticmethod 18 | def mean_pooling( 19 | model_output: torch.Tensor, attention_mask: torch.Tensor 20 | ) -> torch.Tensor: 21 | token_embeddings = model_output[0] 22 | input_mask_expanded = ( 23 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 24 | ) 25 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( 26 | input_mask_expanded.sum(1), min=1e-9 27 | ) 28 | 29 | def get_embeddings(self, sentence: str) -> List[float]: 30 | encoded_input = self.tokenizer( 31 | [sentence], padding=True, truncation=True, return_tensors="pt" 32 | ).to(self.device) 33 | with torch.no_grad(): 34 | model_output = self.model(**encoded_input) 35 | sentence_embeddings = self.mean_pooling( 36 | model_output, encoded_input["attention_mask"] 37 | ) 38 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) 39 | return sentence_embeddings[0].tolist() 40 | -------------------------------------------------------------------------------- /source/Parser/normalizers/chemical_normalizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import nltk 4 | from nltk.stem import WordNetLemmatizer 5 | 6 | # Download NLTK resources quietly 7 | nltk.download("wordnet", quiet=True) 8 | nltk.download("omw-1.4", quiet=True) 9 | 10 | 11 | class ChemicalNormalizer(object): 12 | def __init__(self, dict_path): 13 | self.NO_ENTITY_ID = "CUI-less" 14 | self.lemmatizer = WordNetLemmatizer() 15 | 16 | # Create dictionary for exact match 17 | self.chem2oid = dict() 18 | with open(dict_path, "r", encoding="utf-8") as f: 19 | for line in f: 20 | oid, names = line[:-1].split("||") 21 | names = names.split("|") 22 | for name in names: 23 | # a part of tmChem normalization 24 | normalized_name = self.get_tmchem_name(name) 25 | self.chem2oid[normalized_name] = oid 26 | 27 | def normalize(self, names): 28 | oids = list() 29 | for name in names: 30 | # a part of tmChem normalization 31 | normalized_name = self.get_tmchem_name(name) 32 | 33 | if normalized_name in self.chem2oid: 34 | oids.append(self.chem2oid[normalized_name]) 35 | else: 36 | oids.append(self.NO_ENTITY_ID) 37 | 38 | return oids 39 | 40 | def get_tmchem_name(self, name): 41 | # 1. lowercase, 2. removes all whitespace and punctuation 42 | # https://jcheminf.biomedcentral.com/articles/10.1186/1758-2946-7-S1-S3 43 | cleaned_name = re.sub(r"[^\w\s]", "", name.lower()).replace(" ", "") 44 | lemmatized_name = self.lemmatizer.lemmatize(cleaned_name) 45 | return lemmatized_name 46 | -------------------------------------------------------------------------------- /elasticsearch/certs/ca/ca.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpAIBAAKCAQEAmreLYETx2FrGa8oAlDQEnHHOMBjjXDc0ZJ+y0Wxfo1lXMdxv 3 | EiytbmOcNI3QvpoqEXvLyFUdWpp4NSrSKAA8WmleN6wjUk3zzug7Az7vQBnaMZR6 4 | 1I5KX+Zj3MtYYUaVXx4tNcrakoqf3XgPto0oatp2jLoOKOmKKnT3Ylbzq8zL7lXR 5 | nhjTa5YoUx5hDbSOr2adVKv6cWWKh4apM5Qg67mb4Xumy6b228hlacenSQOdTLsW 6 | epushLc9Xzc9n4IftrZfdYV6e1mnjA6Rh2ns6nAm4h0jBvb14VPH7IE0wdjTjYho 7 | h/+t+VWXE220vXbQ3lHuGP7xGNLlEVNncgRdywIDAQABAoIBADTHPSv7iPbRzJNT 8 | pwvnjNUje39r1g+Mo3paAiGv0xZBsV2IgXlVNVqNb2l8IUQMCiLJtNQjuO5B+JTG 9 | hUdxASWkgSgDuE7o2a1xCkSKsQoQZ573NEmTOqrpSJK26XDRp735aNnLV/GaiXt+ 10 | 6/lNwQZmfP93rsHlHSVrnkJ1QA2QOJXwLiBl6FYDtSX/FMlIvDdEFQJw0ueizUpq 11 | jdjyI0yBDhCBtYbiiPaqqn/JdDLmJh6g5fOQrhZ7ugV28Ko8xiHN8n/X4Dgfil8m 12 | qOizYuwEfTuoTf5b0pB3c9MKiug2URinVIvlSSTmYU5cmMDwUI2OfKABwKxgrJ6g 13 | QcgdJUUCgYEAtxhP7R0CkXTziriw4WH5asqPeQsDaokfo2AGv+OgX2TzR9rYofWe 14 | gA+yv/JtV8N0We7lWEU3kwjMrjqPo+FVnMkJB9Ns+EpArYt6VI5jXTc9wv9+Ubg3 15 | bmqMZy4vK/84bukTJmP1kEkBC6uKxTa6b5agZ91ylmzp1EihgtoblRcCgYEA2FKJ 16 | O41p0Tg7DXskOMTalYmV4tuIPnbCQnL3zdYKkxpvP/DEsgFs2WuznTR7CA8UcQpn 17 | dZf7INPvXNrSAFzsQEFUgOMaeL6whzBD/X8uDzqO0ACJAgAT1ixsPNKD9dipOsjJ 18 | XL0JIu/WDnr+RewY/oYmxGBP/rJ3vh90D0QtFW0CgYA7HEhBfsojd6Rgtru0J9NE 19 | HN0w8NNLg7WJIylKrgxKf+bi3c5ui0N+iJLm0TdnzBw2JKA6XS5R89dQsGtLNyZS 20 | lbyqoCFgD2jOHmeCAO4nW/w+hgmcDGMo7JEjho+IHr9zXx/llwPibw89W0ZT4RVT 21 | jUeAAMhLtCHRfRlXi164RQKBgQCxXrt3QPRKwDhrTVTd10sC4dUsNaT40pdltK7K 22 | 732sMDiXzOr6qYB+pXiYpbdbXEH+jfFW0k7vE9dn42PdOFPeO7L2G0BRUMmdj02w 23 | RN+XTQRcOJQeN8IMElCNbm8U3ZmkwY2ZpMYhB1YUeXSUEIpN5+FRk0cEJ4FXOqMH 24 | AXeV4QKBgQCYKWjq4XRLege8h+Cvyi55Kaj12nDc90Oorywx75GrxixEWRbye4aF 25 | qCZg0D+C474MG2C/n58arGuT7jwgX7LDvZVDvHjNcAsLP8JmrvXgUMMHximjdVmP 26 | XokfEI8o2GjsCf0pbgp51jSDV07yBjRh3kxP70J20MC3jLpk+uWt4g== 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /elasticsearch/certs/es01/es01.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEogIBAAKCAQEAmpXRQ+THn+VpNZwD+7+9R7IOhisKjCEGKMqSARr5wQPlulxL 3 | DlHoO7pRLw6JcEXmLn4tx+QuLkkm9pFAADYnIcCljzRkFHADKuo+P7BG6AZ+6l1U 4 | VHwya6iUaERda3+VClTHPzRoPn8TRdX78+cY7dAzme0noosw8WCgpL2HgiYciFiw 5 | le80OKaeiVxvRNJ+6n36a/q4lgLx9wQoJvox/2OCP3318UY1SgA2AO9kFk0zDTvJ 6 | ji9ESSz6d+6LCB73Bk2cKWx6q2pVoSTNPhkC5mD9E/zhuXSACfY3H85GZksJ4jzt 7 | dF1UcsOsrjYsMwwxHH4BwVKak97ERJmICSLV/wIDAQABAoIBABDs2Vo1KNQtcoz/ 8 | lYIRVsCMUsHG4aNBFGMP9tdvJCxJaHQ0mbUqK6KqfiwIS+0CgjbR8uJBbfr8YGs7 9 | sQW06CjuZlIdGt4P+5DNz936R3EtEOVJLawIYx7deM5HufDEqcVVTfFyI/2/vRT2 10 | 3lywj06ubo/qYt4NnmC3Qy92XulVZknvjaolGhqVlYGbHQITSI47KNErFX0KiL77 11 | dYkO3dY75XIc4NsECk5UyFRyu1Mg6dkG+a8HahPVmX7ctq+bytQ3r9hpwXrY1MZV 12 | RXtEW81kwB0yJV3GmXJRrqKZR+vvpqMRIQhmLZgr0XkKIshBuBneAo4cUX7MVgAX 13 | JcdHsMUCgYEA0P7NgUL4xg9Xtpygf5DWldcCO/cy6yoOkXEJ6d9xOhQny24qi0MN 14 | uOuyCCF4UQnhdEYcR9aY6YOwviQluq6roQT7g1Jex20F2iaaLVNdxaGi5s9nmzgp 15 | CVfBdLK798GCu/yrRwhy1EW6kaQdIJ9D0K1Fkm8SFOAuQY0gKgbK6BsCgYEAvVpH 16 | AKtcUJiqyenkvyhnxmpZXhS93F2qSJnzBJ/2wAuQ7BCxw/vW7vw1XDDy+IKFXLQ4 17 | eRnbIs2hk+FBAgAQDJl26dUELaE+BKsmU1XACRBrnpLZXNd0ssLc1dMFewkxURmT 18 | krWCku32OjVVnPuIALE4Bv9hBSvotYFmf7DXL+0CgYBU05bFqFEg0olfbSMXo8n0 19 | 91fIzwSzvlY7Yg4MBs0GLbgZMZXDAGxJaiDQfAVBnykK8In5/ngCD5llE3bc1piC 20 | umr7Witt9iox6Qka7INa+8gKtpPuxFSjniK/Iux4GurdMiiypBM3ZTXcdyf7XalA 21 | wZNDZCGKp5MeuBEd/bPNkQKBgAlVzRBUYm26yRjBRjzCYjNfBN7liOK3X3DK3jdJ 22 | J6IaL9/jhtARt2v61SqhYykrTiXe4LXft3UEzEV9InZVyHTGkB1BGj6hp2wVgAM1 23 | xAzuWU/tD3hLSv6RKtAD4k5Jirvj1emytyhFQRFnlbvyjqbyFcAKkR7vJj7kjUgY 24 | UNOVAoGACkTgg+e0Cxu/rzPqJTcwB7gXikBK9gXegi9fXv2Q8q860Y37MXkTCStF 25 | elsIix7jrUokjXWG8Gc44GfJUCv+WLZCeqaL3oS+ZznfbESfiLwEecu/6jsb3SEk 26 | Swx3Bb5X8q0/S0kD3ShMVPFnFIO37P37efMFaYJ2Fs1C3FpWfwM= 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /elasticsearch/certs/es02/es02.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEowIBAAKCAQEAyaaCYcE+eDk4++1N2j0edhOQeJtRlVAY4okWx/GKZ8bUrFgW 3 | eXDmYnty8n2BjOsYy5AikELeZhsA6xWqV7/l8Tmhfv27e1qsxB1zNBurFQ4B2/O7 4 | YXHRXs2uKJqKb2Sj1+2GM6xQxZqZQS39MX2rKmENhud8kyDZUpQfqlUNE+QgRjMK 5 | MvB+7U9lrTfP/XUW2QosY+dzUub1K6pWf/avccf3WJd6R8aTYv2KDQjFsR6sib/R 6 | WHXTvtIDQlynQW4SFxSbhplXDNN8nkaHa1lygAIHcKu77KulFrbTsJZ+hJIMjKn4 7 | HYsMuCH2KYDK+1cZToJ1MDvoA9Xa2yWjbXWaiwIDAQABAoIBABGFNZ6FD0ab1i8m 8 | uFUbaq+gqzScw3ML3Ar8FUN3DEBQXr8WAjyNSinJXGc5mg133xYku6lkHfWqJ/jE 9 | +kwYgMPmRCQIALJK4mhBv3Y0vIxzmfFrP7R2PVZAManBJQELsflQtRB9Ssv4eZNX 10 | tGrd5obOmuxwkCV/vPzMoORdnRQas5a0i0ZtlCjv8jirX/t1VOLlsvoaIGs6TbTm 11 | op3s7zi9ETvnCnZ6VKzsSugwUGCRnFqsIgUt8PUpKPXfgvzdRc5/j1NNWocfwLlH 12 | lTSipk4n64KPevvG/deWw6AnnYLSosIQM4MCC9JrgZ1DhacGs23L7WWbz/1JC+oT 13 | AgEP9FkCgYEAz3pj3PdibwhM/VWuOxwWWR4YS8UJZx1lcFPXzgE4t7V+/YSKuUtL 14 | ro2plGUG1Ji5SgsS0exH/mzaF8OhHX1Jt+eGEVJyzdfJN8S1lomVtOoMyjkyBgxM 15 | I4bajIB1Sz9aipKwVAfise9f6Ga3E0/ET9iq6i0/R1wK5EwgdOCxkBMCgYEA+M84 16 | W4S5P4EyfaeLOw+puyZcgG0ftVVqlAHdkdQlQLXCt3ijI008A0f5GVqf2tTLVrAj 17 | Ub39aDOrM1wA6FWCqsa8LEbtdD4d1xNDi8DegDvjjjmU1z8ULjxSu8iSikhOCw2G 18 | vyXHsr04DYBg4OURD9Pg8hHGtANZSNtc6CwISqkCgYBogbqZi9Z+HQ1Csgy/42by 19 | XrFYQRh6Yxk8Wk8iigT6rCYaJtAFg4LMmrincbfeEEuMm0VQjha5djTosXaPNxOR 20 | 2cHzKbeALchCGghpmkXZSedFWUf0Oe+EGaIuEWqDi5bcpATDXvF2NR/3HP3scUpt 21 | +bIloML1+8vUsO/MT33BFwKBgQCiXR+K4Wa94UKgqwf5p7P8VAFDMXLis3XUVg9Q 22 | DZ+txa7maYwUCl+iSIJuoCv28qwqytCRlCjcqfMLlftlof+eEAhV4IcuNybj5kdK 23 | 2LaZ+fr6IetWN2yk62qV7kJqiNqc7dvDuxTBOCdu8BrIR9NFf9+oOB9x80l9eOD9 24 | BVb32QKBgAjHylTABCzS+2zOwUQeMpXQnYeNhMvvvYug/mDh+GQQL7zVeSJXlNGS 25 | PkJqE3xEoVPDOPJ+Kr8nl5jS3YGCWcWerYURd9XiX7lHx7LlOeb6LusNIm/mO563 26 | y2U3WPIDTdy36vnvz3FZKYIionGFYc5hRGI54GfWKbCdr9ug2Xaw 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /elasticsearch/certs/es03/es03.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEowIBAAKCAQEAkPwxA6T6MKFD+6Q40onk2WYgy/EDfSIQXQODWtpFa7omS1eC 3 | XQ9xByopJsKBDxX/6Ooiwq/HNRG4Yfo82eh4TsHs9INbloXxI6Ha8hUJy7KpgJN+ 4 | 6y/0+TDR1dvSSQUXunzDEbvFLJCdxnGX2oSwQrewq5HvWsecEXivt/G5X5BXdkvP 5 | DMnjFMYVF3vCCt/83v+INcKWs7OyVM6PoHEgxBPTgmLObn9LFyn72CPz0ujGzcuz 6 | dxM2Fpac2XggNC1oDwuUrZnyYLr5SAgPc/HRl8xyntCWbyXJf6WgU4gUqUGbdEn+ 7 | O7nrppxczSuvEwJoRHHNWHrD9EVD6z1kRyiuuQIDAQABAoIBAAeRot1rLMSCfBIU 8 | WK26aoAO8iyzawM1VrVZtu8fRfssafi08NFQOmOtXR3w5DFCwha0DRcs0LGB41D6 9 | HK3OPQ701hyjCZJUxmnzIS2ne4Nu0vqAuLd+dsOA5kjL2QaG2t5EH8rlNT3xBCDq 10 | p+AjMoXJISf+eQhm3fbriKuJqwEaFoFs8gFOm7BKSzw4nFdwaBI0hO7QQng9Um8x 11 | g27R+mJAGG+fNFJuDFD6gy2iIN+xkrNP9WGlFmEV4OoCUi1hz5YfxNRs2xCxUA3E 12 | ggzcU0UvmLccHwz6PYK8EzTmsZUJsLRJTyG3WuNuaVTN67oMLYRv7GlGdZ998+HE 13 | 7m3KU6ECgYEAurfhPf8ACbMxR49iF62QJUzHM5xu0Qq85z6lWKHiCXWaIcqNv2mI 14 | qnao9a5wBqEsTNNLljDlr99u8stHt+Zx3gAA1F5Pew/3quwOtaFx+zvxqrQKs6pb 15 | cz3q9Ad8sJVthr06SW5W4OvdDGOF52EV+bcITZCetENmk9jvbN6z0ucCgYEAxsgi 16 | enAcripDIT8uxhiEWxEtnsZunOsGWzXKcKJeg7w7cErqHRBb4liPenq3B/6sIp3r 17 | AwPyqRIfVXZNbOpEPwoaF/1bQ9HujHuO0HEMsoeJh12KkBSBuHQXnA/X01Y6Nf0l 18 | ZijizBE/Qc5EkWyxZx+pSTf7RVGf+1DlmKBE3V8CgYBIWUurrA0ltQtZQROvPQ9n 19 | hJKDSxAda926dKm46DEfnTP19/hovMm503SwjcDWsMjrk8vsDFJTjW3+IgpOFbr1 20 | XGb14v1FH/DFh+ZDNqVlxdpkXJLw/wekZc+OcwA7pArmdJgLL/f1+y6RyFZwS0wq 21 | kGNlOq5kBuHOU/ah5sEi7QKBgQDGiT1mbHM4wJ0rp59f2zzWd+HIowf3UgWXM7Jt 22 | rL4ZdPcowKnzPVOITkt/WPFV2tax/GetK1RB6QfCo9XQ4monTD+jljiBFDvds8qA 23 | BWlZJmYF/TdXkCO/xrON+4TkX0rkgWHJFyzuBIvZfdqeJKFLDiRWLMOaCFxw9eta 24 | 9TfSoQKBgEhHjogylvtkkiDIg+ArNhezYCjZ2yVPPsM8Jejh3sFpmjHHLaiw86Np 25 | fRrxp3ealK1DXziRni+USCNM8iM6GFo2I6jBMWTJDTsc6wSyYLu4MlLTXyobZZnc 26 | SwtvyZip2+VL4qEeg9p2c5wxbDBPdxh/8vjW9QxCeAovIr/0lth4 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /elasticsearch/config/es01/es01.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEogIBAAKCAQEAmpXRQ+THn+VpNZwD+7+9R7IOhisKjCEGKMqSARr5wQPlulxL 3 | DlHoO7pRLw6JcEXmLn4tx+QuLkkm9pFAADYnIcCljzRkFHADKuo+P7BG6AZ+6l1U 4 | VHwya6iUaERda3+VClTHPzRoPn8TRdX78+cY7dAzme0noosw8WCgpL2HgiYciFiw 5 | le80OKaeiVxvRNJ+6n36a/q4lgLx9wQoJvox/2OCP3318UY1SgA2AO9kFk0zDTvJ 6 | ji9ESSz6d+6LCB73Bk2cKWx6q2pVoSTNPhkC5mD9E/zhuXSACfY3H85GZksJ4jzt 7 | dF1UcsOsrjYsMwwxHH4BwVKak97ERJmICSLV/wIDAQABAoIBABDs2Vo1KNQtcoz/ 8 | lYIRVsCMUsHG4aNBFGMP9tdvJCxJaHQ0mbUqK6KqfiwIS+0CgjbR8uJBbfr8YGs7 9 | sQW06CjuZlIdGt4P+5DNz936R3EtEOVJLawIYx7deM5HufDEqcVVTfFyI/2/vRT2 10 | 3lywj06ubo/qYt4NnmC3Qy92XulVZknvjaolGhqVlYGbHQITSI47KNErFX0KiL77 11 | dYkO3dY75XIc4NsECk5UyFRyu1Mg6dkG+a8HahPVmX7ctq+bytQ3r9hpwXrY1MZV 12 | RXtEW81kwB0yJV3GmXJRrqKZR+vvpqMRIQhmLZgr0XkKIshBuBneAo4cUX7MVgAX 13 | JcdHsMUCgYEA0P7NgUL4xg9Xtpygf5DWldcCO/cy6yoOkXEJ6d9xOhQny24qi0MN 14 | uOuyCCF4UQnhdEYcR9aY6YOwviQluq6roQT7g1Jex20F2iaaLVNdxaGi5s9nmzgp 15 | CVfBdLK798GCu/yrRwhy1EW6kaQdIJ9D0K1Fkm8SFOAuQY0gKgbK6BsCgYEAvVpH 16 | AKtcUJiqyenkvyhnxmpZXhS93F2qSJnzBJ/2wAuQ7BCxw/vW7vw1XDDy+IKFXLQ4 17 | eRnbIs2hk+FBAgAQDJl26dUELaE+BKsmU1XACRBrnpLZXNd0ssLc1dMFewkxURmT 18 | krWCku32OjVVnPuIALE4Bv9hBSvotYFmf7DXL+0CgYBU05bFqFEg0olfbSMXo8n0 19 | 91fIzwSzvlY7Yg4MBs0GLbgZMZXDAGxJaiDQfAVBnykK8In5/ngCD5llE3bc1piC 20 | umr7Witt9iox6Qka7INa+8gKtpPuxFSjniK/Iux4GurdMiiypBM3ZTXcdyf7XalA 21 | wZNDZCGKp5MeuBEd/bPNkQKBgAlVzRBUYm26yRjBRjzCYjNfBN7liOK3X3DK3jdJ 22 | J6IaL9/jhtARt2v61SqhYykrTiXe4LXft3UEzEV9InZVyHTGkB1BGj6hp2wVgAM1 23 | xAzuWU/tD3hLSv6RKtAD4k5Jirvj1emytyhFQRFnlbvyjqbyFcAKkR7vJj7kjUgY 24 | UNOVAoGACkTgg+e0Cxu/rzPqJTcwB7gXikBK9gXegi9fXv2Q8q860Y37MXkTCStF 25 | elsIix7jrUokjXWG8Gc44GfJUCv+WLZCeqaL3oS+ZznfbESfiLwEecu/6jsb3SEk 26 | Swx3Bb5X8q0/S0kD3ShMVPFnFIO37P37efMFaYJ2Fs1C3FpWfwM= 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /elasticsearch/config/es02/es02.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEowIBAAKCAQEAyaaCYcE+eDk4++1N2j0edhOQeJtRlVAY4okWx/GKZ8bUrFgW 3 | eXDmYnty8n2BjOsYy5AikELeZhsA6xWqV7/l8Tmhfv27e1qsxB1zNBurFQ4B2/O7 4 | YXHRXs2uKJqKb2Sj1+2GM6xQxZqZQS39MX2rKmENhud8kyDZUpQfqlUNE+QgRjMK 5 | MvB+7U9lrTfP/XUW2QosY+dzUub1K6pWf/avccf3WJd6R8aTYv2KDQjFsR6sib/R 6 | WHXTvtIDQlynQW4SFxSbhplXDNN8nkaHa1lygAIHcKu77KulFrbTsJZ+hJIMjKn4 7 | HYsMuCH2KYDK+1cZToJ1MDvoA9Xa2yWjbXWaiwIDAQABAoIBABGFNZ6FD0ab1i8m 8 | uFUbaq+gqzScw3ML3Ar8FUN3DEBQXr8WAjyNSinJXGc5mg133xYku6lkHfWqJ/jE 9 | +kwYgMPmRCQIALJK4mhBv3Y0vIxzmfFrP7R2PVZAManBJQELsflQtRB9Ssv4eZNX 10 | tGrd5obOmuxwkCV/vPzMoORdnRQas5a0i0ZtlCjv8jirX/t1VOLlsvoaIGs6TbTm 11 | op3s7zi9ETvnCnZ6VKzsSugwUGCRnFqsIgUt8PUpKPXfgvzdRc5/j1NNWocfwLlH 12 | lTSipk4n64KPevvG/deWw6AnnYLSosIQM4MCC9JrgZ1DhacGs23L7WWbz/1JC+oT 13 | AgEP9FkCgYEAz3pj3PdibwhM/VWuOxwWWR4YS8UJZx1lcFPXzgE4t7V+/YSKuUtL 14 | ro2plGUG1Ji5SgsS0exH/mzaF8OhHX1Jt+eGEVJyzdfJN8S1lomVtOoMyjkyBgxM 15 | I4bajIB1Sz9aipKwVAfise9f6Ga3E0/ET9iq6i0/R1wK5EwgdOCxkBMCgYEA+M84 16 | W4S5P4EyfaeLOw+puyZcgG0ftVVqlAHdkdQlQLXCt3ijI008A0f5GVqf2tTLVrAj 17 | Ub39aDOrM1wA6FWCqsa8LEbtdD4d1xNDi8DegDvjjjmU1z8ULjxSu8iSikhOCw2G 18 | vyXHsr04DYBg4OURD9Pg8hHGtANZSNtc6CwISqkCgYBogbqZi9Z+HQ1Csgy/42by 19 | XrFYQRh6Yxk8Wk8iigT6rCYaJtAFg4LMmrincbfeEEuMm0VQjha5djTosXaPNxOR 20 | 2cHzKbeALchCGghpmkXZSedFWUf0Oe+EGaIuEWqDi5bcpATDXvF2NR/3HP3scUpt 21 | +bIloML1+8vUsO/MT33BFwKBgQCiXR+K4Wa94UKgqwf5p7P8VAFDMXLis3XUVg9Q 22 | DZ+txa7maYwUCl+iSIJuoCv28qwqytCRlCjcqfMLlftlof+eEAhV4IcuNybj5kdK 23 | 2LaZ+fr6IetWN2yk62qV7kJqiNqc7dvDuxTBOCdu8BrIR9NFf9+oOB9x80l9eOD9 24 | BVb32QKBgAjHylTABCzS+2zOwUQeMpXQnYeNhMvvvYug/mDh+GQQL7zVeSJXlNGS 25 | PkJqE3xEoVPDOPJ+Kr8nl5jS3YGCWcWerYURd9XiX7lHx7LlOeb6LusNIm/mO563 26 | y2U3WPIDTdy36vnvz3FZKYIionGFYc5hRGI54GfWKbCdr9ug2Xaw 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /elasticsearch/config/es03/es03.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEowIBAAKCAQEAkPwxA6T6MKFD+6Q40onk2WYgy/EDfSIQXQODWtpFa7omS1eC 3 | XQ9xByopJsKBDxX/6Ooiwq/HNRG4Yfo82eh4TsHs9INbloXxI6Ha8hUJy7KpgJN+ 4 | 6y/0+TDR1dvSSQUXunzDEbvFLJCdxnGX2oSwQrewq5HvWsecEXivt/G5X5BXdkvP 5 | DMnjFMYVF3vCCt/83v+INcKWs7OyVM6PoHEgxBPTgmLObn9LFyn72CPz0ujGzcuz 6 | dxM2Fpac2XggNC1oDwuUrZnyYLr5SAgPc/HRl8xyntCWbyXJf6WgU4gUqUGbdEn+ 7 | O7nrppxczSuvEwJoRHHNWHrD9EVD6z1kRyiuuQIDAQABAoIBAAeRot1rLMSCfBIU 8 | WK26aoAO8iyzawM1VrVZtu8fRfssafi08NFQOmOtXR3w5DFCwha0DRcs0LGB41D6 9 | HK3OPQ701hyjCZJUxmnzIS2ne4Nu0vqAuLd+dsOA5kjL2QaG2t5EH8rlNT3xBCDq 10 | p+AjMoXJISf+eQhm3fbriKuJqwEaFoFs8gFOm7BKSzw4nFdwaBI0hO7QQng9Um8x 11 | g27R+mJAGG+fNFJuDFD6gy2iIN+xkrNP9WGlFmEV4OoCUi1hz5YfxNRs2xCxUA3E 12 | ggzcU0UvmLccHwz6PYK8EzTmsZUJsLRJTyG3WuNuaVTN67oMLYRv7GlGdZ998+HE 13 | 7m3KU6ECgYEAurfhPf8ACbMxR49iF62QJUzHM5xu0Qq85z6lWKHiCXWaIcqNv2mI 14 | qnao9a5wBqEsTNNLljDlr99u8stHt+Zx3gAA1F5Pew/3quwOtaFx+zvxqrQKs6pb 15 | cz3q9Ad8sJVthr06SW5W4OvdDGOF52EV+bcITZCetENmk9jvbN6z0ucCgYEAxsgi 16 | enAcripDIT8uxhiEWxEtnsZunOsGWzXKcKJeg7w7cErqHRBb4liPenq3B/6sIp3r 17 | AwPyqRIfVXZNbOpEPwoaF/1bQ9HujHuO0HEMsoeJh12KkBSBuHQXnA/X01Y6Nf0l 18 | ZijizBE/Qc5EkWyxZx+pSTf7RVGf+1DlmKBE3V8CgYBIWUurrA0ltQtZQROvPQ9n 19 | hJKDSxAda926dKm46DEfnTP19/hovMm503SwjcDWsMjrk8vsDFJTjW3+IgpOFbr1 20 | XGb14v1FH/DFh+ZDNqVlxdpkXJLw/wekZc+OcwA7pArmdJgLL/f1+y6RyFZwS0wq 21 | kGNlOq5kBuHOU/ah5sEi7QKBgQDGiT1mbHM4wJ0rp59f2zzWd+HIowf3UgWXM7Jt 22 | rL4ZdPcowKnzPVOITkt/WPFV2tax/GetK1RB6QfCo9XQ4monTD+jljiBFDvds8qA 23 | BWlZJmYF/TdXkCO/xrON+4TkX0rkgWHJFyzuBIvZfdqeJKFLDiRWLMOaCFxw9eta 24 | 9TfSoQKBgEhHjogylvtkkiDIg+ArNhezYCjZ2yVPPsM8Jejh3sFpmjHHLaiw86Np 25 | fRrxp3ealK1DXziRni+USCNM8iM6GFo2I6jBMWTJDTsc6wSyYLu4MlLTXyobZZnc 26 | SwtvyZip2+VL4qEeg9p2c5wxbDBPdxh/8vjW9QxCeAovIr/0lth4 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /source/Matcher/utils/temporal_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from datetime import datetime 3 | from typing import Dict, Optional 4 | 5 | 6 | def parse_temporal(temporal_obj: Optional[Dict]) -> str: 7 | """Parse complex temporal elements with error handling.""" 8 | if not temporal_obj: 9 | return "Timing not specified" 10 | try: 11 | if "age" in temporal_obj: 12 | return parse_iso_duration(temporal_obj["age"].get("iso8601duration")) 13 | if "timestamp" in temporal_obj: 14 | return datetime.fromisoformat(temporal_obj["timestamp"]).strftime( 15 | "%Y-%m-%d" 16 | ) 17 | if "interval" in temporal_obj: 18 | start = temporal_obj["interval"].get("start", "unknown") 19 | end = temporal_obj["interval"].get("end", "unknown") 20 | return f"{start} to {end}" 21 | return "Timing information available" 22 | except Exception as e: 23 | return f"Timing information unavailable: {str(e)}" 24 | 25 | 26 | def parse_iso_duration(duration: Optional[str]) -> str: 27 | """Convert ISO8601 duration to a human-readable format.""" 28 | if not duration: 29 | return "Age unspecified" 30 | try: 31 | match = re.match(r"P(?:(\d+)Y)?(?:(\d+)M)?(?:(\d+)D)?", duration) 32 | parts = [] 33 | if match: 34 | if match.group(1): 35 | parts.append(f"{match.group(1)} years") 36 | if match.group(2): 37 | parts.append(f"{match.group(2)} months") 38 | if match.group(3): 39 | parts.append(f"{match.group(3)} days") 40 | return " ".join(parts) if parts else duration 41 | return duration 42 | except Exception as e: 43 | return f"Duration parsing failed: {str(e)}" 44 | -------------------------------------------------------------------------------- /source/Matcher/config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bio_med_ner": { 3 | "biomedner_port": 18894, 4 | "gner_port": 18783, 5 | "gene_norm_port": 18888, 6 | "disease_norm_port": 18892, 7 | "biomedner_home": "Parser/", 8 | "use_neural_normalizer": true, 9 | "no_cuda": false 10 | }, 11 | "services": { 12 | "stop_script": "biomedner_services/run_biomedner.sh", 13 | "run_script": "biomedner_services/run_biomedner.sh" 14 | }, 15 | "paths": { 16 | "patients_dir": "../example/", 17 | "output_dir": "../results/", 18 | "trials_json_folder": "../data/trials_jsons", 19 | "docker_certs": "../elasticsearch/certs/ca.crt" 20 | }, 21 | "model": { 22 | "base_model": "microsoft/phi-4", 23 | "quantization": { 24 | "load_in_4bit": true, 25 | "bnb_4bit_use_double_quant": true, 26 | "bnb_4bit_quant_type": "nf4", 27 | "bnb_4bit_compute_dtype": "float16" 28 | }, 29 | "cot_adapter_path": "models/finetuned_phi_reasoning", 30 | "reranker_model_path": "google/gemma-2-2b-it", 31 | "reranker_adapter_path": "models/finetuned_gemma2" 32 | }, 33 | "tokenizer": { 34 | "use_fast": true, 35 | "padding_side": "left" 36 | }, 37 | "global": { 38 | "device": 0 39 | }, 40 | "elasticsearch": { 41 | "host": "https://localhost:9200", 42 | "username": "elastic", 43 | "password": "QQ7wWoB_WnKe*L*X9tAW", 44 | "request_timeout": 600, 45 | "retry_on_timeout": true, 46 | "index_trials": "clinical_trials", 47 | "index_trials_eligibility": "eligibility_criteria" 48 | }, 49 | "embedder": { 50 | "model_name": "BAAI/bge-m3" 51 | }, 52 | "cot": { 53 | "batch_size": 10 54 | }, 55 | "LLM_reranker": { 56 | "batch_size": 20 57 | }, 58 | "search": { 59 | "vector_score_threshold": 0.5, 60 | "max_trials_first_level": 300, 61 | "max_trials_second_level": 100 62 | }, 63 | "use_cot_reasoning": true, 64 | "cot_backend": "vllm", 65 | "rag": { 66 | "batch_size": 4, 67 | "max_trials_rag": 20 68 | }, 69 | "vllm": { 70 | "batch_size": 100, 71 | "max_new_tokens": 5000, 72 | "temperature": 0.0, 73 | "top_p": 1.0, 74 | "seed": 1234, 75 | "length_bucket": true, 76 | "gpu_memory_utilization": 0.5, 77 | "max_model_len": 8192, 78 | "tensor_parallel_size": 1 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /utils/Preprocessor/preprocessing.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | from tqdm.auto import tqdm 3 | from preprocessing_utils import eic_text_preprocessing 4 | from preprocess_clinical_notes import tokenize_clinical_note 5 | import pandas as pd 6 | import os 7 | 8 | memory = joblib.Memory(".") 9 | 10 | 11 | def ParallelExecutor(use_bar="tqdm", **joblib_args): 12 | """Utility for tqdm progress bar in joblib.Parallel""" 13 | all_bar_funcs = { 14 | "tqdm": lambda args: lambda x: tqdm(x, **args), 15 | "False": lambda args: iter, 16 | "None": lambda args: iter, 17 | } 18 | 19 | def aprun(bar=use_bar, **tq_args): 20 | def tmp(op_iter): 21 | if str(bar) in all_bar_funcs.keys(): 22 | bar_func = all_bar_funcs[str(bar)](tq_args) 23 | else: 24 | raise ValueError("Value %s not supported as bar type" % bar) 25 | # Pass n_jobs from joblib_args 26 | return joblib.Parallel(n_jobs=joblib_args.get("n_jobs", 10))( 27 | bar_func(op_iter) 28 | ) 29 | 30 | return tmp 31 | 32 | return aprun 33 | 34 | 35 | class Preprocessor: 36 | def __init__(self, id_list, n_jobs): 37 | self.id_list = id_list 38 | self.n_jobs = n_jobs 39 | 40 | def preprocess_clinical_trials_text(self): 41 | parallel_runner = ParallelExecutor(n_jobs=self.n_jobs)(total=len(self.id_list)) 42 | X = parallel_runner( 43 | joblib.delayed(eic_text_preprocessing)([_id]) for _id in self.id_list 44 | ) 45 | return pd.concat(X).reset_index(drop=True) 46 | 47 | def preprocess_patient_clinical_notes(self): 48 | parallel_runner = ParallelExecutor(n_jobs=self.n_jobs)(total=len(self.id_list)) 49 | X = parallel_runner( 50 | joblib.delayed(tokenize_clinical_note)([_id]) for _id in self.id_list 51 | ) 52 | return pd.concat(X).reset_index(drop=True) 53 | 54 | 55 | if __name__ == "__main__": 56 | # Load the list of NCT IDs 57 | folder_path = "../../data/trials_xmls" 58 | file_names = [] 59 | # List all files in the folder 60 | for file in os.listdir(folder_path): 61 | if os.path.isfile(os.path.join(folder_path, file)): 62 | file_name, file_extension = os.path.splitext(file) 63 | file_names.append(file_name) 64 | nct_ids = file_names 65 | n_jobs = 10 66 | preprocessor = Preprocessor(nct_ids, n_jobs) 67 | preprocessor.preprocess_clinical_trials_text() 68 | -------------------------------------------------------------------------------- /example/phenopacket/keywords.json: -------------------------------------------------------------------------------- 1 | { 2 | "main_conditions": [ 3 | "Myocardial infarction", 4 | "Heart attack", 5 | "Cardiac infarction", 6 | "Type 2 diabetes mellitus", 7 | "Diabetes", 8 | "Elevated hepatic transaminase level", 9 | "Liver enzyme elevation", 10 | "Hypercholesterolemia", 11 | "High cholesterol", 12 | "Coronary artery disease" 13 | ], 14 | "other_conditions": [ 15 | "Recurrent myocardial infarction", 16 | "Emergency stenting", 17 | "Bypass surgery", 18 | "HbA1c levels between 6.5% and 7.2%", 19 | "Mild elevation of liver enzymes", 20 | "Statin therapy", 21 | "Non-alcoholic fatty liver disease", 22 | "LDL levels >190 mg/dL", 23 | "Lifestyle interventions", 24 | "Former smoker", 25 | "20 pack-years", 26 | "Quit smoking at age 45", 27 | "Stage III coronary artery disease", 28 | "Multivessel involvement", 29 | "Impaired left ventricular function", 30 | "Ischemic cardiomyopathy", 31 | "Genetic testing", 32 | "Heterozygous LDLR mutation", 33 | "Familial hypercholesterolemia", 34 | "Atorvastatin treatment", 35 | "Metformin treatment", 36 | "Coronary Artery Bypass Grafting", 37 | "Sudden cardiac death in father", 38 | "Father deceased at age 62" 39 | ], 40 | "expanded_sentences": [ 41 | "The patient is a 58-year-old male with a complex cardiac history, including multiple myocardial infarctions.", 42 | "He has experienced recurrent heart attacks, necessitating emergency stenting and later bypass surgery.", 43 | "The patient has type 2 diabetes mellitus, managed with metformin, maintaining HbA1c levels between 6.5% and 7.2%.", 44 | "There is a mild elevation in liver enzymes, suspected to be secondary to statin therapy or non-alcoholic fatty liver disease.", 45 | "He has hypercholesterolemia, with LDL levels exceeding 190 mg/dL despite lifestyle interventions.", 46 | "The patient is a former smoker, having quit at age 45 after 20 pack-years of smoking.", 47 | "He has advanced coronary artery disease, stage III, with multivessel involvement and impaired left ventricular function.", 48 | "A tissue biopsy taken during bypass surgery revealed ischemic cardiomyopathy.", 49 | "Genetic testing confirmed a heterozygous LDLR mutation, consistent with familial hypercholesterolemia.", 50 | "The patient is currently on atorvastatin and metformin treatments.", 51 | "His father experienced sudden cardiac death at age 62." 52 | ] 53 | } -------------------------------------------------------------------------------- /utils/finetuning/finetune_instruct/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Optional 4 | 5 | import torch 6 | from peft import get_peft_model_state_dict 7 | from transformers.integrations import is_deepspeed_zero3_enabled 8 | from transformers.trainer import Trainer 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class SFTTrainer(Trainer): 14 | use_lora: bool 15 | 16 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 17 | # Custom saving logic depending on whether we're using LoRA or not 18 | if not self.use_lora: 19 | # If not using LoRA, just use the default save implementation 20 | super()._save(output_dir, state_dict) 21 | return 22 | 23 | # Using LoRA 24 | output_dir = output_dir if output_dir is not None else self.args.output_dir 25 | os.makedirs(output_dir, exist_ok=True) 26 | logger.info("Saving model checkpoint to %s", output_dir) 27 | 28 | # Ensure model has the `save` method implemented 29 | if not hasattr(self.model, "save"): 30 | raise NotImplementedError( 31 | f"MODEL {self.model.__class__.__name__} does not support save interface" 32 | ) 33 | else: 34 | self.model.save(output_dir) 35 | 36 | # Save training arguments 37 | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) 38 | 39 | # If using DeepSpeed ZeRO-3, save LoRA adapters separately 40 | if is_deepspeed_zero3_enabled(): 41 | if state_dict is None: 42 | state_dict = self.model.state_dict() 43 | prefix = "model." 44 | assert all(k.startswith(prefix) for k in state_dict.keys()), list( 45 | state_dict.keys() 46 | ) 47 | state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()} 48 | lora_state_dict = get_peft_model_state_dict(self.model.model, state_dict) 49 | if self.args.process_index <= 0: 50 | torch.save( 51 | lora_state_dict, os.path.join(output_dir, "adapter_model.bin") 52 | ) 53 | logger.info(f"Saved LoRA adapter model at {output_dir}") 54 | 55 | def compute_loss( 56 | self, model, inputs, return_outputs=False, num_items_in_batch=None 57 | ): 58 | """ 59 | How the loss is computed by Trainer. 60 | For causal language modeling tasks, the model returns the loss directly if labels are provided. 61 | """ 62 | outputs = model(**inputs) 63 | loss = outputs.loss 64 | 65 | # Optionally use num_items_in_batch if needed for custom logic 66 | if num_items_in_batch is not None: 67 | # Example: Adjust loss based on batch size if required 68 | pass 69 | 70 | return (loss, outputs) if return_outputs else loss 71 | -------------------------------------------------------------------------------- /utils/finetuning/finetune_instruct/load_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import LoraConfig, PeftModel, TaskType, get_peft_model 3 | from transformers import AutoModelForCausalLM, BitsAndBytesConfig 4 | 5 | # from local_gemma import LocalGemma2ForCausalLM 6 | 7 | 8 | def get_model(model_args, training_args): 9 | # model = LocalGemma2ForCausalLM.from_pretrained(model_args.model_name_or_path, 10 | # preset="auto", 11 | # attn_implementation='eager', 12 | # cache_dir=model_args.cache_dir, 13 | # torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16, 14 | # token=model_args.token, 15 | # from_tf=bool(".ckpt" in model_args.model_name_or_path), 16 | # trust_remote_code=True, 17 | # use_flash_attention_2=True if model_args.use_flash_attn else False) 18 | quantization_config = BitsAndBytesConfig( 19 | load_in_4bit=True, 20 | bnb_4bit_use_double_quant=True, 21 | bnb_4bit_quant_type="nf4", 22 | bnb_4bit_compute_dtype="float16", 23 | ) 24 | 25 | model = AutoModelForCausalLM.from_pretrained( 26 | model_args.model_name_or_path, 27 | torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16, 28 | token=model_args.token, 29 | cache_dir=model_args.cache_dir, 30 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 31 | trust_remote_code=True, 32 | attn_implementation="flash_attention_2", 33 | quantization_config=quantization_config, 34 | ) 35 | 36 | model = model.to("cuda") 37 | 38 | if torch.cuda.device_count() > 1: # If more than 1 GPU 39 | model.is_parallelizable = True 40 | model.model_parallel = True 41 | 42 | model.config.use_cache = False 43 | 44 | if model_args.from_peft is not None: 45 | model = PeftModel.from_pretrained( 46 | model, model_args.from_peft, is_trainable=True 47 | ) 48 | model.print_trainable_parameters() 49 | else: 50 | if model_args.use_lora: 51 | peft_config = LoraConfig( 52 | task_type=TaskType.CAUSAL_LM, 53 | inference_mode=False, 54 | r=model_args.lora_rank, 55 | target_modules=model_args.target_modules, 56 | lora_alpha=model_args.lora_alpha, 57 | lora_dropout=model_args.lora_dropout, 58 | modules_to_save=model_args.lora_extra_parameters, 59 | ) 60 | model = get_peft_model(model, peft_config) 61 | model.print_trainable_parameters() 62 | 63 | print(model) 64 | return model 65 | -------------------------------------------------------------------------------- /elasticsearch/config/es01/jvm.options: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | ## 3 | ## JVM configuration 4 | ## 5 | ################################################################ 6 | ## 7 | ## WARNING: DO NOT EDIT THIS FILE. If you want to override the 8 | ## JVM options in this file, or set any additional options, you 9 | ## should create one or more files in the jvm.options.d 10 | ## directory containing your adjustments. 11 | ## 12 | ## See https://www.elastic.co/guide/en/elasticsearch/reference/8.13/jvm-options.html 13 | ## for more information. 14 | ## 15 | ################################################################ 16 | 17 | 18 | 19 | ################################################################ 20 | ## IMPORTANT: JVM heap size 21 | ################################################################ 22 | ## 23 | ## The heap size is automatically configured by Elasticsearch 24 | ## based on the available memory in your system and the roles 25 | ## each node is configured to fulfill. If specifying heap is 26 | ## required, it should be done through a file in jvm.options.d, 27 | ## which should be named with .options suffix, and the min and 28 | ## max should be set to the same value. For example, to set the 29 | ## heap to 4 GB, create a new file in the jvm.options.d 30 | ## directory containing these lines: 31 | ## 32 | ## -Xms4g 33 | ## -Xmx4g 34 | ## 35 | ## See https://www.elastic.co/guide/en/elasticsearch/reference/8.13/heap-size.html 36 | ## for more information 37 | ## 38 | ################################################################ 39 | 40 | 41 | ################################################################ 42 | ## Expert settings 43 | ################################################################ 44 | ## 45 | ## All settings below here are considered expert settings. Do 46 | ## not adjust them unless you understand what you are doing. Do 47 | ## not edit them in this file; instead, create a new file in the 48 | ## jvm.options.d directory containing your adjustments. 49 | ## 50 | ################################################################ 51 | 52 | -XX:+UseG1GC 53 | 54 | ## JVM temporary directory 55 | -Djava.io.tmpdir=${ES_TMPDIR} 56 | 57 | # Leverages accelerated vector hardware instructions; removing this may 58 | # result in less optimal vector performance 59 | 20-:--add-modules=jdk.incubator.vector 60 | 61 | ## heap dumps 62 | 63 | # generate a heap dump when an allocation from the Java heap fails; heap dumps 64 | # are created in the working directory of the JVM unless an alternative path is 65 | # specified 66 | -XX:+HeapDumpOnOutOfMemoryError 67 | 68 | # exit right after heap dump on out of memory error 69 | -XX:+ExitOnOutOfMemoryError 70 | 71 | # specify an alternative path for heap dumps; ensure the directory exists and 72 | # has sufficient space 73 | -XX:HeapDumpPath=data 74 | 75 | # specify an alternative path for JVM fatal error logs 76 | -XX:ErrorFile=logs/hs_err_pid%p.log 77 | 78 | ## GC logging 79 | -Xlog:gc*,gc+age=trace,safepoint:file=logs/gc.log:utctime,level,pid,tags:filecount=32,filesize=64m 80 | -------------------------------------------------------------------------------- /elasticsearch/config/es02/jvm.options: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | ## 3 | ## JVM configuration 4 | ## 5 | ################################################################ 6 | ## 7 | ## WARNING: DO NOT EDIT THIS FILE. If you want to override the 8 | ## JVM options in this file, or set any additional options, you 9 | ## should create one or more files in the jvm.options.d 10 | ## directory containing your adjustments. 11 | ## 12 | ## See https://www.elastic.co/guide/en/elasticsearch/reference/8.13/jvm-options.html 13 | ## for more information. 14 | ## 15 | ################################################################ 16 | 17 | 18 | 19 | ################################################################ 20 | ## IMPORTANT: JVM heap size 21 | ################################################################ 22 | ## 23 | ## The heap size is automatically configured by Elasticsearch 24 | ## based on the available memory in your system and the roles 25 | ## each node is configured to fulfill. If specifying heap is 26 | ## required, it should be done through a file in jvm.options.d, 27 | ## which should be named with .options suffix, and the min and 28 | ## max should be set to the same value. For example, to set the 29 | ## heap to 4 GB, create a new file in the jvm.options.d 30 | ## directory containing these lines: 31 | ## 32 | ## -Xms4g 33 | ## -Xmx4g 34 | ## 35 | ## See https://www.elastic.co/guide/en/elasticsearch/reference/8.13/heap-size.html 36 | ## for more information 37 | ## 38 | ################################################################ 39 | 40 | 41 | ################################################################ 42 | ## Expert settings 43 | ################################################################ 44 | ## 45 | ## All settings below here are considered expert settings. Do 46 | ## not adjust them unless you understand what you are doing. Do 47 | ## not edit them in this file; instead, create a new file in the 48 | ## jvm.options.d directory containing your adjustments. 49 | ## 50 | ################################################################ 51 | 52 | -XX:+UseG1GC 53 | 54 | ## JVM temporary directory 55 | -Djava.io.tmpdir=${ES_TMPDIR} 56 | 57 | # Leverages accelerated vector hardware instructions; removing this may 58 | # result in less optimal vector performance 59 | 20-:--add-modules=jdk.incubator.vector 60 | 61 | ## heap dumps 62 | 63 | # generate a heap dump when an allocation from the Java heap fails; heap dumps 64 | # are created in the working directory of the JVM unless an alternative path is 65 | # specified 66 | -XX:+HeapDumpOnOutOfMemoryError 67 | 68 | # exit right after heap dump on out of memory error 69 | -XX:+ExitOnOutOfMemoryError 70 | 71 | # specify an alternative path for heap dumps; ensure the directory exists and 72 | # has sufficient space 73 | -XX:HeapDumpPath=data 74 | 75 | # specify an alternative path for JVM fatal error logs 76 | -XX:ErrorFile=logs/hs_err_pid%p.log 77 | 78 | ## GC logging 79 | -Xlog:gc*,gc+age=trace,safepoint:file=logs/gc.log:utctime,level,pid,tags:filecount=32,filesize=64m 80 | -------------------------------------------------------------------------------- /elasticsearch/config/es03/jvm.options: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | ## 3 | ## JVM configuration 4 | ## 5 | ################################################################ 6 | ## 7 | ## WARNING: DO NOT EDIT THIS FILE. If you want to override the 8 | ## JVM options in this file, or set any additional options, you 9 | ## should create one or more files in the jvm.options.d 10 | ## directory containing your adjustments. 11 | ## 12 | ## See https://www.elastic.co/guide/en/elasticsearch/reference/8.13/jvm-options.html 13 | ## for more information. 14 | ## 15 | ################################################################ 16 | 17 | 18 | 19 | ################################################################ 20 | ## IMPORTANT: JVM heap size 21 | ################################################################ 22 | ## 23 | ## The heap size is automatically configured by Elasticsearch 24 | ## based on the available memory in your system and the roles 25 | ## each node is configured to fulfill. If specifying heap is 26 | ## required, it should be done through a file in jvm.options.d, 27 | ## which should be named with .options suffix, and the min and 28 | ## max should be set to the same value. For example, to set the 29 | ## heap to 4 GB, create a new file in the jvm.options.d 30 | ## directory containing these lines: 31 | ## 32 | ## -Xms4g 33 | ## -Xmx4g 34 | ## 35 | ## See https://www.elastic.co/guide/en/elasticsearch/reference/8.13/heap-size.html 36 | ## for more information 37 | ## 38 | ################################################################ 39 | 40 | 41 | ################################################################ 42 | ## Expert settings 43 | ################################################################ 44 | ## 45 | ## All settings below here are considered expert settings. Do 46 | ## not adjust them unless you understand what you are doing. Do 47 | ## not edit them in this file; instead, create a new file in the 48 | ## jvm.options.d directory containing your adjustments. 49 | ## 50 | ################################################################ 51 | 52 | -XX:+UseG1GC 53 | 54 | ## JVM temporary directory 55 | -Djava.io.tmpdir=${ES_TMPDIR} 56 | 57 | # Leverages accelerated vector hardware instructions; removing this may 58 | # result in less optimal vector performance 59 | 20-:--add-modules=jdk.incubator.vector 60 | 61 | ## heap dumps 62 | 63 | # generate a heap dump when an allocation from the Java heap fails; heap dumps 64 | # are created in the working directory of the JVM unless an alternative path is 65 | # specified 66 | -XX:+HeapDumpOnOutOfMemoryError 67 | 68 | # exit right after heap dump on out of memory error 69 | -XX:+ExitOnOutOfMemoryError 70 | 71 | # specify an alternative path for heap dumps; ensure the directory exists and 72 | # has sufficient space 73 | -XX:HeapDumpPath=data 74 | 75 | # specify an alternative path for JVM fatal error logs 76 | -XX:ErrorFile=logs/hs_err_pid%p.log 77 | 78 | ## GC logging 79 | -Xlog:gc*,gc+age=trace,safepoint:file=logs/gc.log:utctime,level,pid,tags:filecount=32,filesize=64m 80 | -------------------------------------------------------------------------------- /elasticsearch/tmp-config/jvm.options: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | ## 3 | ## JVM configuration 4 | ## 5 | ################################################################ 6 | ## 7 | ## WARNING: DO NOT EDIT THIS FILE. If you want to override the 8 | ## JVM options in this file, or set any additional options, you 9 | ## should create one or more files in the jvm.options.d 10 | ## directory containing your adjustments. 11 | ## 12 | ## See https://www.elastic.co/guide/en/elasticsearch/reference/8.13/jvm-options.html 13 | ## for more information. 14 | ## 15 | ################################################################ 16 | 17 | 18 | 19 | ################################################################ 20 | ## IMPORTANT: JVM heap size 21 | ################################################################ 22 | ## 23 | ## The heap size is automatically configured by Elasticsearch 24 | ## based on the available memory in your system and the roles 25 | ## each node is configured to fulfill. If specifying heap is 26 | ## required, it should be done through a file in jvm.options.d, 27 | ## which should be named with .options suffix, and the min and 28 | ## max should be set to the same value. For example, to set the 29 | ## heap to 4 GB, create a new file in the jvm.options.d 30 | ## directory containing these lines: 31 | ## 32 | ## -Xms4g 33 | ## -Xmx4g 34 | ## 35 | ## See https://www.elastic.co/guide/en/elasticsearch/reference/8.13/heap-size.html 36 | ## for more information 37 | ## 38 | ################################################################ 39 | 40 | 41 | ################################################################ 42 | ## Expert settings 43 | ################################################################ 44 | ## 45 | ## All settings below here are considered expert settings. Do 46 | ## not adjust them unless you understand what you are doing. Do 47 | ## not edit them in this file; instead, create a new file in the 48 | ## jvm.options.d directory containing your adjustments. 49 | ## 50 | ################################################################ 51 | 52 | -XX:+UseG1GC 53 | 54 | ## JVM temporary directory 55 | -Djava.io.tmpdir=${ES_TMPDIR} 56 | 57 | # Leverages accelerated vector hardware instructions; removing this may 58 | # result in less optimal vector performance 59 | 20-:--add-modules=jdk.incubator.vector 60 | 61 | ## heap dumps 62 | 63 | # generate a heap dump when an allocation from the Java heap fails; heap dumps 64 | # are created in the working directory of the JVM unless an alternative path is 65 | # specified 66 | -XX:+HeapDumpOnOutOfMemoryError 67 | 68 | # exit right after heap dump on out of memory error 69 | -XX:+ExitOnOutOfMemoryError 70 | 71 | # specify an alternative path for heap dumps; ensure the directory exists and 72 | # has sufficient space 73 | -XX:HeapDumpPath=data 74 | 75 | # specify an alternative path for JVM fatal error logs 76 | -XX:ErrorFile=logs/hs_err_pid%p.log 77 | 78 | ## GC logging 79 | -Xlog:gc*,gc+age=trace,safepoint:file=logs/gc.log:utctime,level,pid,tags:filecount=32,filesize=64m 80 | -------------------------------------------------------------------------------- /utils/finetuning/finetune_instruct/modeling.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import torch 6 | from torch import Tensor, nn 7 | from transformers import AutoTokenizer 8 | from transformers.modeling_outputs import ModelOutput 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @dataclass 14 | class LMOutput(ModelOutput): 15 | loss: Optional[Tensor] = None 16 | logits: Optional[Tensor] = None 17 | 18 | 19 | class LanguageModelFinetuner(nn.Module): 20 | def __init__( 21 | self, 22 | model: nn.Module, 23 | tokenizer: AutoTokenizer = None, 24 | train_batch_size: int = 4, 25 | enable_gradient_checkpointing: bool = False, 26 | ): 27 | super().__init__() 28 | self.model = model 29 | self.tokenizer = tokenizer 30 | self.train_batch_size = train_batch_size 31 | 32 | if self.model.config.pad_token_id is None and self.tokenizer is not None: 33 | self.model.config.pad_token_id = self.tokenizer.pad_token_id 34 | self.config = self.model.config 35 | 36 | if enable_gradient_checkpointing and hasattr( 37 | self.model, "gradient_checkpointing_enable" 38 | ): 39 | self.model.gradient_checkpointing_enable() 40 | 41 | def gradient_checkpointing_enable(self, **kwargs): 42 | if hasattr(self.model, "gradient_checkpointing_enable"): 43 | self.model.gradient_checkpointing_enable(**kwargs) 44 | 45 | def enable_input_require_grads(self, **kwargs): 46 | if hasattr(self.model, "enable_input_require_grads"): 47 | self.model.enable_input_require_grads(**kwargs) 48 | 49 | def forward( 50 | self, 51 | input_ids: torch.Tensor = None, 52 | attention_mask: torch.Tensor = None, 53 | labels: torch.Tensor = None, 54 | ) -> LMOutput: 55 | device = next( 56 | self.model.parameters() 57 | ).device # Move inputs to the model's device 58 | input_ids = input_ids.to(device) 59 | attention_mask = attention_mask.to(device) 60 | if labels is not None: 61 | labels = labels.to(device) 62 | 63 | outputs = self.model( 64 | input_ids=input_ids, attention_mask=attention_mask, labels=labels 65 | ) 66 | 67 | return LMOutput( 68 | loss=outputs.loss if hasattr(outputs, "loss") else None, 69 | logits=outputs.logits if hasattr(outputs, "logits") else None, 70 | ) 71 | 72 | def save(self, output_dir: str): 73 | # Save the model (with weights) to output_dir 74 | state_dict = self.model.state_dict() 75 | state_dict = type(state_dict)( 76 | {k: v.clone().cpu() for k, v in state_dict.items()} 77 | ) 78 | self.model.save_pretrained(output_dir, state_dict=state_dict) 79 | 80 | def save_pretrained(self, **kwargs): 81 | if self.tokenizer is not None: 82 | self.tokenizer.save_pretrained(**kwargs) 83 | return self.model.save_pretrained(**kwargs) 84 | -------------------------------------------------------------------------------- /source/Matcher/pipeline/trial_ranker.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List 3 | 4 | from Matcher.utils.file_utils import read_json_file, write_json_file 5 | from Matcher.utils.logging_config import setup_logging 6 | 7 | logger = setup_logging() 8 | 9 | 10 | def load_trial_data(json_folder: str) -> List[Dict]: 11 | trial_data = [] 12 | for file_name in os.listdir(json_folder): 13 | if file_name.endswith(".json"): 14 | file_path = os.path.join(json_folder, file_name) 15 | trial_id = os.path.splitext(file_name)[0] 16 | try: 17 | trial = read_json_file(file_path) 18 | trial["TrialID"] = trial_id 19 | trial_data.append(trial) 20 | except Exception as e: 21 | logger.error(f"Failed to load {file_name}: {e}") 22 | return trial_data 23 | 24 | 25 | def score_trial(trial: Dict) -> float: 26 | def calculate_ratio( 27 | criteria_list, positive_classifications, negative_classifications 28 | ): 29 | criteria_to_exclude = ["Irrelevant", "Unclear"] 30 | criteria_list = [ 31 | c 32 | for c in criteria_list 33 | if c.get("Classification") not in criteria_to_exclude 34 | ] 35 | total_criteria = len(criteria_list) 36 | if total_criteria == 0: 37 | return 0.0 38 | positive_count = sum( 39 | 1 40 | for c in criteria_list 41 | if c.get("Classification") in positive_classifications 42 | ) 43 | negative_count = sum( 44 | 1 45 | for c in criteria_list 46 | if c.get("Classification") in negative_classifications 47 | ) 48 | penalty_factor_negative = 1.0 49 | reward_factor_positive = 1.0 50 | score = ( 51 | reward_factor_positive * positive_count 52 | - penalty_factor_negative * negative_count 53 | ) / total_criteria 54 | return score 55 | 56 | inclusion_criteria = trial.get("Inclusion_Criteria_Evaluation", []) 57 | exclusion_criteria = trial.get("Exclusion_Criteria_Evaluation", []) 58 | inclusion_ratio = calculate_ratio( 59 | inclusion_criteria, ["Met", "Not Violated"], ["Violated", "Not Met"] 60 | ) 61 | exclusion_ratio = calculate_ratio( 62 | exclusion_criteria, ["Not Violated", "Met"], ["Violated"] 63 | ) 64 | return (inclusion_ratio + exclusion_ratio) / 2 65 | 66 | 67 | def rank_trials(trial_data: List[Dict]) -> List[Dict]: 68 | ranked_trials = [] 69 | for trial in trial_data: 70 | trial_id = trial.get("TrialID", "Unknown") 71 | score = score_trial(trial) 72 | ranked_trials.append({"TrialID": trial_id, "Score": score}) 73 | ranked_trials.sort(key=lambda x: x["Score"], reverse=True) 74 | return ranked_trials 75 | 76 | 77 | def save_ranked_trials(ranked_trials: List[Dict], output_file: str): 78 | try: 79 | write_json_file({"RankedTrials": ranked_trials}, output_file) 80 | logger.info(f"Ranked trials saved to {output_file}") 81 | except Exception as e: 82 | logger.error(f"Failed to save ranked trials: {e}") 83 | -------------------------------------------------------------------------------- /utils/Preprocessor/jsonify.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | import json 3 | import os 4 | 5 | 6 | def parse_xml(xml_file): 7 | tree = ET.parse(xml_file) 8 | root = tree.getroot() 9 | 10 | data = {} 11 | 12 | # Extracting required fields 13 | data["nct_id"] = root.findtext("id_info/nct_id") 14 | data["brief_title"] = root.findtext("brief_title") 15 | data["official_title"] = root.findtext("official_title") 16 | data["brief_summary"] = root.findtext("brief_summary/textblock") 17 | data["detailed_description"] = root.findtext("detailed_description/textblock") 18 | data["overall_status"] = root.findtext("overall_status") 19 | data["start_date"] = root.findtext("start_date") 20 | data["completion_date"] = root.findtext("completion_date") 21 | data["phase"] = root.findtext("phase") 22 | data["study_type"] = root.findtext("study_type") 23 | 24 | # Handle multiple conditions 25 | data["condition"] = [cond.text for cond in root.findall("condition")] 26 | 27 | # Handle multiple interventions 28 | data["intervention"] = [] 29 | for intervention in root.findall("intervention"): 30 | data["intervention"].append( 31 | { 32 | "intervention_type": intervention.findtext("intervention_type"), 33 | "intervention_name": intervention.findtext("intervention_name"), 34 | } 35 | ) 36 | 37 | data["gender"] = root.findtext("eligibility/gender") 38 | data["minimum_age"] = root.findtext("eligibility/minimum_age") 39 | data["maximum_age"] = root.findtext("eligibility/maximum_age") 40 | 41 | # Extract eligibility criteria as a single block 42 | data["eligibility_criteria"] = root.findtext("eligibility/criteria/textblock") 43 | 44 | # Handle multiple locations 45 | data["location"] = [] 46 | for location in root.findall("location"): 47 | city = location.findtext("facility/address/city") 48 | state = location.findtext("facility/address/state") 49 | country = location.findtext("facility/address/country") 50 | location_name = location.findtext("facility/name") 51 | location_address = ", ".join(filter(None, [city, state, country])) 52 | data["location"].append( 53 | {"location_name": location_name, "location_address": location_address} 54 | ) 55 | 56 | # Handle multiple references 57 | data["reference"] = [] 58 | for ref in root.findall("reference"): 59 | data["reference"].append( 60 | {"citation": ref.findtext("citation"), "PMID": ref.findtext("PMID")} 61 | ) 62 | 63 | return data 64 | 65 | 66 | def convert_to_json(data): 67 | json_data = json.dumps(data, indent=4) 68 | return json_data 69 | 70 | 71 | def process_files(input_dir, output_dir): 72 | for filename in os.listdir(input_dir): 73 | if filename.endswith(".xml"): 74 | xml_file = os.path.join(input_dir, filename) 75 | data = parse_xml(xml_file) 76 | json_data = convert_to_json(data) 77 | 78 | # Save JSON to output directory 79 | json_file = os.path.join(output_dir, filename.replace(".xml", ".json")) 80 | with open(json_file, "w") as f: 81 | f.write(json_data) 82 | 83 | 84 | if __name__ == "__main__": 85 | input_dir = "../../data/trials_xmls/" 86 | output_dir = "../../data/trials_jsons/" 87 | process_files(input_dir, output_dir) 88 | -------------------------------------------------------------------------------- /source/Parser/biomedner_server.py: -------------------------------------------------------------------------------- 1 | # server_script.py 2 | import argparse 3 | import json 4 | import os 5 | import socket 6 | import struct 7 | from datetime import datetime 8 | 9 | from biomedner_init import BioMedNER 10 | from ops import filter_entities, pubtator2dict_list 11 | 12 | 13 | def count_entities(data): 14 | num_entities = 0 15 | for d in data: 16 | if "entities" not in d: 17 | continue 18 | for ent_type, entities in d["entities"].items(): 19 | num_entities += len(entities) 20 | return num_entities 21 | 22 | 23 | def biomedner_recognize(model, dict_path, base_name, biomedner_home, args): 24 | # Ensure input and output directories exist within biomedner_home 25 | input_dir = os.path.join(biomedner_home, "input") 26 | output_dir = os.path.join(biomedner_home, "output") 27 | os.makedirs(input_dir, exist_ok=True) 28 | os.makedirs(output_dir, exist_ok=True) 29 | 30 | input_mt_ner = os.path.join( 31 | biomedner_home, "input", f"{dict_path}.biomedner.PubTator" 32 | ) 33 | output_mt_ner = os.path.join( 34 | biomedner_home, "output", f"{dict_path}.biomedner.json" 35 | ) 36 | 37 | dict_list = pubtator2dict_list(input_mt_ner) 38 | 39 | res = model.recognize(input_dl=dict_list, base_name=base_name) 40 | 41 | if res is None: 42 | return None, 0 43 | 44 | num_filtered_species_per_doc = filter_entities(res) 45 | for n_f_spcs in num_filtered_species_per_doc: 46 | if n_f_spcs[1] > 0: 47 | print( 48 | datetime.now().strftime(args.time_format), 49 | "[{}] Filtered {} species".format(base_name, n_f_spcs[1]), 50 | ) 51 | num_entities = count_entities(res) 52 | 53 | res[0]["num_entities"] = num_entities 54 | # Write output str to a .PubTator format file 55 | with open(output_mt_ner, "w", encoding="utf-8") as f: 56 | json.dump(res[0], f) 57 | 58 | 59 | def run_server(model, args): 60 | host = args.biomedner_host 61 | port = args.biomedner_port 62 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 63 | s.bind((host, port)) 64 | s.listen(600) 65 | print(f"Server listening on {host}:{port}") 66 | while True: 67 | conn, addr = s.accept() 68 | with conn: 69 | print(f"Connected by {addr}") 70 | message_length = struct.unpack(">H", conn.recv(2))[0] 71 | message = conn.recv(message_length).decode("utf-8") 72 | request = json.loads(message) 73 | biomedner_home = request["biomedner_home"] 74 | inputfile = request["inputfile"] 75 | base_name = inputfile.split(".")[0] 76 | base_name = base_name.replace("\x00A", "") 77 | 78 | biomedner_recognize(model, inputfile, base_name, biomedner_home, args) 79 | 80 | output_stream = struct.pack(">H", len(inputfile)) + inputfile.encode( 81 | "utf-8" 82 | ) 83 | conn.send(output_stream) 84 | print(f"Response sent for {inputfile}") 85 | 86 | 87 | if __name__ == "__main__": 88 | argparser = argparse.ArgumentParser() 89 | argparser.add_argument( 90 | "--seed", type=int, help="random seed for initialization", default=1 91 | ) 92 | argparser.add_argument("--model_name_or_path") 93 | argparser.add_argument( 94 | "--max_seq_length", 95 | type=int, 96 | help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.", 97 | default=512, 98 | ) 99 | argparser.add_argument( 100 | "--biomedner_host", help="biomedical language model host", default="localhost" 101 | ) 102 | argparser.add_argument( 103 | "--biomedner_port", 104 | type=int, 105 | help="biomedical language model port", 106 | default=18894, 107 | ) 108 | argparser.add_argument( 109 | "--time_format", help="time format", default="[%d/%b/%Y %H:%M:%S.%f]" 110 | ) 111 | argparser.add_argument( 112 | "--no_cuda", action="store_false", help="Avoid using CUDA when available" 113 | ) 114 | args = argparser.parse_args() 115 | mt_ner = BioMedNER(args) 116 | run_server(mt_ner, args) 117 | -------------------------------------------------------------------------------- /elasticsearch/apptainer-run-es.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euo pipefail 3 | 4 | #=== LOAD .env ===# 5 | if [ ! -f .env ]; then 6 | echo "[ERROR] .env file not found." 7 | exit 1 8 | fi 9 | 10 | # shellcheck disable=SC1091 11 | source .env 12 | 13 | #=== CONFIGURATION FROM ENV ===# 14 | STACK_VERSION="${STACK_VERSION:-8.13.4}" 15 | CLUSTER_NAME="${CLUSTER_NAME:-apptainer-cluster}" 16 | ES_IMAGE="docker.elastic.co/elasticsearch/elasticsearch:${STACK_VERSION}" 17 | 18 | BASE_DIR="$(pwd)" 19 | CERTS_DIR="$BASE_DIR/certs" 20 | CONFIG_DIR="$BASE_DIR/config" 21 | DATA_DIR="$BASE_DIR/data" 22 | LOGS_DIR="$BASE_DIR/logs" 23 | SIF_DIR="$BASE_DIR/sif" 24 | TMP_CONFIG="$BASE_DIR/tmp-config" 25 | 26 | ES_PORT1="${ES_PORT:-9200}" 27 | ES_PORT2=$((ES_PORT1 + 1)) 28 | ES_PORT3=$((ES_PORT1 + 2)) 29 | 30 | ELASTIC_PASSWORD="${ELASTIC_PASSWORD:?ELASTIC_PASSWORD not set in .env}" 31 | 32 | #=== PREPARE FOLDERS ===# 33 | mkdir -p "$CONFIG_DIR/es01" "$CONFIG_DIR/es02" "$CONFIG_DIR/es03" 34 | mkdir -p "$DATA_DIR/es01" "$DATA_DIR/es02" "$DATA_DIR/es03" 35 | mkdir -p "$LOGS_DIR" "$SIF_DIR" "$TMP_CONFIG" 36 | 37 | #=== BUILD SIF IMAGE IF NEEDED ===# 38 | if [ ! -f "$SIF_DIR/es.sif" ]; then 39 | echo "[INFO] Building Elasticsearch SIF..." 40 | apptainer build "$SIF_DIR/es.sif" "docker://$ES_IMAGE" 41 | fi 42 | 43 | #=== CLEAN CONFIG EXTRACTION DIR ===# 44 | rm -rf "$TMP_CONFIG"/* 45 | 46 | #=== EXTRACT DEFAULT CONFIG FILES ===# 47 | echo "[INFO] Extracting default Elasticsearch config files..." 48 | apptainer exec --bind "$TMP_CONFIG:/mnt/tmp" "$SIF_DIR/es.sif" \ 49 | bash -c 'cp -r /usr/share/elasticsearch/config/* /mnt/tmp/' 50 | 51 | #=== PREPARE NODE CONFIGS ===# 52 | for NODE in es01 es02 es03; do 53 | cp -r "$TMP_CONFIG"/* "$CONFIG_DIR/$NODE/" 54 | cp "$CERTS_DIR/ca.crt" "$CONFIG_DIR/$NODE/" 55 | cp "$CERTS_DIR/$NODE/$NODE.crt" "$CONFIG_DIR/$NODE/" 56 | cp "$CERTS_DIR/$NODE/$NODE.key" "$CONFIG_DIR/$NODE/" 57 | 58 | cat > "$CONFIG_DIR/$NODE/elasticsearch.yml" < "$LOGS_DIR/$NODE.log" 2>&1 & 108 | } 109 | 110 | launch_node es01 $ES_PORT1 9300 111 | sleep 10 112 | launch_node es02 $ES_PORT2 9301 113 | sleep 10 114 | launch_node es03 $ES_PORT3 9302 115 | sleep 10 116 | 117 | #=== WAIT FOR ES01 TO BE READY ===# 118 | echo "[INFO] Waiting for es01 to be ready..." 119 | until curl -s --cacert "$CERTS_DIR/ca.crt" -u elastic:"$ELASTIC_PASSWORD" \ 120 | https://localhost:$ES_PORT1/_cluster/health?pretty | grep -q '"status"'; do 121 | echo -n "." 122 | sleep 5 123 | done 124 | echo -e "\n[INFO] Elasticsearch cluster is up." 125 | 126 | echo "[INFO] Access Elasticsearch at: https://localhost:$ES_PORT1" 127 | -------------------------------------------------------------------------------- /source/regex/exception_regex_patterns.json: -------------------------------------------------------------------------------- 1 | { 2 | "patterns": { 3 | "pattern1": { 4 | "regex": "^\\d+\\s$", 5 | "comment": "Matches digits followed by a space." 6 | }, 7 | "pattern2": { 8 | "regex": "^\\d+(\\.|\\-)\\d+$", 9 | "comment": "Matches digits separated by a dot (.) or hyphen (-)." 10 | }, 11 | "pattern3": { 12 | "regex": "^\\d+(\\.\\d+)?\\s?[×x]\\s?\\d+(\\.\\d+)?/?[A-Za-z/]+?$", 13 | "comment": "Matches expressions like '3x 2.5mm' or '5.75x2.5mm/s'." 14 | }, 15 | "pattern4": { 16 | "regex": "^\\d+/[A-Za-z]$", 17 | "comment": "Matches fraction expressions like '2/A' or '5/m'." 18 | }, 19 | "pattern5": { 20 | "regex": "^\\d+(\\.\\d+)?%$", 21 | "comment": "Matches percentage values like '25%', '3.5%', etc." 22 | }, 23 | "pattern6": { 24 | "regex": "^\\d+(\\.\\d+)?\\s*[°°°CFC]$", 25 | "comment": "Matches temperature values like '25°C', '3.5°F', etc." 26 | }, 27 | "pattern7": { 28 | "regex": "^\\d+[\\/\\^]\\d+[\\/\\^]\\w$", 29 | "comment": "Matches expressions like '2/3/m' or '5^2/n'." 30 | }, 31 | "pattern8": { 32 | "regex": "\\b\\d+\\b(?![.\\-+*/\\\\()\\[\\]{}])", 33 | "comment": "Matches standalone whole numbers that are not part of larger expressions." 34 | }, 35 | "pattern9": { 36 | "regex": "\\b\\d+-\\w+\\b", 37 | "comment": "Matches expressions like '12-abc' or '5-xyz'." 38 | }, 39 | "pattern10": { 40 | "regex": "\\b\\d+\\/\\d+\\b", 41 | "comment": "Matches fraction expressions like '2/3' or '5/8'." 42 | }, 43 | "pattern11": { 44 | "regex": "\\b\\d+[A-Za-z]+\\b", 45 | "comment": "Matches expressions like '25kg' or '10m'." 46 | }, 47 | "pattern12": { 48 | "regex": "e.g.", 49 | "comment": "Matches 'e.g.' (for example)." 50 | }, 51 | "pattern13": { 52 | "regex": "i.e.", 53 | "comment": "Matches 'i.e.' (that is)." 54 | }, 55 | "pattern14": { 56 | "regex": "\\b\\d+\\.\\d+[A-Za-z]\\b", 57 | "comment": "Matches expressions like '3.14pi' or '2.75x'." 58 | }, 59 | "pattern15": { 60 | "regex": "\\b\\d+\\/\\w+\\b", 61 | "comment": "Matches patterns like '≥ 10/μL', '< 9/cl', '+ 12/hgmm', etc." 62 | }, 63 | "pattern16": { 64 | "regex": "\\b\\d+\\.\\d+[)]?", 65 | "comment": "Matches patterns like '2.0)', '3.0)', '15.0)', '16.1-', '2.1)', '9.6/'." 66 | }, 67 | 68 | "pattern17": { 69 | "regex": "(? 4 | 5 | An AI-driven tool designed to match patients with the most relevant clinical trials. Leveraging state-of-the-art Large Language Models (LLMs), Natural Language Processing (NLP), and Explainable AI (XAI), TrialMatchAI structures trial documentation and patient data to provide transparent, personalized recommendations. 6 | 7 | --- 8 | 9 | ## ⚠️ Disclaimer 10 | At this stage, TrialMatchAI is still under active development and largely a **prototype** provided for research and informational purposes only. It is **NOT** medical advice and should not replace consultation with qualified healthcare professionals. 11 | 12 | --- 13 | 14 | ## 🔍 Key Features 15 | 16 | - **AI-Powered Matching**: Utilizes advanced LLMs to parse complex eligibility criteria and patient records (including unstructured notes and genetic reports). 17 | - **Personalized Recommendations**: Tailors trial suggestions based on each patient’s unique clinical history and genomic profile. 18 | - **Explainable Insights**: Provides clear, chain-of-thought explanations for every recommended trial, enhancing trust and interpretability. 19 | - **Real-Time Updates**: Maintains an up-to-date database of recruiting trials. 20 | - **Scalable Architecture**: Dockerized components enable easy deployment of Elasticsearch indices and indexing pipelines. 21 | 22 | --- 23 | 24 | ## ⚙️ System Requirements 25 | 26 | - **OS**: Linux or macOS 27 | - **Docker & Docker Compose**: For running the Elasticsearch container 28 | - **Java**: For running the NER and Normalization 29 | - **Python**: ≥ 3.8 30 | - **GPU**: NVIDIA (e.g., H100) with ≥ 60 GB VRAM (recommended for large-scale processing) 31 | - **Disk Space**: ≥ 100 GB free (for data and indices) 32 | 33 | --- 34 | 35 | ## 🚀 Installation & Setup 36 | 37 | 1. **Clone the Repository** 38 | ```bash 39 | git clone https://github.com/cbib/TrialMatchAI.git 40 | cd TrialMatchAI 41 | ``` 42 | 43 | 2. **Ensure the Repository Is Up to Date** 44 | ```bash 45 | git pull origin main 46 | ``` 47 | 48 | 3. **Make the Setup Script Executable** 49 | ```bash 50 | chmod +x setup.sh 51 | ``` 52 | 53 | 4. **(Optional) Configure Elasticsearch Password** 54 | - Open the `.env` file located in the `docker/` folder. 55 | - Update the `ELASTIC_PASSWORD` variable to your desired secure password. 56 | ```dotenv 57 | # docker/.env 58 | ELASTIC_PASSWORD=YourNewPassword 59 | ``` 60 | 61 | 4a. **(Optional) Sync `config.json` Password** 62 | If you updated `ELASTIC_PASSWORD` above, open `config.json` in the repo root and update the Elasticsearch password field to match: 63 | ```json 64 | { 65 | "elasticsearch": { 66 | "host": "https://localhost:9200", 67 | "username": "elastic", 68 | "password": "YourNewPassword", 69 | . 70 | . 71 | }, 72 | ... 73 | } 74 | ``` 75 | 76 | 5. **Run the Setup Script** 77 | ```bash 78 | ./setup.sh 79 | ``` 80 | - Installs Python dependencies 81 | - Downloads datasets, resources, and model archives from Zenodo 82 | - Verifies GPU availability 83 | - Builds the Elasticsearch container via Docker Compose 84 | - Launches indexing pipelines in the background 85 | - **Estimated Time**: ~60–90 minutes (depending on hardware) 86 | 87 | 6. **(Optional) Using Flash-Attention 2** 88 | If you want to use Flash-Attention for faster and more memory efficient attention, you can install it through pip, as presented in the [package github](https://github.com/Dao-AILab/flash-attention). 89 | 90 | For some systems however, installing it with standard methods is either impossible or too slow. It is recommended in this case to manually download the **wheel** file compatible with your torch, cuda and python versions, listed in the package [releases](https://github.com/Dao-AILab/flash-attention/releases). Then, you can pip install that file. 91 | 92 | --- 93 | 94 | ## 🎯 Usage Example 95 | 96 | Run the matcher on a sample input directory: 97 | 98 | ```bash 99 | python -m src.Matcher.main 100 | ``` 101 | 102 | Results are saved under `results/`, with detailed criterion-level explanations for each recommended trial. 103 | 104 | --- 105 | 106 | ## 🤝 Contributing 107 | 108 | We welcome community contributions! To contribute: 109 | 110 | 1. Fork the repository. 111 | 2. Create a feature branch: `git checkout -b feature/YourFeature`. 112 | 3. Commit your changes and push to your branch. 113 | 4. Open a Pull Request against `main`. 114 | 115 | Please follow our code style and include tests where applicable. 116 | 117 | --- 118 | 119 | ## 🙋 Support & Contact 120 | 121 | For questions, issues, or feature requests, open an issue on GitHub or reach out to: 122 | 123 | - **Email**: [abdallahmajd7@gmail.com](mailto:abdallahmajd7@gmail.com) 124 | -------------------------------------------------------------------------------- /source/Parser/normalizers/normalizer_all.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from datetime import datetime 4 | from rapidfuzz import fuzz 5 | 6 | 7 | class BaseNormalizer: 8 | """ 9 | Base class for flexible normalization with exact and similarity-based matching. 10 | """ 11 | 12 | def __init__(self, dict_path, nlp=None, similarity_threshold=0.5): 13 | self.NO_ENTITY_ID = "CUI-less" 14 | self.similarity_threshold = similarity_threshold 15 | self.nlp = nlp 16 | self.entity2oid = dict() 17 | 18 | self.load_dictionary(dict_path) 19 | 20 | def load_dictionary(self, dict_path): 21 | """ 22 | Load dictionary from the given path and normalize keys for matching. 23 | """ 24 | print( 25 | datetime.now().strftime("[%d/%b/%Y %H:%M:%S.%f]"), "Loading dictionary..." 26 | ) 27 | with open(dict_path, "r", encoding="utf-8") as f: 28 | for line in f: 29 | oid, names = line.strip().split("||") 30 | names = names.split("|") 31 | for name in names: 32 | normalized_name = self.normalize_string(name) 33 | self.entity2oid[normalized_name] = oid 34 | print(datetime.now().strftime("[%d/%b/%Y %H:%M:%S.%f]"), "Dictionary loaded.") 35 | 36 | def normalize(self, names): 37 | """ 38 | Normalize a list of names and return their corresponding IDs. 39 | """ 40 | oids = list() 41 | for name in names: 42 | normalized_name = self.get_tmchem_name(name) 43 | 44 | # Exact match 45 | if normalized_name in self.entity2oid: 46 | oids.append(self.entity2oid[normalized_name]) 47 | else: 48 | # Flexible match based on similarity score 49 | best_match = self.find_best_match(normalized_name) 50 | if best_match: 51 | oids.append(self.entity2oid[best_match]) 52 | else: 53 | oids.append(self.NO_ENTITY_ID) 54 | 55 | return oids 56 | 57 | def get_tmchem_name(self, name): 58 | """ 59 | Normalize a name using lowercasing, punctuation removal, and lemmatization (if NLP is enabled). 60 | """ 61 | cleaned_name = re.sub( 62 | r"[^\w\s-]", "", name.lower() 63 | ) # Remove punctuation but keep hyphens 64 | if self.nlp: 65 | doc = self.nlp(cleaned_name) 66 | lemmatized_name = " ".join([token.lemma_ for token in doc]) 67 | else: 68 | lemmatized_name = cleaned_name 69 | return self.normalize_string(lemmatized_name) 70 | 71 | def normalize_string(self, text): 72 | """ 73 | Normalize a string for matching by removing spaces and hyphens. 74 | """ 75 | return re.sub(r"[\s-]", "", text.lower()) 76 | 77 | def find_best_match(self, normalized_name): 78 | """ 79 | Find the best match for a given normalized name based on a similarity score. 80 | """ 81 | best_match = None 82 | highest_score = 0 83 | 84 | for key in self.entity2oid.keys(): 85 | score = self.similarity_score(normalized_name, key) 86 | if score > self.similarity_threshold and score > highest_score: 87 | best_match = key 88 | highest_score = score 89 | 90 | return best_match 91 | 92 | def similarity_score(self, name1, name2): 93 | """ 94 | Use RapidFuzz's token_sort_ratio for fast and accurate similarity scoring. 95 | """ 96 | return fuzz.token_sort_ratio(name1, name2) / 100 # Normalize score to 0-1 range 97 | 98 | 99 | class CellTypeNormalizer(BaseNormalizer): 100 | pass 101 | 102 | 103 | class ProcedureNormalizer(BaseNormalizer): 104 | pass 105 | 106 | 107 | class ChemicalNormalizer(BaseNormalizer): 108 | def load_dictionary(self, dict_path): 109 | """ 110 | Specialized dictionary loader for chemical normalizer with file size logging. 111 | """ 112 | dict_size = os.path.getsize(dict_path) 113 | print( 114 | datetime.now().strftime("[%d/%b/%Y %H:%M:%S.%f]"), 115 | f"Chemical dictionary file size: {dict_size} bytes", 116 | ) 117 | super().load_dictionary(dict_path) 118 | 119 | 120 | class CellLineNormalizer(BaseNormalizer): 121 | def get_tmchem_name(self, name): 122 | """ 123 | Simplified name normalization for cell line data. 124 | """ 125 | return self.normalize_string(name) 126 | 127 | 128 | class SpeciesNormalizer(BaseNormalizer): 129 | def get_tmchem_name(self, name): 130 | """ 131 | Simplified name normalization for species data. 132 | """ 133 | return self.normalize_string(name) 134 | 135 | 136 | class SignSymptomNormalizer(BaseNormalizer): 137 | pass 138 | -------------------------------------------------------------------------------- /source/Matcher/models/llm/llm_loader.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from Matcher.utils.logging_config import setup_logging 5 | from peft import PeftModel 6 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 7 | 8 | logger = setup_logging() 9 | 10 | 11 | def load_model_and_tokenizer( 12 | model_config: dict, device: int 13 | ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: 14 | """Load a model and tokenizer with safe device handling and optional 4-bit.""" 15 | use_cuda = torch.cuda.is_available() 16 | device_str = "cuda" if use_cuda else "cpu" 17 | quant_config = None 18 | attn_impl = None 19 | # Select best dtype 20 | compute_dtype = torch.float32 21 | if use_cuda and torch.cuda.is_bf16_supported(): 22 | compute_dtype = torch.bfloat16 23 | elif use_cuda: 24 | compute_dtype = torch.float16 25 | 26 | if use_cuda: 27 | cuda_count = torch.cuda.device_count() 28 | idx = int(device) if isinstance(device, int) else 0 29 | if idx < 0 or idx >= cuda_count: 30 | logger.warning( 31 | f"Requested CUDA device {device} invalid; using 0 (num_gpus={cuda_count})." 32 | ) 33 | idx = 0 34 | try: 35 | torch.cuda.set_device(idx) 36 | except Exception as e: 37 | logger.warning( 38 | f"torch.cuda.set_device({idx}) failed: {e}. Falling back to 0." 39 | ) 40 | idx = 0 41 | torch.cuda.set_device(idx) 42 | device_str = f"cuda:{idx}" 43 | 44 | # Prefer FlashAttention-2 if available, else SDPA 45 | attn_impl = "sdpa" 46 | try: 47 | import flash_attn # noqa: F401 48 | 49 | major, minor = torch.cuda.get_device_capability(idx) 50 | if (major * 10 + minor) >= 75: 51 | attn_impl = "flash_attention_2" 52 | logger.info("Using FlashAttention-2.") 53 | else: 54 | logger.info("FlashAttention-2 unsupported on this GPU; using SDPA.") 55 | except Exception: 56 | logger.info("flash-attn not available; using SDPA.") 57 | 58 | quant_config = BitsAndBytesConfig( 59 | load_in_4bit=bool(model_config["quantization"]["load_in_4bit"]), 60 | bnb_4bit_use_double_quant=bool( 61 | model_config["quantization"]["bnb_4bit_use_double_quant"] 62 | ), 63 | bnb_4bit_quant_type=str( 64 | model_config["quantization"]["bnb_4bit_quant_type"] 65 | ), 66 | bnb_4bit_compute_dtype=compute_dtype, 67 | ) 68 | logger.info(f"Loading model on {device_str} with 4-bit quantization.") 69 | else: 70 | logger.warning( 71 | "CUDA not available; loading model on CPU without 4-bit quantization." 72 | ) 73 | device_str = "cpu" 74 | quant_config = BitsAndBytesConfig(load_in_4bit=False) 75 | 76 | tokenizer = AutoTokenizer.from_pretrained( 77 | model_config["base_model"], 78 | use_fast=True, 79 | padding_side="left", 80 | trust_remote_code=True, 81 | ) 82 | # Always left-pad decoder-only models; keep most recent tokens if truncation occurs. 83 | tokenizer.padding_side = "left" 84 | tokenizer.truncation_side = "left" 85 | tokenizer.pad_token = tokenizer.eos_token 86 | if attn_impl == "flash_attention_2": 87 | logger.info( 88 | "Using FlashAttention-2; keeping padding_side='left' for decoder-only models." 89 | ) 90 | 91 | model = AutoModelForCausalLM.from_pretrained( 92 | model_config["base_model"], 93 | trust_remote_code=True, 94 | torch_dtype=compute_dtype if use_cuda else torch.float32, 95 | device_map=device_str, 96 | attn_implementation=attn_impl, 97 | quantization_config=quant_config, 98 | low_cpu_mem_usage=True, 99 | ) 100 | # Ensure KV cache usage for faster generation 101 | try: 102 | model.config.use_cache = True 103 | except Exception: 104 | pass 105 | 106 | model = PeftModel.from_pretrained( 107 | model, model_config["cot_adapter_path"], device_map=device_str 108 | ) 109 | 110 | # Optional: compile for extra speed when supported 111 | if bool(model_config.get("compile", False)): 112 | try: 113 | model = torch.compile(model, mode="max-autotune", fullgraph=False) 114 | logger.info("Model compiled with torch.compile.") 115 | except Exception as e: 116 | logger.warning(f"torch.compile failed; continuing without it. Err: {e}") 117 | 118 | if isinstance(model, torch.nn.Module): 119 | model.eval() 120 | else: 121 | logger.warning("Model is not an instance of torch.nn.Module; skipping eval.") 122 | logger.info(f"Model loaded on {device_str}.") 123 | return model, tokenizer # type: ignore[return-value] 124 | -------------------------------------------------------------------------------- /utils/gpt/gpt_generate_reranking_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from typing import Dict, List, Optional 5 | 6 | from langchain.schema import HumanMessage 7 | from langchain_community.chat_models import ChatOpenAI 8 | from pydantic import BaseModel 9 | 10 | os.environ["OPENAI_API_KEY"] = "" 11 | 12 | INPUT_MEDNLI_FILE = "mednli_train.jsonl" 13 | OUTPUT_AUGMENTED_FILE = "mednli_yesno_aug.jsonl" 14 | 15 | N_VARIANTS_PER_SEED = 3 16 | MAX_SEED_EXAMPLES = 500 17 | 18 | MODEL_NAME = "gpt-4o-mini" 19 | TEMPERATURE = 0.7 20 | 21 | INSTRUCTION_TEXT = ( 22 | "You are a clinical assistant tasked with determining whether the patient " 23 | "information (Statement A) provides enough details to evaluate whether the " 24 | "patient satisfies or violates the clinical trial eligibility criterion " 25 | "(Statement B). Respond with 'Yes' if Statement A contains sufficient " 26 | "information to make this evaluation, or 'No' if it does not." 27 | ) 28 | 29 | 30 | class MedNLISeed(BaseModel): 31 | sentence1: str 32 | sentence2: str 33 | gold_label: str 34 | 35 | 36 | class YesNoExample(BaseModel): 37 | instruction: str 38 | sentence1: str 39 | sentence2: str 40 | gold_label: str 41 | 42 | 43 | llm = ChatOpenAI( 44 | model=MODEL_NAME, 45 | temperature=TEMPERATURE, 46 | top_p=0.9, 47 | ) 48 | 49 | 50 | def load_mednli_jsonl(path: str) -> List[MedNLISeed]: 51 | examples: List[MedNLISeed] = [] 52 | with open(path, "r") as f: 53 | for line in f: 54 | raw = json.loads(line.strip()) 55 | examples.append( 56 | MedNLISeed( 57 | sentence1=raw["sentence1"], 58 | sentence2=raw["sentence2"], 59 | gold_label=raw["gold_label"], 60 | ) 61 | ) 62 | return examples 63 | 64 | 65 | def write_jsonl(path: str, data: List[Dict]): 66 | with open(path, "a", encoding="utf-8") as f: 67 | for obj in data: 68 | f.write(json.dumps(obj, ensure_ascii=False) + "\n") 69 | 70 | 71 | def generate_yesno_variants( 72 | seed_example: MedNLISeed, n_variants: int 73 | ) -> List[YesNoExample]: 74 | prompt = f""" 75 | You are generating training data for a clinical trial matching system. 76 | 77 | We want examples following this pattern: 78 | 79 | - Statement A: patient information. 80 | - Statement B: clinical trial eligibility criterion. 81 | - A fixed instruction string. 82 | - A label "Yes" or "No". 83 | 84 | Semantics: 85 | - "Yes": Statement A contains enough information to determine if the patient satisfies or violates Statement B. 86 | - "No": Statement A does not contain enough information. 87 | 88 | MedNLI mapping: 89 | - entailment/contradiction → Yes 90 | - neutral → No 91 | 92 | Seed example: 93 | {{ 94 | "sentence1": {json.dumps(seed_example.sentence1)}, 95 | "sentence2": {json.dumps(seed_example.sentence2)}, 96 | "gold_label": {json.dumps(seed_example.gold_label)} 97 | }} 98 | 99 | Generate {n_variants} new examples. Each example must have: 100 | - "instruction": {json.dumps(INSTRUCTION_TEXT)} 101 | - "sentence1": Statement A 102 | - "sentence2": Statement B 103 | - "gold_label": "Yes" or "No" 104 | 105 | Return only valid JSON: a list of objects with exactly these keys. 106 | """.strip() 107 | 108 | response = llm.invoke([HumanMessage(content=prompt)]) 109 | raw = response.content.strip() 110 | if raw.startswith("```"): 111 | raw = raw.strip("`").replace("json", "", 1).strip() 112 | 113 | try: 114 | parsed = json.loads(raw) 115 | out = [] 116 | for item in parsed: 117 | try: 118 | ex = YesNoExample(**item) 119 | if ex.gold_label in {"Yes", "No"}: 120 | ex.instruction = INSTRUCTION_TEXT 121 | out.append(ex) 122 | except Exception: 123 | pass 124 | return out 125 | except Exception: 126 | return [] 127 | 128 | 129 | def augment_mednli_yesno( 130 | input_path: str, 131 | output_path: str, 132 | n_variants_per_seed: int = 3, 133 | max_seed_examples: Optional[int] = None, 134 | ): 135 | all_seeds = load_mednli_jsonl(input_path) 136 | 137 | if max_seed_examples is not None and max_seed_examples < len(all_seeds): 138 | seed_examples = random.sample(all_seeds, max_seed_examples) 139 | else: 140 | seed_examples = all_seeds 141 | 142 | open(output_path, "w").close() 143 | 144 | for seed in seed_examples: 145 | generated = generate_yesno_variants(seed, n_variants_per_seed) 146 | if not generated: 147 | continue 148 | to_write = [ 149 | { 150 | "instruction": ex.instruction, 151 | "sentence1": ex.sentence1, 152 | "sentence2": ex.sentence2, 153 | "gold_label": ex.gold_label, 154 | } 155 | for ex in generated 156 | ] 157 | write_jsonl(output_path, to_write) 158 | 159 | 160 | if __name__ == "__main__": 161 | augment_mednli_yesno( 162 | input_path=INPUT_MEDNLI_FILE, 163 | output_path=OUTPUT_AUGMENTED_FILE, 164 | n_variants_per_seed=N_VARIANTS_PER_SEED, 165 | max_seed_examples=MAX_SEED_EXAMPLES, 166 | ) 167 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euo pipefail 3 | IFS=$'\n\t' 4 | 5 | #=== CONFIGURATION ===# 6 | DATA_URL_1="https://zenodo.org/records/15516900/files/processed_trials.tar.gz?download=1" 7 | RESOURCES_URL="https://zenodo.org/records/15516900/files/resources.tar.gz?download=1" 8 | MODELS_URL="https://zenodo.org/records/15516900/files/models.tar.gz?download=1" 9 | CRITERIA_ZIP_BASE_URL="https://zenodo.org/records/15516900/files" 10 | CHUNK_PREFIX="criteria_part" 11 | CHUNK_COUNT=6 12 | 13 | ARCHIVE_1="processed_trials.tar.gz" 14 | RESOURCES_ARCHIVE="resources.tar.gz" 15 | MODELS_ARCHIVE="models.tar.gz" 16 | 17 | #=== COLORS ===# 18 | GREEN='\033[0;32m' 19 | NC='\033[0m' # No Color 20 | 21 | #=== HELPERS ===# 22 | info() { echo -e "${GREEN}[INFO]${NC} $*"; } 23 | error() { echo -e "[ERROR] $*" >&2; exit 1; } 24 | 25 | #=== MAIN SCRIPT ===# 26 | info "Starting TrialMatchAI setup..." 27 | 28 | # 0) Check for available GPUs 29 | info "Checking for available GPUs..." 30 | 31 | if command -v nvidia-smi &> /dev/null; then 32 | if nvidia-smi &> /dev/null; then 33 | info "NVIDIA GPUs detected:" 34 | nvidia-smi --query-gpu=index,name,memory.total --format=csv 35 | else 36 | info "nvidia-smi found, but no NVIDIA GPU detected or driver not loaded." 37 | fi 38 | else 39 | info "No NVIDIA GPUs detected." 40 | fi 41 | 42 | # 1) Install Python dependencies 43 | if ! command -v pip &> /dev/null; then 44 | error "pip not found. Please install Python and pip first." 45 | fi 46 | info "Installing Python requirements..." 47 | pip install --upgrade pip 48 | pip install -r requirements.txt 49 | 50 | # 2) Prepare data directory 51 | info "Preparing data directory..." 52 | mkdir -p data 53 | cd data 54 | 55 | # Download core archives 56 | if [ ! -f "$ARCHIVE_1" ]; then 57 | info "Downloading ${ARCHIVE_1}..." 58 | wget --quiet "$DATA_URL_1" -O "$ARCHIVE_1" 59 | else 60 | info "${ARCHIVE_1} already exists. Skipping download." 61 | fi 62 | 63 | if [ ! -f "$RESOURCES_ARCHIVE" ]; then 64 | info "Downloading ${RESOURCES_ARCHIVE}..." 65 | wget --quiet "$RESOURCES_URL" -O "$RESOURCES_ARCHIVE" 66 | else 67 | info "${RESOURCES_ARCHIVE} already exists. Skipping download." 68 | fi 69 | 70 | if [ ! -f "$MODELS_ARCHIVE" ]; then 71 | info "Downloading ${MODELS_ARCHIVE}..." 72 | wget --quiet "$MODELS_URL" -O "$MODELS_ARCHIVE" 73 | else 74 | info "${MODELS_ARCHIVE} already exists. Skipping download." 75 | fi 76 | 77 | # Download and extract processed_criteria ZIP chunks 78 | if [ ! -d "processed_criteria" ]; then 79 | info "Downloading and extracting processed_criteria chunks..." 80 | mkdir -p processed_criteria 81 | 82 | for i in $(seq 0 $((CHUNK_COUNT - 1))); do 83 | chunk_zip="${CHUNK_PREFIX}_${i}.zip" 84 | chunk_url="${CRITERIA_ZIP_BASE_URL}/${chunk_zip}?download=1" 85 | 86 | if [ ! -f "$chunk_zip" ]; then 87 | info "Downloading $chunk_zip..." 88 | wget --quiet "$chunk_url" -O "$chunk_zip" 89 | else 90 | info "$chunk_zip already exists. Skipping download." 91 | fi 92 | 93 | info "Extracting $chunk_zip into processed_criteria..." 94 | unzip -q "$chunk_zip" -d processed_criteria 95 | done 96 | else 97 | info "processed_criteria already exists. Skipping extraction." 98 | fi 99 | 100 | # Extract processed_trials 101 | if [ ! -d "processed_trials" ]; then 102 | info "Extracting $ARCHIVE_1..." 103 | tar -xzvf "$ARCHIVE_1" 104 | else 105 | info "processed_trials already exists. Skipping extraction of $ARCHIVE_1." 106 | fi 107 | 108 | cd .. 109 | 110 | # Extract resources 111 | info "Extracting resources into source/Parser..." 112 | mkdir -p source/Parser 113 | tar -xzvf data/"$RESOURCES_ARCHIVE" -C source/Parser 114 | 115 | info "Extracting models into models/..." 116 | mkdir -p models 117 | tar -xzvf data/"$MODELS_ARCHIVE" -C models 118 | 119 | info "Cleaning up archives..." 120 | rm -f data/"$ARCHIVE_1" data/"$RESOURCES_ARCHIVE" data/"$MODELS_ARCHIVE" 121 | 122 | for i in $(seq 0 $((CHUNK_COUNT - 1))); do 123 | rm -f data/"${CHUNK_PREFIX}_${i}.zip" 124 | done 125 | 126 | # 3) Launch Elasticsearch: Try Docker first, then Apptainer fallback 127 | cd elasticsearch 128 | if command -v docker &> /dev/null && docker info &> /dev/null; then 129 | info "Docker is available. Setting up Elasticsearch with Docker Compose..." 130 | docker-compose up -d --build 131 | cd .. 132 | elif command -v apptainer &> /dev/null; then 133 | info "Docker not found or not running. Falling back to Apptainer..." 134 | if [ ! -f "./apptainer-run-es.sh" ]; then 135 | error "Apptainer script not found at ./elasticsearch/apptainer-run-es.sh" 136 | fi 137 | bash ./apptainer-run-es.sh 138 | else 139 | error "Neither Docker nor Apptainer is available. Cannot continue." 140 | fi 141 | cd .. 142 | 143 | # 4) Launch indexers in background 144 | cd utils/Indexer 145 | info "Starting index_criteria.py (trials_eligibility) ..." 146 | nohup python index_criteria.py \ 147 | --config config.json \ 148 | --processed-folder ../../data/processed_criteria \ 149 | --index-name trials_eligibility \ 150 | --batch-size 100 \ 151 | --max-workers 100 \ 152 | > criteria.log 2>&1 & 153 | 154 | info "Starting index_trials.py (clinical_trials) ..." 155 | nohup python index_trials.py \ 156 | --config config.json \ 157 | --processed-folder ../../data/processed_trials \ 158 | --index-name clinical_trials \ 159 | --batch-size 100 \ 160 | > trials.log 2>&1 & 161 | 162 | info "Waiting for indexing jobs to complete..." 163 | wait 164 | 165 | info "✅ TrialMatchAI setup is complete!" 166 | -------------------------------------------------------------------------------- /utils/finetuning/finetune_instruct/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, set_seed 9 | 10 | from arguments import ( 11 | ModelArguments, 12 | DataArguments, 13 | SFTTrainingArguments as TrainingArguments, 14 | ) 15 | from data import TrainDataset, DataCollatorForFinetuning 16 | from modeling import LanguageModelFinetuner 17 | from trainer import SFTTrainer 18 | from load_model import get_model 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | # Initialize the distributed environment if needed 24 | dist.init_process_group(backend="nccl") 25 | 26 | # Get the rank of the current process 27 | rank = dist.get_rank() 28 | 29 | # Map the rank to a specific GPU 30 | device_id = rank # This assumes rank maps to GPU ID 31 | torch.cuda.set_device(device_id) 32 | 33 | print( 34 | f"Rank {rank} using device {device_id} on {torch.cuda.get_device_name(device_id)}" 35 | ) 36 | 37 | 38 | def main(): 39 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 40 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 41 | model_args: ModelArguments 42 | data_args: DataArguments 43 | training_args: TrainingArguments 44 | 45 | if ( 46 | os.path.exists(training_args.output_dir) 47 | and os.listdir(training_args.output_dir) 48 | and training_args.do_train 49 | and not training_args.overwrite_output_dir 50 | ): 51 | raise ValueError( 52 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 53 | f"Use --overwrite_output_dir to overcome." 54 | ) 55 | 56 | # Setup logging 57 | logging.basicConfig( 58 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 59 | datefmt="%m/%d/%Y %H:%M:%S", 60 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 61 | ) 62 | logger.warning( 63 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 64 | training_args.local_rank, 65 | training_args.device, 66 | training_args.n_gpu, 67 | bool(training_args.local_rank != -1), 68 | training_args.fp16, 69 | ) 70 | logger.info("Training/evaluation parameters %s", training_args) 71 | logger.info("Model parameters %s", model_args) 72 | logger.info("Data parameters %s", data_args) 73 | 74 | # Set seed 75 | set_seed(training_args.seed) 76 | 77 | base_model = get_model(model_args, training_args) 78 | 79 | # Load tokenizer 80 | tokenizer = AutoTokenizer.from_pretrained( 81 | model_args.tokenizer_name 82 | if model_args.tokenizer_name 83 | else model_args.model_name_or_path, 84 | cache_dir=model_args.cache_dir, 85 | use_fast=not model_args.use_slow_tokenizer, 86 | trust_remote_code=True, 87 | token=model_args.token, 88 | ) 89 | 90 | # Ensure pad_token_id is defined 91 | if tokenizer.pad_token_id is None: 92 | if tokenizer.unk_token_id is not None: 93 | tokenizer.pad_token_id = tokenizer.unk_token_id 94 | else: 95 | # As a fallback if the tokenizer doesn't have unk_token_id, set pad_token_id to a known token 96 | # If using a special tokenizer, make sure to adapt accordingly. 97 | tokenizer.pad_token_id = tokenizer.eos_token_id 98 | 99 | config = AutoConfig.from_pretrained( 100 | model_args.config_name 101 | if model_args.config_name 102 | else model_args.model_name_or_path, 103 | cache_dir=model_args.cache_dir, 104 | trust_remote_code=True, 105 | ) 106 | logger.info("Config: %s", config) 107 | 108 | model = LanguageModelFinetuner( 109 | model=base_model, 110 | tokenizer=tokenizer, 111 | train_batch_size=training_args.per_device_train_batch_size, 112 | ) 113 | 114 | if training_args.gradient_checkpointing: 115 | model.enable_input_require_grads() 116 | 117 | # Load the training dataset 118 | train_dataset = TrainDataset(args=data_args, tokenizer=tokenizer) 119 | 120 | # Setup data collator 121 | data_collator = DataCollatorForFinetuning( 122 | tokenizer=tokenizer, 123 | query_max_len=data_args.query_max_len, 124 | passage_max_len=data_args.passage_max_len, 125 | pad_to_multiple_of=8, 126 | return_tensors="pt", 127 | padding=True, 128 | ) 129 | 130 | trainer = SFTTrainer( 131 | model=model, 132 | args=training_args, 133 | train_dataset=train_dataset, 134 | data_collator=data_collator, 135 | tokenizer=tokenizer, 136 | ) 137 | trainer.use_lora = model_args.use_lora 138 | 139 | Path(training_args.output_dir).mkdir(parents=True, exist_ok=True) 140 | 141 | # Training 142 | trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 143 | trainer.save_model() 144 | 145 | # If not using LoRA, you can save a final checkpoint if desired 146 | if not model_args.use_lora and trainer.deepspeed is not None: 147 | checkpoint_dir = os.path.join(training_args.output_dir, "checkpoint-final") 148 | trainer.deepspeed.save_checkpoint(checkpoint_dir) 149 | 150 | # If world process zero, save tokenizer 151 | if trainer.is_world_process_zero(): 152 | tokenizer.save_pretrained(training_args.output_dir) 153 | 154 | 155 | if __name__ == "__main__": 156 | try: 157 | main() 158 | finally: 159 | # Ensure all processes finalize 160 | dist.barrier() # Synchronize all processes 161 | dist.destroy_process_group() # Clean up 162 | -------------------------------------------------------------------------------- /utils/Indexer/index_trials.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import json 4 | import argparse 5 | from pathlib import Path 6 | 7 | from elasticsearch import Elasticsearch 8 | from elasticsearch.helpers import bulk 9 | 10 | 11 | def load_config(path: str) -> dict: 12 | return json.loads(Path(path).read_text()) 13 | 14 | 15 | def make_es_client(cfg: dict) -> Elasticsearch: 16 | es_conf = cfg["elasticsearch"] 17 | return Elasticsearch( 18 | hosts=es_conf["hosts"], 19 | basic_auth=(es_conf["username"], es_conf["password"]), 20 | ca_certs=es_conf["ca_certs"], 21 | verify_certs=True, 22 | ) 23 | 24 | 25 | def detect_vector_dim(sample: dict) -> int: 26 | for k, v in sample.items(): 27 | if k.endswith("_vector") and isinstance(v, list): 28 | return len(v) 29 | raise ValueError("No vector field found in sample") 30 | 31 | 32 | def load_processed(folder: Path) -> list[dict]: 33 | docs = [] 34 | for fn in os.listdir(folder): 35 | if fn.endswith(".json"): 36 | docs.append(json.loads((folder / fn).read_text())) 37 | return docs 38 | 39 | 40 | def create_index(es: Elasticsearch, name: str, dims: int): 41 | body = { 42 | "settings": { 43 | "analysis": { 44 | "analyzer": { 45 | "standard_lowercase": { 46 | "type": "custom", 47 | "tokenizer": "standard", 48 | "filter": ["lowercase"], 49 | } 50 | } 51 | } 52 | }, 53 | "mappings": { 54 | "properties": { 55 | "nct_id": {"type": "keyword"}, 56 | "brief_title": {"type": "text", "analyzer": "standard_lowercase"}, 57 | "brief_title_vector": {"type": "dense_vector", "dims": dims}, 58 | "brief_summary": {"type": "text", "analyzer": "standard_lowercase"}, 59 | "brief_summary_vector": {"type": "dense_vector", "dims": dims}, 60 | "condition": {"type": "text", "analyzer": "standard_lowercase"}, 61 | "condition_vector": {"type": "dense_vector", "dims": dims}, 62 | "overall_status": {"type": "keyword"}, 63 | "start_date": {"type": "date", "format": "yyyy-MM-dd"}, 64 | "completion_date": {"type": "date", "format": "yyyy-MM-dd"}, 65 | "phase": {"type": "keyword"}, 66 | "study_type": {"type": "keyword"}, 67 | "intervention": { 68 | "properties": { 69 | "intervention_type": {"type": "keyword"}, 70 | "intervention_name": {"type": "text"}, 71 | } 72 | }, 73 | "gender": {"type": "keyword"}, 74 | "minimum_age": {"type": "float"}, 75 | "maximum_age": {"type": "float"}, 76 | "location": { 77 | "properties": { 78 | "location_name": {"type": "text"}, 79 | "location_address": {"type": "text"}, 80 | } 81 | }, 82 | "reference": { 83 | "type": "nested", 84 | "properties": { 85 | "citation": {"type": "text"}, 86 | "PMID": {"type": "keyword"}, 87 | }, 88 | }, 89 | "eligibility_criteria": { 90 | "type": "text", 91 | "analyzer": "standard_lowercase", 92 | }, 93 | "eligibility_criteria_vector": {"type": "dense_vector", "dims": dims}, 94 | } 95 | }, 96 | } 97 | es.indices.create(index=name, body=body) 98 | print(f"Created index `{name}` with vector dims={dims}") 99 | 100 | 101 | def main(): 102 | parser = argparse.ArgumentParser(description="Bulk‑index processed trial JSONs") 103 | parser.add_argument( 104 | "--config", 105 | required=True, 106 | help="Path to JSON config file with Elasticsearch credentials", 107 | ) 108 | parser.add_argument( 109 | "--processed-folder", required=True, help="Folder of processed JSONs to index" 110 | ) 111 | parser.add_argument( 112 | "--index-name", 113 | default="clinical_trials", 114 | help="Target Elasticsearch index name", 115 | ) 116 | parser.add_argument( 117 | "--batch-size", type=int, default=100, help="Number of docs per bulk request" 118 | ) 119 | args = parser.parse_args() 120 | 121 | cfg = load_config(args.config) 122 | es = make_es_client(cfg) 123 | 124 | processed_path = Path(args.processed_folder) 125 | docs = load_processed(processed_path) 126 | if not docs: 127 | print("❌ No JSONs found to index.") 128 | return 129 | 130 | dims = detect_vector_dim(docs[0]) 131 | 132 | # <-- FIXED: use keyword arg `index=` 133 | if not es.indices.exists(index=args.index_name): 134 | create_index(es, args.index_name, dims) 135 | else: 136 | print(f"Index `{args.index_name}` already exists; skipping creation.") 137 | 138 | actions = [ 139 | { 140 | "_op_type": "index", 141 | "_index": args.index_name, 142 | "_id": doc["nct_id"], 143 | "_source": doc, 144 | } 145 | for doc in docs 146 | ] 147 | 148 | success, failures = bulk( 149 | client=es, 150 | actions=actions, 151 | chunk_size=args.batch_size, 152 | stats_only=True, 153 | raise_on_error=False, 154 | ) 155 | es.indices.refresh(index=args.index_name) 156 | print(f"✅ Indexed {success} documents; {failures} failures.") 157 | 158 | 159 | if __name__ == "__main__": 160 | main() 161 | -------------------------------------------------------------------------------- /utils/finetuning/finetune_instruct/evaluate_CoT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from bert_score import score 4 | from transformers import ( 5 | AutoModelForCausalLM, 6 | AutoTokenizer, 7 | BitsAndBytesConfig, 8 | pipeline, 9 | ) 10 | from peft import PeftModel 11 | import random 12 | 13 | 14 | def load_model(device, model_path, adapter_path): 15 | """ 16 | Loads the model on the specified device using 4-bit quantization. 17 | """ 18 | print(f"Loading model on device cuda:{device}...") 19 | tokenizer = AutoTokenizer.from_pretrained(model_path) 20 | tokenizer.pad_token = tokenizer.eos_token 21 | quant_config = BitsAndBytesConfig( 22 | load_in_4bit=True, 23 | bnb_4bit_use_double_quant=True, 24 | bnb_4bit_quant_type="nf4", 25 | bnb_4bit_compute_dtype=torch.float16, 26 | ) 27 | base_model = AutoModelForCausalLM.from_pretrained( 28 | model_path, 29 | torch_dtype=torch.float16, 30 | device_map=f"cuda:{device}", 31 | attn_implementation="flash_attention_2", 32 | trust_remote_code=True, 33 | quantization_config=quant_config, 34 | ) 35 | 36 | pipe = pipeline( 37 | "text-generation", 38 | model=base_model, 39 | tokenizer=tokenizer, 40 | max_new_tokens=1024, 41 | do_sample=False, 42 | repetition_penalty=1.05, 43 | ) 44 | 45 | # Load the fine-tuned adapter model 46 | pipe.model = PeftModel.from_pretrained(base_model, adapter_path) 47 | 48 | return pipe, tokenizer 49 | 50 | 51 | def generate_output(pipe, tokenizer, instruction, input_text): 52 | """ 53 | Generates model output using the instruction as system prompt and input as user prompt. 54 | This version instructs the model to include its internal chain-of-thought exactly once, 55 | followed by the final answer. 56 | """ 57 | # Instruct the model to reveal its chain-of-thought once. 58 | system_msg = ( 59 | instruction + "\nPlease include your internal chain-of-thought exactly once, " 60 | "followed by the final answer. Do not repeat the chain-of-thought." 61 | ) 62 | messages = [ 63 | {"role": "system", "content": system_msg}, 64 | {"role": "user", "content": input_text}, 65 | ] 66 | prompt = tokenizer.apply_chat_template( 67 | messages, tokenize=False, add_generation_prompt=True 68 | ) 69 | generated = pipe(prompt)[0]["generated_text"].strip() 70 | 71 | # Optional post-processing: If the chain-of-thought is repeated, keep only the first occurrence. 72 | # This snippet assumes the model labels its reasoning with "Chain-of-thought:" and final answer with "Final Answer:". 73 | if generated.count("Chain-of-thought:") > 1: 74 | # Split on the label and reconstruct output using only the first instance. 75 | parts = generated.split("Chain-of-thought:") 76 | first_cot = ( 77 | parts[1].split("Final Answer:")[0] 78 | if "Final Answer:" in parts[1] 79 | else parts[1] 80 | ) 81 | final_answer = "" 82 | if "Final Answer:" in generated: 83 | final_answer = "Final Answer:" + generated.split("Final Answer:")[-1] 84 | # Rebuild the output with a single chain-of-thought section. 85 | generated = parts[0] + "Chain-of-thought:" + first_cot + "\n" + final_answer 86 | generated = generated.strip() 87 | 88 | return generated 89 | 90 | 91 | def compute_bertscore( 92 | model_outputs, 93 | reference_outputs, 94 | lang="en", 95 | model_type="allenai/longformer-base-4096", 96 | ): 97 | """ 98 | Computes BERTScore for evaluating model-generated outputs against reference texts. 99 | """ 100 | assert len(model_outputs) == len(reference_outputs), ( 101 | "Mismatch in number of model and reference outputs" 102 | ) 103 | 104 | print("Computing BERTScore...") 105 | precision, recall, f1 = score( 106 | model_outputs, 107 | reference_outputs, 108 | lang=lang, 109 | model_type=model_type, 110 | device="cuda" if torch.cuda.is_available() else "cpu", 111 | ) 112 | 113 | return { 114 | "precision": precision.mean().item(), 115 | "recall": recall.mean().item(), 116 | "f1": f1.mean().item(), 117 | } 118 | 119 | 120 | def main(): 121 | device = 0 122 | model_path = "microsoft/phi-4" 123 | adapter_path = "finetuned_phi_reasoning/" 124 | pipe, tokenizer = load_model(device, model_path, adapter_path) 125 | 126 | file_path = "finetuning_data/medical_o1_reasoning_test.jsonl" 127 | model_outputs = [] 128 | reference_outputs = [] 129 | 130 | with open(file_path, "r") as f: 131 | lines = f.readlines() 132 | 133 | # Randomly select 15 test cases 134 | selected_lines = random.sample(lines, 500) 135 | 136 | for idx, line in enumerate(selected_lines, 1): 137 | print(f"Processing randomly selected line {idx}") 138 | data = json.loads(line) 139 | generated_text = generate_output( 140 | pipe, tokenizer, data["instruction"], data["input"] 141 | ) 142 | print("Generated Output:\n", generated_text, "\n") 143 | model_outputs.append(generated_text) 144 | reference_outputs.append(data["output"]) 145 | 146 | results = compute_bertscore(model_outputs, reference_outputs) 147 | 148 | print("\nBERTScore Results:") 149 | print(f"Precision: {results['precision']:.4f}") 150 | print(f"Recall: {results['recall']:.4f}") 151 | print(f"F1 Score: {results['f1']:.4f}") 152 | 153 | with open("bertscore_results.txt", "w") as f: 154 | f.write(f"Precision: {results['precision']:.4f}\n") 155 | f.write(f"Recall: {results['recall']:.4f}\n") 156 | f.write(f"F1 Score: {results['f1']:.4f}\n") 157 | 158 | 159 | if __name__ == "__main__": 160 | main() 161 | -------------------------------------------------------------------------------- /utils/Indexer/prepare_trials.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | import os 5 | import re 6 | import warnings 7 | 8 | import dateutil.parser 9 | import torch 10 | import torch.nn.functional as F 11 | from transformers import AutoModel, AutoTokenizer 12 | 13 | warnings.filterwarnings( 14 | "ignore", category=UserWarning, message="TypedStorage is deprecated" 15 | ) 16 | 17 | 18 | class SentenceEmbedder: 19 | def __init__(self, model_name: str = "BAAI/bge-m3"): 20 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | print(f"Embedding on device: {self.device}") 22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 23 | self.model = AutoModel.from_pretrained(model_name).to(self.device) 24 | 25 | def mean_pooling(self, model_output, attention_mask): 26 | tokens = model_output[0] 27 | mask = attention_mask.unsqueeze(-1).expand(tokens.size()).float() 28 | return torch.sum(tokens * mask, 1) / torch.clamp(mask.sum(1), min=1e-9) 29 | 30 | def get_embeddings(self, text: str): 31 | if not text: 32 | return None 33 | enc = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") 34 | enc = {k: v.to(self.device) for k, v in enc.items()} 35 | with torch.no_grad(): 36 | out = self.model(**enc) 37 | emb = self.mean_pooling(out, enc["attention_mask"]) 38 | emb = F.normalize(emb, p=2, dim=1) 39 | return emb.squeeze().cpu().tolist() 40 | 41 | def preprocess_text(self, t: str) -> str: 42 | return re.sub(r"\s+", " ", t).strip() if t else t 43 | 44 | def to_iso(self, s: str): 45 | try: 46 | return dateutil.parser.parse(s).date().isoformat() if s else None 47 | except Exception: 48 | return None 49 | 50 | def age_to_years(self, s: str): 51 | if not s: 52 | return None 53 | m = re.search(r"([\d\.]+)", s) 54 | if not m: 55 | return None 56 | v = float(m.group(1)) 57 | u = s.lower() 58 | if "year" in u: 59 | y = v 60 | elif "month" in u: 61 | y = v / 12 62 | elif "week" in u: 63 | y = v / 52 64 | elif "day" in u: 65 | y = v / 365 66 | else: 67 | return None 68 | return round(y, 2) 69 | 70 | 71 | def embed_and_prepare(doc: dict, embedder: SentenceEmbedder): 72 | out = {"nct_id": doc["nct_id"]} 73 | # text fields → clean + vector 74 | for field, vec_name in [ 75 | ("brief_title", "brief_title_vector"), 76 | ("brief_summary", "brief_summary_vector"), 77 | ("condition", "condition_vector"), 78 | ("eligibility_criteria", "eligibility_criteria_vector"), 79 | ]: 80 | if field in doc: 81 | txt = doc[field] 82 | if isinstance(txt, list): 83 | txt = " ".join( 84 | [str(t) for t in txt if isinstance(t, str) and t.strip()] 85 | ) 86 | txt = embedder.preprocess_text(txt) 87 | emb = embedder.get_embeddings(txt) or [0.0] * len( 88 | embedder.get_embeddings("test") or [0.0] 89 | ) 90 | out[field] = txt 91 | out[vec_name] = emb 92 | 93 | # passthroughs 94 | for simple in ["overall_status", "phase", "study_type", "gender"]: 95 | if simple in doc: 96 | out[simple] = doc[simple] 97 | 98 | # dates 99 | for d in ["start_date", "completion_date"]: 100 | if d in doc: 101 | iso = embedder.to_iso(doc[d]) 102 | if iso: 103 | out[d] = iso 104 | 105 | # ages 106 | for a in ["minimum_age", "maximum_age"]: 107 | if a in doc: 108 | yrs = embedder.age_to_years(doc[a]) 109 | if yrs is not None: 110 | out[a] = yrs 111 | 112 | # nested passthroughs 113 | for nest in ["intervention", "location", "reference"]: 114 | if nest in doc: 115 | out[nest] = doc[nest] 116 | 117 | return out 118 | 119 | 120 | if __name__ == "__main__": 121 | p = argparse.ArgumentParser(description="Prepare & embed clinical trial JSONs") 122 | p.add_argument("--ids-file", required=True, help="One NCT ID per line") 123 | p.add_argument("--source-folder", required=True, help="Raw JSONs dir") 124 | p.add_argument( 125 | "--processed-folder", 126 | default="processed_docs", 127 | help="Where to write embedded JSONs", 128 | ) 129 | p.add_argument( 130 | "--model-name", default="BAAI/bge-m3", help="Sentence embedding model" 131 | ) 132 | args = p.parse_args() 133 | 134 | os.makedirs(args.processed_folder, exist_ok=True) 135 | embedder = SentenceEmbedder(model_name=args.model_name) 136 | 137 | with open(args.ids_file) as f: 138 | ids = [line.strip() for line in f if line.strip()] 139 | 140 | processed = 0 141 | skipped = 0 142 | 143 | for nct in ids: 144 | out_path = os.path.join(args.processed_folder, f"{nct}.json") 145 | if os.path.exists(out_path): 146 | print(f"🟡 Skipping {nct}: already processed") 147 | skipped += 1 148 | continue 149 | 150 | in_path = os.path.join(args.source_folder, f"{nct}.json") 151 | if not os.path.exists(in_path): 152 | print(f"⚠️ Missing raw JSON for {nct}") 153 | continue 154 | 155 | doc = json.load(open(in_path)) 156 | doc["nct_id"] = nct 157 | proc = embed_and_prepare(doc, embedder) 158 | 159 | with open(out_path, "w") as wf: 160 | json.dump(proc, wf, indent=2) 161 | processed += 1 162 | print(f"✅ Processed {nct}") 163 | 164 | print( 165 | f"\nSummary: {processed} processed, {skipped} skipped, {len(ids) - processed - skipped} missing." 166 | ) 167 | -------------------------------------------------------------------------------- /utils/Indexer/prepare_criteria.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import hashlib 4 | import json 5 | import logging 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from transformers import AutoModel, AutoTokenizer 11 | 12 | logging.basicConfig( 13 | format="%(asctime)s %(levelname)s %(name)s - %(message)s", 14 | level=logging.INFO, 15 | ) 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class SentenceEmbedder: 20 | def __init__(self, model_name: str = "BAAI/bge-m3", use_gpu: bool = True): 21 | self.device = torch.device( 22 | "cuda" if use_gpu and torch.cuda.is_available() else "cpu" 23 | ) 24 | logger.info(f"Loading model {model_name} on device {self.device}") 25 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 26 | self.model = AutoModel.from_pretrained(model_name).to(self.device) 27 | if self.device.type == "cuda": 28 | self.model = self.model.half() 29 | logger.info("Converted model to FP16") 30 | 31 | def mean_pool( 32 | self, token_embeds: torch.Tensor, attention_mask: torch.Tensor 33 | ) -> torch.Tensor: 34 | mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float() 35 | summed = torch.sum(token_embeds * mask, dim=1) 36 | counts = torch.clamp(mask.sum(dim=1), min=1e-9) 37 | return summed / counts 38 | 39 | def embed(self, texts: list[str]) -> list[list[float]]: 40 | """ 41 | Batch‑embed a list of strings; returns a list of float vectors. 42 | """ 43 | enc = self.tokenizer( 44 | texts, padding=True, truncation=True, return_tensors="pt" 45 | ).to(self.device) 46 | with torch.no_grad(): 47 | outputs = self.model(**enc) 48 | vecs = self.mean_pool(outputs.last_hidden_state, enc.attention_mask) 49 | vecs = F.normalize(vecs, p=2, dim=1) 50 | return vecs.cpu().tolist() 51 | 52 | 53 | def compute_criteria_id(nct_id: str, criterion: str) -> str: 54 | """ 55 | Deterministically hash a trial‑criterion pair to a 64‑hex string. 56 | """ 57 | return hashlib.sha256(f"{nct_id}:{criterion}".encode("utf-8")).hexdigest() 58 | 59 | 60 | def load_raw_trial(path: Path) -> dict: 61 | return json.loads(path.read_text()) 62 | 63 | 64 | def process_trial( 65 | nct_id: str, 66 | source_folder: Path, 67 | processed_folder: Path, 68 | embedder: SentenceEmbedder, 69 | ) -> int: 70 | raw_path = source_folder / f"{nct_id}.json" 71 | if not raw_path.exists(): 72 | logger.warning(f"Missing raw JSON for {nct_id}, skipping.") 73 | return 0 74 | 75 | data = load_raw_trial(raw_path) 76 | criteria = data.get("criteria", []) 77 | if not criteria: 78 | logger.info(f"No criteria found for {nct_id}.") 79 | return 0 80 | 81 | # collect texts 82 | entries = [] 83 | texts = [] 84 | for crit in criteria: 85 | text = crit.get("criterion") or crit.get("sentence") 86 | if not text: 87 | continue 88 | entries.append( 89 | { 90 | "nct_id": nct_id, 91 | "criterion": text, 92 | "entities": crit.get("entities", []), 93 | "eligibility_type": crit.get("type"), 94 | } 95 | ) 96 | texts.append(text) 97 | 98 | if not entries: 99 | return 0 100 | 101 | # embed all at once 102 | vectors = embedder.embed(texts) 103 | 104 | # write out 105 | trial_folder = processed_folder / nct_id 106 | trial_folder.mkdir(parents=True, exist_ok=True) 107 | 108 | for entry, vec in zip(entries, vectors): 109 | crit_id = compute_criteria_id(entry["nct_id"], entry["criterion"]) 110 | out = { 111 | "criteria_id": crit_id, 112 | "nct_id": entry["nct_id"], 113 | "criterion": entry["criterion"], 114 | "entities": entry["entities"], 115 | "eligibility_type": entry["eligibility_type"], 116 | "criterion_vector": vec, 117 | } 118 | (trial_folder / f"{crit_id}.json").write_text(json.dumps(out, indent=2)) 119 | 120 | logger.info(f"Processed {len(entries)} criteria for {nct_id}") 121 | return len(entries) 122 | 123 | 124 | def main(): 125 | p = argparse.ArgumentParser( 126 | description="Prepare & embed eligibility criteria per trial" 127 | ) 128 | p.add_argument( 129 | "--ids-file", required=True, help="Path to nct_ids.txt (one NCT ID per line)" 130 | ) 131 | p.add_argument( 132 | "--source-folder", 133 | required=True, 134 | help="Folder containing raw trial JSONs named .json", 135 | ) 136 | p.add_argument( 137 | "--processed-folder", 138 | default="processed_criteria", 139 | help="Output root; will contain one subfolder per trial", 140 | ) 141 | p.add_argument( 142 | "--model-name", default="BAAI/bge-m3", help="Sentence embedding model name" 143 | ) 144 | p.add_argument("--use-gpu", action="store_true", help="Enable GPU iff available") 145 | args = p.parse_args() 146 | 147 | ids = [line.strip() for line in open(args.ids_file) if line.strip()] 148 | source_folder = Path(args.source_folder) 149 | processed_folder = Path(args.processed_folder) 150 | processed_folder.mkdir(parents=True, exist_ok=True) 151 | 152 | embedder = SentenceEmbedder(model_name=args.model_name, use_gpu=args.use_gpu) 153 | 154 | total = 0 155 | skipped = 0 156 | 157 | for nct in ids: 158 | trial_folder = processed_folder / nct 159 | # Skip if already processed (i.e. folder exists and contains at least one .json) 160 | if trial_folder.exists() and any(trial_folder.glob("*.json")): 161 | logger.info(f"Skipping {nct}: already processed") 162 | skipped += 1 163 | continue 164 | 165 | processed_count = process_trial(nct, source_folder, processed_folder, embedder) 166 | total += processed_count 167 | 168 | logger.info( 169 | f"✅ Finished embedding. Total criteria written: {total}. Trials skipped: {skipped}." 170 | ) 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /utils/finetuning/finetune_instruct/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List 4 | 5 | from transformers import TrainingArguments 6 | 7 | # Set desired GPU indices if needed 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,3,4,5,6,7" 9 | 10 | 11 | def default_list() -> List[str]: 12 | return ["q_proj", "v_proj", "o_proj", "down_proj", "up_proj", "gate_proj"] 13 | 14 | 15 | @dataclass 16 | class ModelArguments: 17 | """ 18 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 19 | """ 20 | 21 | model_name_or_path: str = field( 22 | metadata={ 23 | "help": "Path to pretrained model or model identifier from huggingface.co/models" 24 | } 25 | ) 26 | peft_model_path: str = field(default="") 27 | config_name: Optional[str] = field( 28 | default=None, 29 | metadata={ 30 | "help": "Pretrained config name or path if not the same as model_name" 31 | }, 32 | ) 33 | tokenizer_name: Optional[str] = field( 34 | default=None, 35 | metadata={ 36 | "help": "Pretrained tokenizer name or path if not the same as model_name" 37 | }, 38 | ) 39 | use_lora: bool = field( 40 | default=True, 41 | metadata={ 42 | "help": "If passed, will use LORA (low-rank parameter-efficient training) to train the model." 43 | }, 44 | ) 45 | lora_rank: int = field( 46 | default=64, metadata={"help": "The rank dimension for LoRA."} 47 | ) 48 | lora_alpha: float = field( 49 | default=16, metadata={"help": "The scaling factor (alpha) for LoRA."} 50 | ) 51 | lora_dropout: float = field( 52 | default=0.05, metadata={"help": "The dropout rate for LoRA layers."} 53 | ) 54 | target_modules: List[str] = field( 55 | default_factory=default_list, 56 | metadata={"help": "List of modules to apply LoRA to."}, 57 | ) 58 | save_merged_lora_model: bool = field( 59 | default=False, 60 | metadata={ 61 | "help": "If True, merges the LoRA parameters into the base model before saving." 62 | }, 63 | ) 64 | use_flash_attn: bool = field( 65 | default=True, 66 | metadata={ 67 | "help": "If True, use flash attention during training (if supported)." 68 | }, 69 | ) 70 | use_slow_tokenizer: bool = field( 71 | default=False, 72 | metadata={ 73 | "help": "If True, use a slow (Python-based) tokenizer instead of a fast (C++/Rust) one." 74 | }, 75 | ) 76 | low_cpu_mem_usage: bool = field( 77 | default=False, 78 | metadata={ 79 | "help": "If True, create the model as an empty shell and then load weights, reducing RAM usage." 80 | }, 81 | ) 82 | cache_dir: str = field( 83 | default="tmp", 84 | metadata={ 85 | "help": "Path to the directory where models and tokenizers are cached." 86 | }, 87 | ) 88 | token: str = field( 89 | default=None, metadata={"help": "HuggingFace hub token for private models."} 90 | ) 91 | from_peft: str = field( 92 | default=None, 93 | metadata={"help": "Path to a PEFT checkpoint from which to load a model."}, 94 | ) 95 | lora_extra_parameters: str = field( 96 | default=None, metadata={"help": "Additional modules to save when using LoRA."} 97 | ) 98 | 99 | 100 | @dataclass 101 | class DataArguments: 102 | train_data: str = field( 103 | default="toy_finetune_data.jsonl", 104 | metadata={"help": "Path to the training data file (in JSONL format)."}, 105 | ) 106 | 107 | query_max_len: int = field( 108 | default=32, 109 | metadata={ 110 | "help": "Max length of the input sequence for the instruction/input portion." 111 | }, 112 | ) 113 | passage_max_len: int = field( 114 | default=128, 115 | metadata={ 116 | "help": "Max length of the entire sequence (instruction + input + output)." 117 | }, 118 | ) 119 | 120 | max_example_num_per_dataset: int = field( 121 | default=10, 122 | metadata={"help": "Maximum number of examples to load from the dataset."}, 123 | ) 124 | 125 | cache_path: str = field( 126 | default="./data_dir", 127 | metadata={"help": "Directory for caching processed datasets."}, 128 | ) 129 | 130 | load_from_disk: bool = field( 131 | default=False, 132 | metadata={ 133 | "help": "If True, load a previously saved dataset from disk instead of processing from scratch." 134 | }, 135 | ) 136 | 137 | load_disk_path: str = field( 138 | default=None, 139 | metadata={ 140 | "help": "Path to the saved dataset on disk if load_from_disk is True." 141 | }, 142 | ) 143 | 144 | save_to_disk: bool = field( 145 | default=False, metadata={"help": "If True, save the processed dataset to disk."} 146 | ) 147 | 148 | save_disk_path: str = field( 149 | default=None, 150 | metadata={ 151 | "help": "Path to save the processed dataset if save_to_disk is True." 152 | }, 153 | ) 154 | 155 | num_shards: int = field( 156 | default=0, 157 | metadata={"help": "Number of shards to split the dataset into when saving."}, 158 | ) 159 | 160 | save_max_shard_size: str = field( 161 | default="50GB", 162 | metadata={"help": "Maximum size of each shard when saving the dataset."}, 163 | ) 164 | 165 | exit_after_save: bool = field( 166 | default=False, 167 | metadata={"help": "If True, exit the program after saving the dataset."}, 168 | ) 169 | 170 | def __post_init__(self): 171 | if not os.path.exists(self.train_data): 172 | raise FileNotFoundError( 173 | f"Cannot find file: {self.train_data}. Please provide a valid path." 174 | ) 175 | 176 | 177 | @dataclass 178 | class SFTTrainingArguments(TrainingArguments): 179 | """ 180 | Training arguments specifically for supervised fine-tuning a causal language model. 181 | """ 182 | 183 | # Additional arguments can be added if needed. 184 | pass 185 | -------------------------------------------------------------------------------- /source/Matcher/models/llm/llm_reranker.py: -------------------------------------------------------------------------------- 1 | import re 2 | import threading 3 | import unicodedata 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | from typing import Dict, List, Optional 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from Matcher.utils.logging_config import setup_logging 10 | from peft import PeftModel 11 | from tqdm import tqdm 12 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 13 | 14 | logger = setup_logging() 15 | 16 | 17 | class LLMReranker: 18 | def __init__( 19 | self, 20 | model_path: str, 21 | adapter_path: Optional[str] = None, 22 | device: int = 0, 23 | torch_dtype=torch.float16, 24 | batch_size: int = 8, 25 | ): 26 | self.model_path = model_path 27 | self.adapter_path = adapter_path 28 | self.torch_dtype = torch_dtype 29 | self.batch_size = batch_size 30 | # Resolve device string 31 | if torch.cuda.is_available(): 32 | cuda_count = torch.cuda.device_count() 33 | idx = int(device) if isinstance(device, int) else 0 34 | if idx < 0 or idx >= cuda_count: 35 | logger.warning( 36 | f"LLMReranker: requested CUDA device {device} invalid; using 0 (num_gpus={cuda_count})." 37 | ) 38 | idx = 0 39 | self.device_str = f"cuda:{idx}" 40 | # Ensure Accelerate/HF loaders use the selected GPU when device_map='auto' 41 | try: 42 | torch.cuda.set_device(idx) 43 | except Exception as e: 44 | logger.warning(f"Could not set CUDA device to {idx}: {e}") 45 | else: 46 | logger.warning("LLMReranker: CUDA not available; using CPU.") 47 | self.device_str = "cpu" 48 | 49 | self.tokenizer = AutoTokenizer.from_pretrained( 50 | self.model_path, trust_remote_code=True 51 | ) 52 | self._initialize_token_ids() 53 | self.model = self.load_model() 54 | self.model_lock = threading.Lock() 55 | 56 | def _initialize_token_ids(self): 57 | responses = ["Yes", "No"] 58 | token_ids = [ 59 | self.tokenizer(response, add_special_tokens=False)["input_ids"] 60 | for response in responses 61 | ] 62 | self.applicable_token_id, self.not_applicable_token_id = [ 63 | ids[0] for ids in token_ids 64 | ] 65 | 66 | def load_model(self): 67 | use_cuda = self.device_str.startswith("cuda") 68 | quant_config = ( 69 | BitsAndBytesConfig( 70 | load_in_4bit=True, 71 | bnb_4bit_use_double_quant=True, 72 | bnb_4bit_quant_type="nf4", 73 | bnb_4bit_compute_dtype=self.torch_dtype, 74 | ) 75 | if use_cuda 76 | else None 77 | ) 78 | model = AutoModelForCausalLM.from_pretrained( 79 | self.model_path, 80 | torch_dtype=self.torch_dtype if use_cuda else torch.float32, 81 | quantization_config=quant_config, 82 | device_map="auto" if use_cuda else None, 83 | attn_implementation="flash_attention_2" if use_cuda else None, 84 | trust_remote_code=True, 85 | ) 86 | if self.adapter_path: 87 | model = PeftModel.from_pretrained(model, self.adapter_path) 88 | model.eval() 89 | return model 90 | 91 | def preprocess_text(self, text: str) -> str: 92 | text = unicodedata.normalize("NFKD", text) 93 | text = re.sub(r"\s+", " ", text) 94 | return text.strip() 95 | 96 | def create_messages(self, patient_text: str, trial_text: str) -> List[Dict]: 97 | system_prompt = ( 98 | "You are a clinical assistant tasked with determining whether the patient information (Statement A) " 99 | "provides enough details to evaluate whether the patient satisfies or violates the clinical " 100 | "trial eligibility criterion (Statement B). Respond with 'Yes' if Statement A contains sufficient " 101 | "information to make this evaluation, or 'No' if it does not." 102 | ) 103 | return [ 104 | {"role": "user", "content": system_prompt}, 105 | {"role": "assistant", "content": " "}, 106 | { 107 | "role": "user", 108 | "content": f"Statement A: {patient_text}\nStatement B: {trial_text}\n\n", 109 | }, 110 | ] 111 | 112 | def process_batch(self, batch: List[tuple]) -> List[Dict]: 113 | batch_prompts = [] 114 | for patient_text, trial_text in batch: 115 | messages = self.create_messages( 116 | self.preprocess_text(patient_text), self.preprocess_text(trial_text) 117 | ) 118 | prompt = self.tokenizer.apply_chat_template( 119 | messages, tokenize=False, add_generation_prompt=True 120 | ) 121 | batch_prompts.append(prompt) 122 | inputs = self.tokenizer(batch_prompts, return_tensors="pt", padding=True) 123 | inputs = {k: v.to(self.device_str) for k, v in inputs.items()} 124 | with self.model_lock: 125 | with torch.no_grad(): 126 | outputs = self.model(**inputs) 127 | logits = outputs.logits[:, -1, :] 128 | probabilities = F.softmax(logits, dim=-1) 129 | applicable_probs = probabilities[:, self.applicable_token_id].tolist() 130 | return [ 131 | {"llm_score": prob, "answer": "Yes" if prob > 0.5 else "No"} 132 | for prob in applicable_probs 133 | ] 134 | 135 | def rank_pairs(self, patient_trial_pairs: List[tuple]) -> List[Dict]: 136 | batches = [ 137 | patient_trial_pairs[i : i + self.batch_size] 138 | for i in range(0, len(patient_trial_pairs), self.batch_size) 139 | ] 140 | results = [] 141 | with ThreadPoolExecutor(max_workers=4) as executor: 142 | futures = [executor.submit(self.process_batch, batch) for batch in batches] 143 | for future in tqdm( 144 | as_completed(futures), total=len(futures), desc="Processing batches" 145 | ): 146 | results.extend(future.result()) 147 | return results 148 | -------------------------------------------------------------------------------- /utils/finetuning/finetune_instruct/data_llama.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import random 4 | from typing import List, Any, Dict, Optional 5 | from dataclasses import dataclass 6 | 7 | import datasets 8 | from torch.utils.data import Dataset 9 | from transformers import ( 10 | PreTrainedTokenizer, 11 | PreTrainedTokenizerBase, 12 | DataCollatorForSeq2Seq, 13 | ) 14 | 15 | 16 | class TrainDataset(Dataset): 17 | def __init__(self, args, tokenizer: PreTrainedTokenizer): 18 | # Load dataset (assuming args.train_data points to a directory or a JSON file) 19 | if os.path.isdir(args.train_data): 20 | train_datasets = [] 21 | for file in os.listdir(args.train_data): 22 | try: 23 | temp_dataset = datasets.load_dataset( 24 | "json", 25 | data_files=os.path.join(args.train_data, file), 26 | split="train", 27 | cache_dir=args.cache_path, 28 | ) 29 | except Exception as e: 30 | print(e, file) 31 | sys.exit() 32 | if len(temp_dataset) > args.max_example_num_per_dataset: 33 | temp_dataset = temp_dataset.select( 34 | random.sample( 35 | range(len(temp_dataset)), args.max_example_num_per_dataset 36 | ) 37 | ) 38 | train_datasets.append(temp_dataset) 39 | 40 | self.dataset = datasets.concatenate_datasets(train_datasets) 41 | else: 42 | self.dataset = datasets.load_dataset( 43 | "json", 44 | data_files=args.train_data, 45 | split="train", 46 | cache_dir=args.cache_path, 47 | ) 48 | 49 | self.tokenizer = tokenizer 50 | self.args = args 51 | self.total_len = len(self.dataset) 52 | 53 | # Define prompt format 54 | messages = [ 55 | { 56 | "role": "system", 57 | "content": "You are an expert trained on healthcare and biomedical domain!{instruction}", 58 | }, 59 | {"role": "user", "content": "{input}\n"}, 60 | ] 61 | self.prompt_format = self.tokenizer.apply_chat_template( 62 | messages, tokenize=False, add_generation_prompt=True 63 | ) 64 | 65 | # Maximum length for the model input 66 | self.max_length = self.args.query_max_len + self.args.passage_max_len 67 | 68 | def __len__(self): 69 | return self.total_len 70 | 71 | def __getitem__(self, index): 72 | example = self.dataset[index] 73 | instruction = example["instruction"] 74 | input_text = example.get("input", "") 75 | output_text = example["output"] 76 | 77 | # Create the prompt 78 | prompt = self.prompt_format.format( 79 | instruction=instruction.strip(), input=input_text.strip() 80 | ) 81 | 82 | # Tokenize the prompt and output together 83 | full_input = prompt + output_text 84 | tokenized = self.tokenizer( 85 | full_input, 86 | max_length=self.max_length, 87 | truncation=True, 88 | return_tensors=None, 89 | add_special_tokens=True, 90 | ) 91 | 92 | input_ids = tokenized["input_ids"] 93 | attention_mask = tokenized["attention_mask"] 94 | 95 | # Determine where output starts 96 | prompt_tokenized = self.tokenizer( 97 | prompt, 98 | max_length=self.max_length, 99 | truncation=True, 100 | return_tensors=None, 101 | add_special_tokens=True, 102 | ) 103 | prompt_len = len(prompt_tokenized["input_ids"]) 104 | 105 | # Create labels 106 | labels = [-100] * len(input_ids) 107 | if prompt_len < len(input_ids): 108 | labels[prompt_len:] = input_ids[prompt_len:] 109 | 110 | return { 111 | "input_ids": input_ids, 112 | "attention_mask": attention_mask, 113 | "labels": labels, 114 | } 115 | 116 | 117 | @dataclass 118 | class DataCollatorForFinetuning(DataCollatorForSeq2Seq): 119 | """ 120 | Collator that pads input_ids, attention_masks, and labels for LLaMA fine-tuning. 121 | """ 122 | 123 | query_max_len: int = 32 124 | passage_max_len: int = 128 125 | label_pad_token_id: int = -100 126 | tokenizer: PreTrainedTokenizerBase = None 127 | padding: bool = True 128 | pad_to_multiple_of: Optional[int] = None 129 | return_tensors: str = "pt" 130 | 131 | def __post_init__(self): 132 | if self.tokenizer is None: 133 | raise ValueError("Tokenizer must be provided to the DataCollator.") 134 | if self.tokenizer.pad_token_id != 128009: 135 | raise ValueError( 136 | "The tokenizer pad_token_id must be set to 128009 for LLaMA." 137 | ) 138 | 139 | def __call__( 140 | self, features: List[Dict[str, Any]], return_tensors: Optional[str] = None 141 | ) -> Dict[str, Any]: 142 | return_tensors = return_tensors or self.return_tensors 143 | 144 | # Separate labels 145 | labels = [feature.pop("labels") for feature in features] 146 | 147 | # Pad labels 148 | label_features = [{"input_ids": label} for label in labels] 149 | padded_labels = self.tokenizer.pad( 150 | label_features, 151 | padding="longest", 152 | return_tensors=return_tensors, 153 | pad_to_multiple_of=self.pad_to_multiple_of, 154 | ) 155 | labels_tensor = padded_labels["input_ids"] 156 | labels_tensor[labels_tensor == self.tokenizer.pad_token_id] = ( 157 | self.label_pad_token_id 158 | ) 159 | 160 | # Pad input_ids and attention_mask 161 | max_length = self.query_max_len + self.passage_max_len 162 | padded_features = self.tokenizer.pad( 163 | features, 164 | padding=self.padding, 165 | max_length=max_length, 166 | pad_to_multiple_of=self.pad_to_multiple_of, 167 | return_tensors=return_tensors, 168 | ) 169 | 170 | # Add labels to the padded features 171 | padded_features["labels"] = labels_tensor 172 | return padded_features 173 | --------------------------------------------------------------------------------