├── .circleci └── config.yml ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── general-question.md └── workflows │ ├── codeql-analysis.yml │ ├── conda-tests.yml │ └── pre-commit.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CODE-OF-CONDUCT.md ├── CONTRIBUTE.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── conf.py ├── index.rst ├── install.rst ├── make.bat ├── notebooks.rst ├── requirements.txt ├── tdc.base_dataset.rst ├── tdc.benchmark_group.rst ├── tdc.chem_utils.rst ├── tdc.evaluator.rst ├── tdc.generation.rst ├── tdc.metadata.rst ├── tdc.multi_pred.rst ├── tdc.oracles.rst ├── tdc.single_pred.rst └── tdc.utils.rst ├── environment.yml ├── examples ├── generation │ └── docking_generation │ │ ├── GCPN │ │ ├── LICENSE.txt │ │ ├── README.md │ │ ├── docking_iter.png │ │ ├── download.sh │ │ ├── gcpn_docking.png │ │ ├── gym-molecule │ │ │ ├── README.md │ │ │ ├── gym_molecule │ │ │ │ ├── __init__.py │ │ │ │ └── envs │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── molecule.py │ │ │ │ │ ├── molecule_combine_substructures.ipynb │ │ │ │ │ ├── molecule_evaluation.ipynb │ │ │ │ │ ├── opt.test.logP-SA │ │ │ │ │ └── sascorer.py │ │ │ └── setup.py │ │ ├── result_analysis.py │ │ ├── run_molecule.py │ │ └── upload.sh │ │ ├── MARS │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ ├── common │ │ │ ├── chem.py │ │ │ ├── nn.py │ │ │ ├── train.py │ │ │ └── utils.py │ │ ├── datasets │ │ │ ├── datasets.py │ │ │ ├── prepro_vocab.py │ │ │ └── utils.py │ │ ├── docking_iter.png │ │ ├── download.sh │ │ ├── environment.yml │ │ ├── estimator │ │ │ ├── estimator.py │ │ │ ├── models.py │ │ │ └── scorer │ │ │ │ ├── chemprop_scorer.py │ │ │ │ ├── drd2_scorer.py │ │ │ │ ├── eval_scorer.py │ │ │ │ ├── kinase_scorer.py │ │ │ │ ├── sa_scorer.py │ │ │ │ ├── scorer.py │ │ │ │ └── utils.py │ │ ├── evaluate.py │ │ ├── main.py │ │ ├── mars │ │ │ ├── mars_1_100.txt │ │ │ ├── mars_1_1000.txt │ │ │ ├── mars_1_500.txt │ │ │ ├── mars_1_5000.txt │ │ │ ├── mars_2_100.txt │ │ │ ├── mars_2_1000.txt │ │ │ ├── mars_2_500.txt │ │ │ ├── mars_2_5000.txt │ │ │ ├── mars_3_100.txt │ │ │ ├── mars_3_1000.txt │ │ │ ├── mars_3_500.txt │ │ │ └── mars_3_5000.txt │ │ ├── preprocess_zinc.py │ │ ├── proposal │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── editor.py │ │ │ │ └── editor_basic.py │ │ │ └── proposal.py │ │ ├── result │ │ │ ├── 100.txt │ │ │ ├── 1000.txt │ │ │ ├── 1050.txt │ │ │ ├── 1100.txt │ │ │ ├── 1150.txt │ │ │ ├── 1200.txt │ │ │ ├── 1250.txt │ │ │ ├── 1300.txt │ │ │ ├── 1350.txt │ │ │ ├── 1400.txt │ │ │ ├── 1450.txt │ │ │ ├── 150.txt │ │ │ ├── 1500.txt │ │ │ ├── 1550.txt │ │ │ ├── 1600.txt │ │ │ ├── 1650.txt │ │ │ ├── 1700.txt │ │ │ ├── 1750.txt │ │ │ ├── 1800.txt │ │ │ ├── 1850.txt │ │ │ ├── 1900.txt │ │ │ ├── 1950.txt │ │ │ ├── 200.txt │ │ │ ├── 2000.txt │ │ │ ├── 2050.txt │ │ │ ├── 2100.txt │ │ │ ├── 2150.txt │ │ │ ├── 2200.txt │ │ │ ├── 2250.txt │ │ │ ├── 2300.txt │ │ │ ├── 2350.txt │ │ │ ├── 2400.txt │ │ │ ├── 2450.txt │ │ │ ├── 250.txt │ │ │ ├── 2500.txt │ │ │ ├── 2550.txt │ │ │ ├── 2600.txt │ │ │ ├── 2650.txt │ │ │ ├── 2700.txt │ │ │ ├── 2750.txt │ │ │ ├── 2800.txt │ │ │ ├── 2850.txt │ │ │ ├── 2900.txt │ │ │ ├── 2950.txt │ │ │ ├── 300.txt │ │ │ ├── 3000.txt │ │ │ ├── 3050.txt │ │ │ ├── 3100.txt │ │ │ ├── 3150.txt │ │ │ ├── 3200.txt │ │ │ ├── 3250.txt │ │ │ ├── 3300.txt │ │ │ ├── 3350.txt │ │ │ ├── 3400.txt │ │ │ ├── 3450.txt │ │ │ ├── 350.txt │ │ │ ├── 3500.txt │ │ │ ├── 3550.txt │ │ │ ├── 3600.txt │ │ │ ├── 3650.txt │ │ │ ├── 3700.txt │ │ │ ├── 3750.txt │ │ │ ├── 3800.txt │ │ │ ├── 3850.txt │ │ │ ├── 3900.txt │ │ │ ├── 3950.txt │ │ │ ├── 400.txt │ │ │ ├── 4000.txt │ │ │ ├── 4050.txt │ │ │ ├── 4100.txt │ │ │ ├── 4150.txt │ │ │ ├── 4200.txt │ │ │ ├── 4250.txt │ │ │ ├── 4300.txt │ │ │ ├── 4350.txt │ │ │ ├── 4400.txt │ │ │ ├── 4450.txt │ │ │ ├── 450.txt │ │ │ ├── 4500.txt │ │ │ ├── 4550.txt │ │ │ ├── 4600.txt │ │ │ ├── 4650.txt │ │ │ ├── 4700.txt │ │ │ ├── 4750.txt │ │ │ ├── 4800.txt │ │ │ ├── 4850.txt │ │ │ ├── 4900.txt │ │ │ ├── 4950.txt │ │ │ ├── 500.txt │ │ │ ├── 5000.txt │ │ │ ├── 5050.txt │ │ │ ├── 5100.txt │ │ │ ├── 5150.txt │ │ │ ├── 5200.txt │ │ │ ├── 5250.txt │ │ │ ├── 5300.txt │ │ │ ├── 5350.txt │ │ │ ├── 5400.txt │ │ │ ├── 5450.txt │ │ │ ├── 550.txt │ │ │ ├── 5500.txt │ │ │ ├── 5550.txt │ │ │ ├── 5600.txt │ │ │ ├── 5650.txt │ │ │ ├── 5700.txt │ │ │ ├── 5750.txt │ │ │ ├── 5800.txt │ │ │ ├── 5850.txt │ │ │ ├── 5900.txt │ │ │ ├── 5950.txt │ │ │ ├── 600.txt │ │ │ ├── 6000.txt │ │ │ ├── 6050.txt │ │ │ ├── 6100.txt │ │ │ ├── 6150.txt │ │ │ ├── 6200.txt │ │ │ ├── 6250.txt │ │ │ ├── 6300.txt │ │ │ ├── 6350.txt │ │ │ ├── 6400.txt │ │ │ ├── 6450.txt │ │ │ ├── 650.txt │ │ │ ├── 700.txt │ │ │ ├── 750.txt │ │ │ ├── 800.txt │ │ │ ├── 850.txt │ │ │ ├── 900.txt │ │ │ └── 950.txt │ │ ├── result_analysis.py │ │ ├── sampler.py │ │ └── upload.sh │ │ ├── guacamol_tdc │ │ ├── README.md │ │ ├── download.sh │ │ ├── fix_99_to_100.py │ │ ├── guacamol │ │ │ ├── .flake8 │ │ │ ├── .gitignore │ │ │ ├── .travis.yml │ │ │ ├── LICENSE │ │ │ ├── MANIFEST.in │ │ │ ├── README.md │ │ │ ├── dockers │ │ │ │ └── Dockerfile │ │ │ ├── guacamol │ │ │ │ ├── __init__.py │ │ │ │ ├── assess_distribution_learning.py │ │ │ │ ├── assess_goal_directed_generation.py │ │ │ │ ├── benchmark_suites.py │ │ │ │ ├── common_scoring_functions.py │ │ │ │ ├── distribution_learning_benchmark.py │ │ │ │ ├── distribution_matching_generator.py │ │ │ │ ├── frechet_benchmark.py │ │ │ │ ├── goal_directed_benchmark.py │ │ │ │ ├── goal_directed_generator.py │ │ │ │ ├── goal_directed_score_contributions.py │ │ │ │ ├── py.typed │ │ │ │ ├── score_modifier.py │ │ │ │ ├── scoring_function.py │ │ │ │ ├── standard_benchmarks.py │ │ │ │ └── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── chemistry.py │ │ │ │ │ ├── data.py │ │ │ │ │ ├── descriptors.py │ │ │ │ │ ├── fingerprints.py │ │ │ │ │ ├── helpers.py │ │ │ │ │ ├── math.py │ │ │ │ │ └── sampling_helpers.py │ │ │ ├── mypy.ini │ │ │ ├── requirements.txt │ │ │ ├── setup.py │ │ │ └── upload.sh │ │ ├── guacamol_baselines │ │ │ ├── .flake8 │ │ │ ├── .gitignore │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── best_from_chembl │ │ │ │ ├── chembl_file_reader.py │ │ │ │ ├── goal_directed_generation.py │ │ │ │ └── optimizer.py │ │ │ ├── dockers │ │ │ │ ├── Dockerfile │ │ │ │ └── requirements.txt │ │ │ ├── fetch_guacamol_dataset.sh │ │ │ ├── graph_ga │ │ │ │ ├── crossover.py │ │ │ │ ├── goal_directed_generation.py │ │ │ │ ├── graph_ga_run.py │ │ │ │ └── mutate.py │ │ │ ├── graph_mcts │ │ │ │ ├── analyze_dataset.py │ │ │ │ ├── distribution_learning.py │ │ │ │ ├── goal_directed_generation.py │ │ │ │ ├── p1.p │ │ │ │ ├── p_ring.p │ │ │ │ ├── r_s1.p │ │ │ │ ├── rs_make_ring.p │ │ │ │ ├── rs_ring.p │ │ │ │ ├── size_stats.p │ │ │ │ └── stats.py │ │ │ ├── moses_baselines │ │ │ │ ├── README.md │ │ │ │ ├── aae_distribution_learning.py │ │ │ │ ├── aae_train.py │ │ │ │ ├── common.py │ │ │ │ ├── organ_distribution_learning.py │ │ │ │ ├── organ_train.py │ │ │ │ ├── vae_distribution_learning.py │ │ │ │ └── vae_train.py │ │ │ ├── random_smiles_sampler │ │ │ │ ├── distribution_learning.py │ │ │ │ ├── generator.py │ │ │ │ ├── goal_directed_generation.py │ │ │ │ └── optimizer.py │ │ │ ├── requirements.txt │ │ │ ├── smiles_ga │ │ │ │ ├── cfg_util.py │ │ │ │ ├── goal_directed_generation.py │ │ │ │ └── smiles_grammar.py │ │ │ ├── smiles_lstm_hc │ │ │ │ ├── action_sampler.py │ │ │ │ ├── distribution_learning.py │ │ │ │ ├── goal_directed_generation.py │ │ │ │ ├── rnn_generator.py │ │ │ │ ├── rnn_model.py │ │ │ │ ├── rnn_sampler.py │ │ │ │ ├── rnn_trainer.py │ │ │ │ ├── rnn_utils.py │ │ │ │ ├── run_smiles_lstm_hc.py │ │ │ │ ├── smiles_char_dict.py │ │ │ │ ├── smiles_rnn_directed_generator.py │ │ │ │ ├── smiles_rnn_distribution_learner.py │ │ │ │ ├── smiles_rnn_generator.py │ │ │ │ └── train_smiles_lstm_model.py │ │ │ ├── smiles_lstm_ppo │ │ │ │ ├── action_replay.py │ │ │ │ ├── goal_directed_generation.py │ │ │ │ ├── molecule_batch.py │ │ │ │ ├── ppo_directed_generator.py │ │ │ │ ├── ppo_generator.py │ │ │ │ ├── ppo_trainer.py │ │ │ │ ├── rnn_model.py │ │ │ │ └── running_reward.py │ │ │ └── upload.sh │ │ ├── results0 │ │ │ ├── graph_ga │ │ │ │ └── goal_directed_params.json │ │ │ ├── random_smiles_sampler │ │ │ │ └── goal_directed_results.json │ │ │ ├── smiles_ga │ │ │ │ └── goal_directed_params.json │ │ │ └── smiles_lstm_ppo │ │ │ │ ├── goal_directed_params.json │ │ │ │ └── pretrained_model │ │ │ │ └── model_final_0.473.json │ │ ├── results_limit_oracle │ │ │ ├── best_from_chembl.json │ │ │ ├── graph_ga.json │ │ │ ├── random_smiles_sampler.json │ │ │ └── smiles_ga.json │ │ └── upload.sh │ │ └── moldqn │ │ ├── README.md │ │ ├── chemgraph │ │ ├── __init__.py │ │ ├── all_800_mols.json │ │ ├── chemutil.py │ │ ├── configs │ │ │ ├── bootstrap_dqn.json │ │ │ ├── bootstrap_dqn_opt_800.json │ │ │ ├── bootstrap_dqn_step1.json │ │ │ ├── bootstrap_dqn_step2.json │ │ │ ├── multi_obj_dqn.json │ │ │ ├── naive_dqn.json │ │ │ ├── naive_dqn_opt_800.json │ │ │ ├── qed_logp_jnk_gsk.json │ │ │ └── target_sas.json │ │ ├── denovo.py │ │ ├── docking_eval.py │ │ ├── docking_iter.png │ │ ├── docking_smilesvaluelst_1.pkl │ │ ├── docking_smilesvaluelst_10.pkl │ │ ├── docking_smilesvaluelst_11.pkl │ │ ├── docking_smilesvaluelst_12.pkl │ │ ├── docking_smilesvaluelst_3.pkl │ │ ├── docking_smilesvaluelst_9.pkl │ │ ├── download.sh │ │ ├── dqn │ │ │ ├── __init__.py │ │ │ ├── deep_q_networks.py │ │ │ ├── molecules.py │ │ │ ├── py │ │ │ │ ├── SA_Score │ │ │ │ │ ├── README │ │ │ │ │ ├── UnitTestSAScore.py │ │ │ │ │ └── sascorer.py │ │ │ │ ├── __init__.py │ │ │ │ └── molecules.py │ │ │ ├── run_dqn.py │ │ │ └── tensorflow_core │ │ │ │ ├── __init__.py │ │ │ │ └── core.py │ │ ├── logp_similar_train.py │ │ ├── multi_obj_opt.py │ │ ├── optimize_docking.py │ │ ├── optimize_logp.py │ │ ├── optimize_logp_of_800_molecules.py │ │ ├── optimize_qed.py │ │ ├── qed_similar_train.py │ │ ├── result │ │ │ ├── ckpt-104000 │ │ │ ├── ckpt-112000 │ │ │ ├── ckpt-120000 │ │ │ ├── ckpt-128000 │ │ │ ├── ckpt-136000 │ │ │ ├── ckpt-144000 │ │ │ ├── ckpt-152000 │ │ │ ├── ckpt-160000 │ │ │ ├── ckpt-168000 │ │ │ ├── ckpt-176000 │ │ │ ├── ckpt-184000 │ │ │ ├── ckpt-192000 │ │ │ ├── ckpt-200000 │ │ │ ├── ckpt-40000 │ │ │ ├── ckpt-48000 │ │ │ ├── ckpt-56000 │ │ │ ├── ckpt-64000 │ │ │ ├── ckpt-72000 │ │ │ ├── ckpt-80000 │ │ │ ├── ckpt-88000 │ │ │ └── ckpt-96000 │ │ ├── result_analysis.py │ │ ├── result_analysis0.py │ │ ├── target_sas.py │ │ ├── target_sas_eval.ipynb │ │ └── try.py │ │ ├── experimental │ │ ├── deep_q_networks_noise.py │ │ ├── eval_800_mols.py │ │ ├── max_qed_with_sim.py │ │ ├── multi_obj.py │ │ ├── multi_obj_gen.py │ │ ├── multi_obj_opt.py │ │ ├── optimize_800_mols.py │ │ ├── optimize_logp.py │ │ ├── optimize_qed.py │ │ ├── optimize_qed_final_reward.py │ │ ├── optimize_qed_max_steps.py │ │ ├── optimize_qed_noise.py │ │ ├── optimize_qed_t.py │ │ ├── optimize_weight_noise.py │ │ └── target_logp.py │ │ ├── mol_dqn.yml │ │ ├── plot │ │ ├── plot.py │ │ └── target_sas_results.csv │ │ ├── requirements.txt │ │ ├── requirements2.txt │ │ └── upload.sh ├── huggingface_examples │ └── herg │ │ ├── .gitignore │ │ ├── environment.yml │ │ ├── hERG_Karim_raytune_HF.ipynb │ │ ├── hERG_inhib.ipynb │ │ ├── hERG_inhib.py │ │ └── hyperopt.py ├── multi_pred │ ├── drugcombo │ │ ├── README.md │ │ ├── model_classes.py │ │ └── train_MLP.py │ ├── dti_dg │ │ ├── README.md │ │ ├── domainbed │ │ │ ├── __init__.py │ │ │ ├── algorithms.py │ │ │ ├── datasets.py │ │ │ ├── hparams_registry.py │ │ │ ├── lib │ │ │ │ ├── fast_data_loader.py │ │ │ │ ├── misc.py │ │ │ │ ├── query.py │ │ │ │ └── reporting.py │ │ │ ├── model_selection.py │ │ │ └── networks.py │ │ └── train.py │ └── geneperturb │ │ ├── prepare_benchmark_dataset.py │ │ └── run_gears.py └── single_pred │ └── admet │ ├── README.md │ └── run.py ├── fig ├── TDCneurips.pptx(1).png ├── logo.png ├── tdc_overview.png └── tdc_problems.png ├── requirements.txt ├── run_tests.py ├── setup.cfg ├── setup.py ├── tdc ├── __init__.py ├── base_dataset.py ├── benchmark_deprecated.py ├── benchmark_group │ ├── __init__.py │ ├── admet_group.py │ ├── base_group.py │ ├── counterfactual_group.py │ ├── docking_group.py │ ├── drugcombo_group.py │ ├── dti_dg_group.py │ ├── geneperturb_group.py │ ├── protein_peptide_group.py │ ├── scdti_group.py │ └── tcrepitope_group.py ├── chem_utils │ ├── __init__.py │ ├── evaluator.py │ ├── featurize │ │ ├── __init__.py │ │ ├── _smartsPatts.py │ │ ├── _smiles2pubchem.py │ │ ├── _xyz2mol.py │ │ └── molconvert.py │ └── oracle │ │ ├── __init__.py │ │ ├── docking.py │ │ ├── filter.py │ │ └── oracle.py ├── dataset_configs │ ├── __init__.py │ ├── brown_mdm2_ace2_12ca5_config.py │ ├── cellxgene_config.py │ ├── config.py │ ├── config_map.py │ ├── opentargets_dti.py │ └── scperturb_config.py ├── evaluator.py ├── feature_generators │ ├── __init__.py │ ├── anndata_to_dataframe.py │ ├── base.py │ ├── cellxgene_generator.py │ ├── data_feature_generator.py │ ├── protein_feature_generator.py │ └── resource.py ├── generation │ ├── __init__.py │ ├── bi_generation_dataset.py │ ├── generation_dataset.py │ ├── ligandmolgen.py │ ├── molgen.py │ ├── reaction.py │ ├── retrosyn.py │ └── sbdd.py ├── metadata.py ├── model_server │ ├── __init__.py │ ├── model_loaders │ │ └── scvi_loader.py │ ├── models │ │ ├── __init__.py │ │ ├── scgpt.py │ │ └── scvi.py │ ├── tdc_hf.py │ └── tokenizers │ │ ├── __init__.py │ │ ├── geneformer.py │ │ └── scgpt.py ├── multi_pred │ ├── __init__.py │ ├── anndata_dataset.py │ ├── antibodyaff.py │ ├── bi_pred_dataset.py │ ├── catalyst.py │ ├── ddi.py │ ├── drugres.py │ ├── drugsyn.py │ ├── dti.py │ ├── gda.py │ ├── mti.py │ ├── multi_pred_dataset.py │ ├── peptidemhc.py │ ├── perturboutcome.py │ ├── ppi.py │ ├── proteinpeptide.py │ ├── single_cell.py │ ├── tcr_epi.py │ ├── test_multi_pred.py │ └── trialoutcome.py ├── oracles.py ├── resource │ ├── __init__.py │ ├── cellxgene_census.py │ ├── dataloader.py │ ├── pharmone.py │ ├── pinnacle.py │ └── primekg.py ├── single_pred │ ├── __init__.py │ ├── adme.py │ ├── crispr_outcome.py │ ├── develop.py │ ├── epitope.py │ ├── hts.py │ ├── mpc.py │ ├── paratope.py │ ├── qm.py │ ├── single_pred_dataset.py │ ├── test_single_pred.py │ ├── tox.py │ └── yields.py ├── test │ ├── __init__.py │ ├── dev_tests │ │ ├── __init__.py │ │ ├── chem_utils_test │ │ │ ├── __init__.py │ │ │ ├── test_molconverter.py │ │ │ ├── test_molfilter.py │ │ │ └── test_oracles.py │ │ └── utils_tests │ │ │ ├── __init__.py │ │ │ ├── test_misc_utils.py │ │ │ └── test_splits.py │ ├── test_benchmark.py │ ├── test_data_process.py │ ├── test_dataloaders.py │ ├── test_functions.py │ ├── test_hf.py │ ├── test_model_server.py │ ├── test_oracles.py │ └── test_resources.py ├── utils │ ├── __init__.py │ ├── knowledge_graph.py │ ├── label.py │ ├── label_name_list.py │ ├── load.py │ ├── misc.py │ ├── query.py │ ├── retrieve.py │ └── split.py └── version.py └── tutorials ├── DGL_User_Group_Demo.ipynb ├── TDC-HuggingFace Interface Demo.ipynb ├── TDC_101_Data_Loader.ipynb ├── TDC_102_Data_Functions.ipynb ├── TDC_103.1_Datasets_Small_Molecules.ipynb ├── TDC_103.2_Datasets_Biologics.ipynb ├── TDC_104_ML_Model_DeepPurpose.ipynb ├── TDC_105_Oracle.ipynb ├── TDC_106_BenchmarkGroup_Submission_Demo.ipynb ├── User_Group ├── 3pbl_ligand.pdb ├── 3pbl_ligand.xyz ├── 3pbl_receptor.pdb ├── UserGroupMeeting_Tianfan.ipynb ├── UserGroupMeeting_Wenhao.ipynb ├── docking.png ├── docking_gflownet.png ├── ga_illustration.pdf ├── ga_illustration.png ├── generation_process.png ├── leaderboard.png ├── leaderboard_generative.png ├── opt_drd3_5000.json ├── oracle.png ├── tdc_problems.png ├── vina.png └── why_docking.png └── graphein_demo_developability.ipynb /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-python/ for more details 4 | # Adapted from https://github.com/NeuralEnsemble/python-neo 5 | version: 2 6 | workflows: 7 | version: 2 8 | test: 9 | jobs: 10 | - test-3.9 11 | jobs: 12 | test-3.9: 13 | docker: 14 | - image: circleci/python:3.10 15 | 16 | working_directory: ~/repo 17 | 18 | steps: 19 | - checkout 20 | - run: sudo chown -R circleci:circleci /usr/local/bin 21 | 22 | # Download and cache dependencies 23 | - restore_cache: 24 | keys: 25 | - v1-py3-dependencies-{{ checksum "requirements.txt" }} 26 | # fallback to using the latest cache if no exact match is found 27 | - v1-py3-dependencies- 28 | 29 | - run: 30 | name: Install git-lfs 31 | command: | 32 | sudo apt-get install git-lfs 33 | git lfs install 34 | 35 | - run: 36 | name: install dependencies 37 | command: | 38 | python -m venv venv 39 | . venv/bin/activate 40 | pip install --upgrade pip 41 | pip install -r requirements.txt 42 | pip install pytest 43 | pip install pytest-cov 44 | 45 | 46 | - save_cache: 47 | paths: 48 | - ./venv 49 | key: v1-py3-dependencies-{{ checksum "requirements.txt" }} 50 | 51 | 52 | # run tests! 53 | - run: 54 | name: run tests 55 | no_output_timeout: 30m 56 | command: | 57 | . venv/bin/activate 58 | pytest --ignore=tdc/test/dev_tests/ --ignore=tdc/test/test_resources.py --ignore=tdc/test/test_dataloaders.py --ignore=tdc/test/test_model_server.py --ignore=tdc/test/test_data_process.py 59 | 60 | - store_artifacts: 61 | path: test-reports 62 | destination: test-reports 63 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve TDC 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Environment:** 27 | - OS: 28 | - Python version: 29 | - TDC version: 30 | - Any other relevant information: 31 | 32 | **Additional context** 33 | Add any other context about the problem here. 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for TDC 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the problem** 11 | A clear and concise description of what the problem is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. Ideally, with pseudo-code. 15 | 16 | **Additional context** 17 | Add any other context or screenshots about the feature request here. 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/general-question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: General question 3 | about: General Q&A about TDC 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe your question.** 11 | 12 | 13 | -------------------------------------------------------------------------------- /.github/workflows/conda-tests.yml: -------------------------------------------------------------------------------- 1 | # Docs for the Azure Web Apps Deploy action: https://github.com/Azure/webapps-deploy 2 | # More GitHub Actions for Azure: https://github.com/Azure/actions 3 | # More info on Python, GitHub Actions, and Azure App Service: https://aka.ms/python-webapps-actions 4 | 5 | name: Build Conda Environment And Run Python Tests 6 | 7 | on: 8 | push: 9 | branches: 10 | - main 11 | - '*' 12 | pull_request: 13 | branches: [ "main" ] 14 | workflow_dispatch: 15 | 16 | jobs: 17 | build: 18 | runs-on: ubuntu-latest 19 | 20 | defaults: 21 | run: 22 | shell: bash -l {0} 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | 27 | - name: Set up Python version 28 | uses: actions/setup-python@v1 29 | with: 30 | python-version: '3.10' 31 | 32 | - name: Install git-lfs 33 | run: | 34 | sudo apt-get install git-lfs 35 | git lfs install 36 | 37 | - name: Setup Miniconda 38 | uses: conda-incubator/setup-miniconda@v2 39 | with: 40 | miniconda-version: "latest" 41 | channels: bioconda, conda-forge, defaults, dgl 42 | use-only-tar-bz2: true 43 | auto-update-conda: true 44 | auto-activate-base: true 45 | 46 | - name: Create and start Conda environment. Run tests 47 | run: | 48 | echo "Creating Conda Environment from environment.yml" 49 | conda env create -q -f environment.yml 50 | conda activate tdc-conda-env 51 | python run_tests.py 52 | conda deactivate 53 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | branches: ["main"] 6 | push: {} 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | with: 15 | python-version: '3.10' 16 | - uses: pre-commit/action@v3.0.1 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks.git 6 | rev: v4.0.1 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | 11 | - repo: https://github.com/google/yapf 12 | rev: v0.43.0 13 | hooks: 14 | - id: yapf 15 | name: "yapf" 16 | args: [--style=google, --recursive, --in-place] 17 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/conf.py 5 | 6 | python: 7 | version: 3.7 8 | install: 9 | - requirements: docs/requirements.txt 10 | - requirements: requirements.txt 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2020 TDC Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune tdc/test 2 | prune tests 3 | prune tutorials 4 | include README.md 5 | include requirements.txt 6 | include LICENSE 7 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | It is recommended to use **pip** for installation. 5 | 6 | .. code-block:: bash 7 | 8 | pip install PyTDC # normal install 9 | pip install PyTDC --upgrade # or update if needed 10 | 11 | Alternatively, you could use **conda** for installation: 12 | 13 | .. code-block:: bash 14 | 15 | conda install -c conda-forge pytdc 16 | 17 | 18 | **Core Required Dependencies**\ : 19 | 20 | * numpy 21 | * pandas 22 | * tqdm 23 | * seaborn 24 | * scikit_learn 25 | * fuzzywuzzy 26 | 27 | To use some of TDC features, you need to install additional packages. Those packages will be automatically installed. Detailed installation instructions will be displayed in the terminal in case automatic installation fails. 28 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | furo 2 | rdkit-pypi 3 | networkx 4 | setuptools 5 | DeepPurpose 6 | rxn4chemistry 7 | pandas_flavor 8 | descriptastorus @ git+https://github.com/bp-kelley/descriptastorus.git@2.4.0 9 | rd_filters @ git+https://github.com/PatWalters/rd_filters.git 10 | -------------------------------------------------------------------------------- /docs/tdc.base_dataset.rst: -------------------------------------------------------------------------------- 1 | tdc.base\_dataset 2 | ======================= 3 | 4 | .. automodule:: tdc.base_dataset 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tdc.benchmark_group.rst: -------------------------------------------------------------------------------- 1 | tdc.benchmark\_group 2 | ============================ 3 | 4 | tdc.benchmark\_group.base\_group module 5 | --------------------------------------- 6 | 7 | .. automodule:: tdc.benchmark_group.base_group 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | tdc.benchmark\_group.admet\_group module 13 | ---------------------------------------- 14 | 15 | .. automodule:: tdc.benchmark_group.admet_group 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | tdc.benchmark\_group.docking\_group module 21 | ------------------------------------------ 22 | 23 | .. automodule:: tdc.benchmark_group.docking_group 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | tdc.benchmark\_group.drugcombo\_group module 29 | -------------------------------------------- 30 | 31 | .. automodule:: tdc.benchmark_group.drugcombo_group 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | tdc.benchmark\_group.dti\_dg\_group module 37 | ------------------------------------------ 38 | 39 | .. automodule:: tdc.benchmark_group.dti_dg_group 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | -------------------------------------------------------------------------------- /docs/tdc.chem_utils.rst: -------------------------------------------------------------------------------- 1 | tdc.chem\_utils 2 | ======================= 3 | 4 | 5 | tdc.chem\_utils.featurize module 6 | -------------------------------- 7 | 8 | tdc.chem\_utils.featurize.molconvert submodule 9 | ^^^^^^^^^ 10 | 11 | .. automodule:: tdc.chem_utils.featurize.molconvert 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: 15 | 16 | tdc.chem\_utils.oracle module 17 | -------------------------------- 18 | 19 | tdc.chem\_utils.oracle.filter submodule 20 | ^^^^^^^^^ 21 | 22 | .. automodule:: tdc.chem_utils.oracle.filter 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | tdc.chem\_utils.oracle.oracle submodule 28 | ^^^^^^^^^ 29 | 30 | .. automodule:: tdc.chem_utils.oracle.oracle 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: 34 | 35 | tdc.chem\_utils.evaluator module 36 | -------------------------------- 37 | 38 | .. automodule:: tdc.chem_utils.evaluator 39 | :members: 40 | :undoc-members: 41 | :show-inheritance: 42 | -------------------------------------------------------------------------------- /docs/tdc.evaluator.rst: -------------------------------------------------------------------------------- 1 | tdc.evaluator 2 | ======================= 3 | 4 | .. automodule:: tdc.evaluator 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tdc.generation.rst: -------------------------------------------------------------------------------- 1 | tdc.generation 2 | ====================== 3 | 4 | tdc.generation.generation\_dataset module 5 | ----------------------------------------- 6 | 7 | .. automodule:: tdc.generation.generation_dataset 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | tdc.generation.molgen module 13 | ---------------------------- 14 | 15 | .. automodule:: tdc.generation.molgen 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | tdc.generation.reaction module 21 | ------------------------------ 22 | 23 | .. automodule:: tdc.generation.reaction 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | tdc.generation.retrosyn module 29 | ------------------------------ 30 | 31 | .. automodule:: tdc.generation.retrosyn 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | -------------------------------------------------------------------------------- /docs/tdc.metadata.rst: -------------------------------------------------------------------------------- 1 | tdc.metadata 2 | ======================= 3 | 4 | .. automodule:: tdc.metadata 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tdc.oracles.rst: -------------------------------------------------------------------------------- 1 | tdc.oracles 2 | ======================= 3 | 4 | .. automodule:: tdc.oracles 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tdc.single_pred.rst: -------------------------------------------------------------------------------- 1 | tdc.single\_pred 2 | ======================== 3 | 4 | tdc.single\_pred.single\_pred\_dataset module 5 | --------------------------------------------- 6 | 7 | .. automodule:: tdc.single_pred.single_pred_dataset 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | tdc.single\_pred.adme module 13 | ---------------------------- 14 | 15 | .. automodule:: tdc.single_pred.adme 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | tdc.single\_pred.crispr\_outcome module 21 | --------------------------------------- 22 | 23 | .. automodule:: tdc.single_pred.crispr_outcome 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | tdc.single\_pred.develop module 29 | ------------------------------- 30 | 31 | .. automodule:: tdc.single_pred.develop 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | tdc.single\_pred.epitope module 37 | ------------------------------- 38 | 39 | .. automodule:: tdc.single_pred.epitope 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | tdc.single\_pred.hts module 45 | --------------------------- 46 | 47 | .. automodule:: tdc.single_pred.hts 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | tdc.single\_pred.paratope module 53 | -------------------------------- 54 | 55 | .. automodule:: tdc.single_pred.paratope 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | tdc.single\_pred.qm module 61 | -------------------------- 62 | 63 | .. automodule:: tdc.single_pred.qm 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | 68 | 69 | tdc.single\_pred.test\_single\_pred module 70 | ------------------------------------------ 71 | 72 | .. automodule:: tdc.single_pred.test_single_pred 73 | :members: 74 | :undoc-members: 75 | :show-inheritance: 76 | 77 | tdc.single\_pred.tox module 78 | --------------------------- 79 | 80 | .. automodule:: tdc.single_pred.tox 81 | :members: 82 | :undoc-members: 83 | :show-inheritance: 84 | 85 | tdc.single\_pred.yields module 86 | ------------------------------ 87 | 88 | .. automodule:: tdc.single_pred.yields 89 | :members: 90 | :undoc-members: 91 | :show-inheritance: 92 | -------------------------------------------------------------------------------- /docs/tdc.utils.rst: -------------------------------------------------------------------------------- 1 | tdc.utils 2 | ================= 3 | 4 | tdc.utils.label module 5 | ---------------------- 6 | 7 | .. automodule:: tdc.utils.label 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | tdc.utils.label\_name\_list module 13 | ---------------------------------- 14 | 15 | .. automodule:: tdc.utils.label_name_list 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | tdc.utils.load module 21 | --------------------- 22 | 23 | .. automodule:: tdc.utils.load 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | tdc.utils.misc module 29 | --------------------- 30 | 31 | .. automodule:: tdc.utils.misc 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | tdc.utils.query module 37 | ---------------------- 38 | 39 | .. automodule:: tdc.utils.query 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | tdc.utils.retrieve module 45 | ------------------------- 46 | 47 | .. automodule:: tdc.utils.retrieve 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | tdc.utils.split module 53 | ---------------------- 54 | 55 | .. automodule:: tdc.utils.split 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: tdc-conda-env 2 | channels: 3 | - bioconda 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - anndata=0.6.22 8 | - biopython=1.78 9 | - dataclasses=0.8 10 | - fuzzywuzzy=0.18.0 11 | - huggingface_hub=0.20.3 12 | - mygene=3.2.2 13 | - numpy=1.26.4 14 | - openpyxl=3.0.10 15 | - python=3.10 16 | - pip=23.3.1 17 | - pandas=2.1.4 18 | - requests=2.31.0 19 | - scikit-learn=1.2.2 20 | - seaborn=0.12.2 21 | - tqdm=4.65.0 22 | - pip: 23 | - accelerate==0.33.0 24 | - cellxgene-census==1.15.0 25 | - datasets<2.20.0 26 | - dgl==1.1.3 27 | - evaluate==0.4.2 28 | - gget==0.28.4 29 | - moleculeace==3.0.0 30 | - pydantic==2.6.3 31 | - gget==0.28.4 32 | - pydantic==2.6.3 33 | - gget==0.28.4 34 | - pydantic==2.6.3 35 | - pytest==8.3.2 36 | - rdkit==2023.9.5 37 | - scvi-tools==1.2.0 38 | - tiledbsoma==1.11.4 39 | - torch==2.1.1 40 | - torch_geometric==2.5.3 41 | - torchvision==0.16.1 42 | - transformers==4.43.4 43 | - yapf==0.40.2 44 | 45 | variables: 46 | KMP_DUPLICATE_LIB_OK: "TRUE" 47 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/GCPN/LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD-3-Clause 2 | 3 | Copyright 2019 Jiaxuan You, Bowen Liu 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/GCPN/README.md: -------------------------------------------------------------------------------- 1 | # GCPN docking for TDC leaderboard 2 | 3 | 4 | ## conda 5 | 6 | ```bash 7 | source activate gcpn 8 | ``` 9 | 10 | ## Installation 11 | - Install rdkit, please refer to the offical website for further details, using anaconda is recommended: 12 | ```bash 13 | conda create -c rdkit -n my-rdkit-env rdkit 14 | ``` 15 | - Install mpi4py, networkx: 16 | ```bash 17 | conda install mpi4py 18 | pip install networkx=1.11 19 | ``` 20 | - Install OpenAI baseline dependencies: 21 | ```bash 22 | cd rl-baselines 23 | pip install -e . 24 | ``` 25 | - Install customized molecule gym environment: 26 | ```bash 27 | cd gym-molecule 28 | pip install -e . 29 | ``` 30 | 31 | 32 | ## Code description 33 | There are 4 important files: 34 | - `run_molecule.py` is the main code for running the program. You may tune all kinds of hyper-parameters there. 35 | - The molecule environment code is in `gym-molecule/gym_molecule/envs/molecule.py`. 36 | - RL related code is in `rl-baselines/baselines/ppo1` folder: `gcn_policy.py` is the GCN policy network; `pposgd_simple_gcn.py` is the PPO algorithm specifically tuned for GCN policy. 37 | 38 | ## Run docking 39 | 40 | ```bash 41 | source activate gcpn 42 | 43 | cd /project/molecular_data/graphnn/GCPN 44 | 45 | rm -rf ckpt result oracle_call_cnt 46 | 47 | mkdir ckpt result 48 | 49 | export PATH=$PATH:/project/molecular_data/graphnn/mol_dqn_docking/package_install/ADFRsuite_x86_64Linux_1.0/bin 50 | 51 | export PATH=$PATH:/project/molecular_data/graphnn/mol_dqn_docking/package_install/autodock_vina_1_1_2_linux_x86/bin 52 | 53 | python run_molecule.py 54 | ``` 55 | 56 | ## Run 57 | - single process run 58 | ```bash 59 | python run_molecule.py 60 | ``` 61 | - mutiple processes run 62 | ```bash 63 | mpirun -np 8 python run_molecule.py 2>/dev/null 64 | ``` 65 | `2>/dev/null` will hide the warning info provided by rdkit package. 66 | 67 | We highly recommend using tensorboard to monitor the training process. To do this, you may run 68 | ```bash 69 | tensorboard --logdir runs 70 | ``` 71 | 72 | All the generated molecules along the training process are stored in the `molecule_gen` folder, each run configuration is stored in a different csv file. Molecules are stored using SMILES strings, along with the desired properties scores. 73 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/GCPN/docking_iter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/GCPN/docking_iter.png -------------------------------------------------------------------------------- /examples/generation/docking_generation/GCPN/download.sh: -------------------------------------------------------------------------------- 1 | 2 | scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/GCPN/result.1.done . 3 | 4 | scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/GCPN/result.2.done . 5 | 6 | scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/GCPN/result.3.done . 7 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/GCPN/gcpn_docking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/GCPN/gcpn_docking.png -------------------------------------------------------------------------------- /examples/generation/docking_generation/GCPN/gym-molecule/README.md: -------------------------------------------------------------------------------- 1 | # Molecule generation environment, compatible with openai gym 2 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/GCPN/gym-molecule/gym_molecule/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id="molecule-v0", 5 | entry_point="gym_molecule.envs:MoleculeEnv", 6 | ) 7 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/GCPN/gym-molecule/gym_molecule/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_molecule.envs.molecule import MoleculeEnv 2 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/GCPN/gym-molecule/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="gym-molecule", 4 | version="0.0.1", 5 | install_requires=["gym>=0.2.3", "pandas"]) 6 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/GCPN/upload.sh: -------------------------------------------------------------------------------- 1 | 2 | scp -r $1 tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/GCPN/ 3 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/README.md: -------------------------------------------------------------------------------- 1 | # MARS: Markov Molecular Sampling for Multi-objective Drug Discovery 2 | 3 | Thanks for your interest! This is the code repository for our ICLR 2021 paper [MARS: Markov Molecular Sampling for Multi-objective Drug Discovery](https://openreview.net/pdf?id=kHSu4ebxFXY). 4 | 5 | ## Dependencies 6 | 7 | The `conda` environment is exported as `environment.yml`. You can also manually install these packages: 8 | 9 | ```bash 10 | conda install -c conda-forge rdkit 11 | conda install tqdm tensorboard scikit-learn 12 | conda install pytorch cudatoolkit=11.1 -c pytorch -c conda-forge 13 | conda install -c dglteam dgl-cuda11.1 14 | 15 | # for cpu only 16 | conda install pytorch cpuonly -c pytorch 17 | conda install -c dglteam dgl 18 | ``` 19 | 20 | ## Run 21 | 22 | > Note: Run the commands **outside** the `MARS` directory. 23 | 24 | To extract molecular fragments from a database: 25 | 26 | 27 | ```bash 28 | python preprocess_zinc.py 29 | ``` 30 | output is `data/zinc.txt` 31 | 32 | ```bash 33 | python -m MARS.datasets.prepro_vocab 34 | ``` 35 | 36 | 37 | 38 | To sample molecules: 39 | 40 | ```bash 41 | cd /project/molecular_data/graphnn/MARS 42 | 43 | source activate MARS 44 | 45 | rm -rf MARS/runs/try; 46 | 47 | python -m MARS.main --run_dir runs/try 48 | ``` 49 | 50 | ## docking 51 | 52 | ```bash 53 | cd /project/molecular_data/graphnn/MARS 54 | 55 | source activate MARS 56 | 57 | export PATH=$PATH:/project/molecular_data/graphnn/docking/ADFRsuite_installed_directory/bin 58 | 59 | export PATH=$PATH:/project/molecular_data/graphnn/docking/autodock_vina_1_1_2_linux_x86/bin 60 | 61 | 62 | ./upload.sh main.py 63 | ./upload.sh sampler.py 64 | ./upload.sh estimator 65 | ``` 66 | 67 | ``` 68 | rm -rf MARS/runs/try4; python -m MARS.main --run_dir runs/try4 69 | ``` 70 | 71 | 72 | 73 | ```python 74 | import pyscreener 75 | from tdc import Oracle 76 | oracle2 = Oracle(name = 'Docking_Score', software='vina', pyscreener_path = './', pdbids=['5WIU'], center=(-18.2, 14.4, -16.1), size=(15.4, 13.9, 14.5), buffer=10, path='./', num_worker=1, ncpu=4) 77 | ``` 78 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/MARS/__init__.py -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/docking_iter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/MARS/docking_iter.png -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/download.sh: -------------------------------------------------------------------------------- 1 | scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/MARS/result* . 2 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/estimator/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from dgl.nn.pytorch.glob import Set2Set 5 | 6 | from ..common.nn import GraphEncoder, MLP 7 | 8 | 9 | class Discriminator(nn.Module): 10 | 11 | def __init__(self, config): 12 | super().__init__() 13 | self.device = config["device"] 14 | self.encoder = GraphEncoder( 15 | config["n_atom_feat"], 16 | config["n_node_hidden"], 17 | config["n_bond_feat"], 18 | config["n_edge_hidden"], 19 | config["n_layers"], 20 | ) 21 | self.set2set = Set2Set(config["n_node_hidden"], n_iters=6, n_layers=2) 22 | self.classifier = MLP(config["n_node_hidden"] * 2, 2) 23 | 24 | def forward(self, g): 25 | with torch.no_grad(): 26 | g = g.to(self.device) 27 | x_node = g.ndata["n_feat"].to(self.device) 28 | x_edge = g.edata["e_feat"].to(self.device) 29 | 30 | h = self.encoder(g, x_node, x_edge) 31 | h = self.set2set(g, h) 32 | h = self.classifier(h) 33 | return h 34 | 35 | def loss(self, batch, metrics=["loss"]): 36 | """ 37 | @params: 38 | batch: batch from the dataset 39 | g: batched dgl.DGLGraph 40 | targs: prediction targets 41 | @returns: 42 | g.batch_size 43 | metric_values: cared metric values for 44 | training and recording 45 | """ 46 | g, targs = batch 47 | targs = targs.to(self.device) 48 | logits = self(g) # (batch_size, 2) 49 | loss = F.cross_entropy(logits, targs) 50 | with torch.no_grad(): 51 | pred = logits.argmax(dim=1) 52 | true = pred == targs 53 | acc = true.float().sum() / g.batch_size 54 | tp = (true * targs).float().sum() 55 | rec = tp / (targs.long().sum() + 1e-6) 56 | prec = tp / (pred.long().sum() + 1e-6) 57 | f1 = 2 * rec * prec / (rec + prec + 1e-6) 58 | local_vars = locals() 59 | 60 | local_vars["loss"] = loss 61 | metric_values = [local_vars[metric] for metric in metrics] 62 | return g.batch_size, metric_values 63 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/estimator/scorer/drd2_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | from rdkit import Chem 4 | from rdkit import rdBase 5 | from rdkit.Chem import AllChem 6 | from rdkit import DataStructs 7 | from sklearn import svm 8 | import pickle 9 | import re 10 | import os.path as op 11 | 12 | rdBase.DisableLog("rdApp.error") 13 | """Scores based on an ECFP classifier for activity.""" 14 | 15 | clf_model = None 16 | 17 | 18 | def load_model(): 19 | global clf_model 20 | name = op.join(op.dirname(__file__), "clf_py36.pkl") 21 | with open(name, "rb") as f: 22 | clf_model = pickle.load(f) 23 | 24 | 25 | def get_scores(mols): 26 | if clf_model is None: 27 | load_model() 28 | 29 | fps = [fingerprints_from_mol(mol) for mol in mols] 30 | fps = np.concatenate(fps, axis=0) 31 | scores = clf_model.predict_proba(fps) 32 | scores = scores[:, 1].tolist() 33 | return scores 34 | 35 | 36 | def fingerprints_from_mol(mol): 37 | fp = AllChem.GetMorganFingerprint(mol, 3, useCounts=True, useFeatures=True) 38 | size = 2048 39 | nfp = np.zeros((1, size), np.int32) 40 | for idx, v in fp.GetNonzeroElements().items(): 41 | nidx = idx % size 42 | nfp[0, nidx] += int(v) 43 | return nfp 44 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/estimator/scorer/eval_scorer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/MARS/estimator/scorer/eval_scorer.py -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/estimator/scorer/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit.Chem import AllChem 3 | from rdkit import DataStructs 4 | 5 | 6 | def fingerprints_from_mol(mol): 7 | fp = AllChem.GetMorganFingerprint(mol, 3, useCounts=True, useFeatures=True) 8 | size = 1024 9 | nfp = np.zeros((size), np.int32) 10 | for idx, v in fp.GetNonzeroElements().items(): 11 | nidx = idx % size 12 | nfp[nidx] += int(v) 13 | return nfp 14 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/preprocess_zinc.py: -------------------------------------------------------------------------------- 1 | input_file = "data/zinc.tab" 2 | output_file = "data/zinc.txt" 3 | 4 | with open(input_file, "r") as fin, open(output_file, "w") as fout: 5 | lines = fin.readlines() 6 | lines = lines[1:] 7 | for line in lines: 8 | line = line.strip() 9 | line = line[1:-1] 10 | fout.write(line + "\n") 11 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/proposal/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .editor_basic import BasicEditor 2 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/result/100.txt: -------------------------------------------------------------------------------- 1 | CCc1ccccc1C(F)(F)F 6.5 2 | CCOc1ccc(C(=N)N)cc1 6.3 3 | CCC(=O)Nc1cccc(C)c1 6.3 4 | CCC(O)c1cccc(Cl)c1 6.3 5 | CCNC(=O)c1ccc(Cl)cc1 6.2 6 | CCCOc1ccc(Cl)cc1Cl 6.0 7 | CCC=C(C)CCC=C(C)C 6.0 8 | CCc1ccc(C(C)=O)cc1 6.0 9 | CCc1c(F)cccc1Cl 5.6 10 | CCC(=O)NCc1ccco1 5.6 11 | CCOc1ccc(Cl)cc1 5.5 12 | CCNCCCN1CCCC1=O 5.4 13 | CCC1CCN(C)CC1 5.0 14 | CCc1ccc(Br)s1 4.7 15 | CCC1SC(=O)NC1=O 4.7 16 | CCCN(CCC)CCC 4.6 17 | CCNC(=O)CC(C)C 4.6 18 | CCOc1cccnc1 4.6 19 | CCOC(=O)C(C)(C)C 4.4 20 | CCC(=O)C1CC1 4.3 21 | CCc1ccoc1 4.3 22 | CCC1CCCN1C 4.3 23 | CCC(=O)NC(C)C 4.2 24 | CCOCCN(CC)CC 4.1 25 | CCOC(=O)CSCC 4.0 26 | CCCS(=O)(=O)O 3.8 27 | CCOP(=O)(O)O 3.7 28 | C=CCSCC 3.5 29 | CCF 2.3 30 | CC 1.8 31 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/result/150.txt: -------------------------------------------------------------------------------- 1 | CC(c1ccoc1)c1cccc(F)c1 7.1 2 | CCc1ccccc1C(F)(F)F 6.5 3 | CCOP(=O)(O)Oc1ccc(C(=N)N)cc1 6.5 4 | CCOc1ccc(C(=N)N)cc1 6.3 5 | CCC(=O)Nc1cccc(C)c1 6.3 6 | CCC(O)c1cccc(Cl)c1 6.3 7 | CCC1CCN(C)C(Nc2cc(C(C)(C)C)on2)C1 6.3 8 | CCNC(=O)c1ccc(Cl)cc1 6.2 9 | C=CCSC(C)c1c[nH]c2ncccc12 6.2 10 | CCNCC(CN1CCCC1=O)N(CCO)CCO 6.2 11 | CCc1cc(F)c(F)cc1F 6.1 12 | CCNc1ccc2[nH]ncc2c1 6.1 13 | CCCOc1ccc(Cl)cc1Cl 6.0 14 | CCC=C(C)CCC=C(C)C 6.0 15 | CCc1ccc(C(C)=O)cc1 6.0 16 | CCNCc1ccccc1F 5.9 17 | CCOC(=O)C(C)(C)Cc1ccc(OC)cc1O 5.9 18 | CCOc1cccnc1OCC(O)CNC(C)C 5.9 19 | CCc1ccc(OC(C)=O)cc1 5.8 20 | CCCOc1ccccc1Cl 5.7 21 | N=C(N)c1ccccc1 5.7 22 | CCc1c(F)cccc1Cl 5.6 23 | CCC(=O)NCc1ccco1 5.6 24 | CC=CCCC(C)=CCC 5.6 25 | CCS(=O)(=O)c1ccccc1 5.6 26 | CCC1NC2CCC1C2 5.6 27 | CCOc1ccc(Cl)cc1 5.5 28 | CCNCCCN1CCCC1=O 5.4 29 | CCNc1cccc(O)c1 5.4 30 | CCNc1cccc(C)n1 5.3 31 | CC(F)C[n+]1ccccc1 5.1 32 | CCC1CCN(C)CC1 5.0 33 | CCc1ccc(Br)s1 4.7 34 | CCC1SC(=O)NC1=O 4.7 35 | CCC(=O)NCCCOC 4.7 36 | Clc1ccccc1 4.7 37 | CCCN(CCC)CCC 4.6 38 | CCNC(=O)CC(C)C 4.6 39 | CCOc1cccnc1 4.6 40 | CCC1CCCN1 4.5 41 | CCOC(=O)C(C)(C)C 4.4 42 | CCC(=O)C1CC1 4.3 43 | CCc1ccoc1 4.3 44 | CCC1CCCN1C 4.3 45 | CCC(=O)NCCO 4.3 46 | CCOC(=O)CSC(C)[O-] 4.3 47 | CCC(=O)NC(C)C 4.2 48 | CCOCCN(CC)CC 4.1 49 | CCCN(C)CCC 4.1 50 | CCOC(=O)CSCC 4.0 51 | CCCS(=O)(=O)O 3.8 52 | CC(=O)NC(C)C 3.8 53 | CCNCC(=O)O 3.8 54 | CCOP(=O)(O)O 3.7 55 | C=CCSCC 3.5 56 | CCF 2.3 57 | CC 1.8 58 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/result/200.txt: -------------------------------------------------------------------------------- 1 | CCOCCN(CC)CCN1CCN(S(C)(=O)=O)CC1 8.7 2 | CCNc1ccc2[nH]nc(Nc3ccccc3OCC)c2c1 8.4 3 | N=C(N)c1ccccc1-c1ccc(Cl)c(Cl)c1 8.0 4 | CCC1(c2ccc(NC(C)=O)cc2)NC2CCC1C2 7.6 5 | CC(NCc1ccccc1)Oc1ccc(Cl)cc1 7.3 6 | CC(c1ccoc1)c1cccc(F)c1 7.1 7 | O=C(CCc1cc2ccccc2[nH]1)NCCO 7.0 8 | C=CC(SCC(N)=O)SC(C)c1c[nH]c2ncccc12 6.9 9 | CCOP(=O)(O)Oc1ccc(C(=N)N)cc1C(O)CC(O)CC(=O)O 6.6 10 | CCc1ccccc1C(F)(F)F 6.5 11 | CCOP(=O)(O)Oc1ccc(C(=N)N)cc1 6.5 12 | CCNC(=O)c1ccc(Cl)cc1[O-] 6.4 13 | CCOc1ccc(C(=N)N)cc1 6.3 14 | CCC(=O)Nc1cccc(C)c1 6.3 15 | CCC(O)c1cccc(Cl)c1 6.3 16 | CCC1CCN(C)C(Nc2cc(C(C)(C)C)on2)C1 6.3 17 | COCCCNC(=O)C(C)Oc1ccc(Cl)c(Cl)c1 6.3 18 | CCNC(=O)c1ccc(Cl)cc1 6.2 19 | C=CCSC(C)c1c[nH]c2ncccc12 6.2 20 | CCNCC(CN1CCCC1=O)N(CCO)CCO 6.2 21 | CCc1cc(F)c(F)cc1F 6.1 22 | CCNc1ccc2[nH]ncc2c1 6.1 23 | C=CCOc1ccc(NCC)cc1O 6.1 24 | CCCOc1ccc(Cl)cc1Cl 6.0 25 | CCC=C(C)CCC=C(C)C 6.0 26 | CCc1ccc(C(C)=O)cc1 6.0 27 | CCC(O)CCS(=O)(=O)c1ccccc1 6.0 28 | CCNCc1ccccc1F 5.9 29 | CCOC(=O)C(C)(C)Cc1ccc(OC)cc1O 5.9 30 | CCOc1cccnc1OCC(O)CNC(C)C 5.9 31 | CC(=O)N(c1ccccc1O)C(C)C 5.9 32 | CCNC(=O)C1CCCCC1 5.9 33 | CCc1ccc(OC(C)=O)cc1 5.8 34 | CCCOc1ccccc1Cl 5.7 35 | N=C(N)c1ccccc1 5.7 36 | CCNCCC1=CCCCC1 5.7 37 | CC=CC(CC(C)=CCC)C(=O)N(CCC)CCC 5.7 38 | CCc1c(F)cccc1Cl 5.6 39 | CCC(=O)NCc1ccco1 5.6 40 | CC=CCCC(C)=CCC 5.6 41 | CCS(=O)(=O)c1ccccc1 5.6 42 | CCC1NC2CCC1C2 5.6 43 | CCOc1ccc(Cl)cc1 5.5 44 | Fc1ccc(F)c(F)c1 5.5 45 | CCNCCCN1CCCC1=O 5.4 46 | CCNc1cccc(O)c1 5.4 47 | CCNc1cccc(C)n1 5.3 48 | CC(F)C[n+]1ccccc1 5.1 49 | CCC1CCN(C)CC1 5.0 50 | CCc1ccc(Br)s1 4.7 51 | CCC1SC(=O)NC1=O 4.7 52 | CCC(=O)NCCCOC 4.7 53 | Clc1ccccc1 4.7 54 | CCCN(CCC)CCC 4.6 55 | CCNC(=O)CC(C)C 4.6 56 | CCOc1cccnc1 4.6 57 | CCC1CCCN1 4.5 58 | CCOC(=O)C(C)(C)C 4.4 59 | CCC(=O)C1CC1 4.3 60 | CCc1ccoc1 4.3 61 | CCC1CCCN1C 4.3 62 | CCC(=O)NCCO 4.3 63 | CCOC(=O)CSC(C)[O-] 4.3 64 | CCC(=O)NC(C)C 4.2 65 | CCOCCN(CC)CC 4.1 66 | CCCN(C)CCC 4.1 67 | CCOC(=O)CSCC 4.0 68 | CCCS(=O)(=O)O 3.8 69 | CC(=O)NC(C)C 3.8 70 | CCNCC(=O)O 3.8 71 | CC(=O)C1CC1 3.8 72 | CCC1CCN(C)C(N(C(=O)N2CC(C)OC(C)C2)c2cc(C(C)(C)C)on2)C1 3.8 73 | CCOP(=O)(O)O 3.7 74 | C=CCSCC 3.5 75 | CCNO 3.0 76 | CS(=O)(=O)O 2.9 77 | CCC 2.5 78 | CCN 2.4 79 | CCF 2.3 80 | CC 1.8 81 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/MARS/upload.sh: -------------------------------------------------------------------------------- 1 | 2 | scp -r $1 tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/MARS/MARS 3 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/download.sh: -------------------------------------------------------------------------------- 1 | 2 | # rm -r results/smiles_lstm_hc_1 results/smiles_lstm_hc_2 results/smiles_lstm_hc_3 c 3 | # scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/smiles_lstm_hc/results.run.1 ./results/smiles_lstm_hc_1 4 | # scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/smiles_lstm_hc/results.run.2 ./results/smiles_lstm_hc_2 5 | # scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/smiles_lstm_hc/results.run.3 ./results/smiles_lstm_hc_3 6 | 7 | rm -r results/graph_ga_1 results/graph_ga_2 results/graph_ga_3 8 | scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/graph_ga/results ./results/graph_ga_1 9 | scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/graph_ga/results.2 ./results/graph_ga_2 10 | scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/graph_ga/results.3 ./results/graph_ga_3 11 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/fix_99_to_100.py: -------------------------------------------------------------------------------- 1 | from random import shuffle 2 | 3 | for i in range(1, 4): 4 | file_1 = "results/graph_ga_" + str(i) + "_100.txt" 5 | file_2 = "results/graph_ga_" + str(i) + "_200.txt" 6 | 7 | with open(file_1, "r") as fin: 8 | lines = fin.readlines() 9 | smiles2value = { 10 | line.split()[0]: float(line.strip().split()[1]) for line in lines 11 | } 12 | with open(file_2, "r") as fin: 13 | lines = fin.readlines() 14 | shuffle(lines) 15 | s2v2 = {line.split()[0]: float(line.strip().split()[1]) for line in lines} 16 | for s, v in s2v2.items(): 17 | if s not in smiles2value: 18 | smiles2value[s] = v 19 | break 20 | smiles_value_lst = [ 21 | (smiles, value) for smiles, value in smiles2value.items() 22 | ] 23 | smiles_value_lst.sort(key=lambda x: x[1]) 24 | with open(file_1, "w") as fout: 25 | for smiles, value in smiles_value_lst: 26 | fout.write(smiles + "\t" + str(value) + "\n") 27 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, E731 3 | exclude = .git,__pycache__ 4 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/.gitignore: -------------------------------------------------------------------------------- 1 | *pyc 2 | .idea 3 | .cache 4 | .mypy_cache 5 | .pytest_cache 6 | __pycache__ 7 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | # Actually not needed since rdkit installation with conda overwrites it 5 | - "3.6" 6 | 7 | before_install: 8 | # download and install miniconda 9 | - wget http://repo.continuum.io/miniconda/Miniconda3-4.1.11-Linux-x86_64.sh -O miniconda.sh; 10 | - bash miniconda.sh -b -p $HOME/conda 11 | - export PATH="$HOME/conda/bin:$PATH" 12 | - conda config --set always_yes yes --set changeps1 no 13 | - conda update -q conda 14 | 15 | - conda create -n test_env python=$TRAVIS_PYTHON_VERSION pip cmake 16 | - source activate test_env 17 | 18 | install: 19 | # install the most recent rdkit package from the RDKit anaconda channel. 20 | - conda install -q -c rdkit rdkit 21 | - pip install -r requirements.txt 22 | 23 | script: 24 | # Style guide enforcement 25 | - flake8 guacamol && flake8 tests 26 | # Static typing enforcement 27 | - mypy guacamol && mypy tests 28 | # Test suite 29 | - python -m pytest tests 30 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 BenevolentAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include the file holdout SMILES strings when the guacamol package is generated 2 | include guacamol/data/holdout_set_gcm_v1.smiles 3 | # Marker file to say that guacamol supports type checking 4 | include guacamol/py.typed 5 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/dockers/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | build-essential \ 5 | cmake ca-certificates \ 6 | libglib2.0-0 libxext6 libsm6 libxrender1 \ 7 | wget \ 8 | curl \ 9 | bash \ 10 | bzip2 \ 11 | && \ 12 | apt-get clean && \ 13 | rm -rf /var/lib/apt/lists/* 14 | 15 | # MiniConda 16 | RUN curl -LO --silent https://repo.continuum.io/miniconda/Miniconda3-4.5.11-Linux-x86_64.sh && \ 17 | bash Miniconda3-4.5.11-Linux-x86_64.sh -p /miniconda -b && \ 18 | rm Miniconda3-4.5.11-Linux-x86_64.sh 19 | 20 | ENV PATH=/miniconda/bin:${PATH} 21 | 22 | # RDKit 23 | RUN conda install -y -q -c rdkit rdkit=2018.09.1.0 24 | 25 | # python deps 26 | RUN pip install joblib \ 27 | tqdm \ 28 | scipy 29 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5.3" 2 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/distribution_matching_generator.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import List 3 | 4 | 5 | class DistributionMatchingGenerator(metaclass=ABCMeta): 6 | """ 7 | Interface for molecule generators. 8 | """ 9 | 10 | @abstractmethod 11 | def generate(self, number_samples: int) -> List[str]: 12 | """ 13 | Samples SMILES strings from a molecule generator. 14 | 15 | Args: 16 | number_samples: number of molecules to generate 17 | 18 | Returns: 19 | A list of SMILES strings. 20 | """ 21 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/goal_directed_generator.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import List, Optional 3 | 4 | from guacamol.scoring_function import ScoringFunction 5 | 6 | 7 | class GoalDirectedGenerator(metaclass=ABCMeta): 8 | """ 9 | Interface for goal-directed molecule generators. 10 | """ 11 | 12 | @abstractmethod 13 | def generate_optimized_molecules( 14 | self, 15 | scoring_function: ScoringFunction, 16 | number_molecules: int, 17 | starting_population: Optional[List[str]] = None, 18 | ) -> List[str]: 19 | """ 20 | Given an objective function, generate molecules that score as high as possible. 21 | 22 | Args: 23 | scoring_function: scoring function 24 | number_molecules: number of molecules to generate 25 | starting_population: molecules to start the optimization from (optional) 26 | 27 | Returns: 28 | A list of SMILES strings for the generated molecules. 29 | """ 30 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/py.typed -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/utils/__init__.py -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/utils/descriptors.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from rdkit.Chem import Descriptors, Mol, rdMolDescriptors 3 | 4 | 5 | def logP(mol: Mol) -> float: 6 | return Descriptors.MolLogP(mol) 7 | 8 | 9 | def qed(mol: Mol) -> float: 10 | return Descriptors.qed(mol) 11 | 12 | 13 | def tpsa(mol: Mol) -> float: 14 | return Descriptors.TPSA(mol) 15 | 16 | 17 | def bertz(mol: Mol) -> float: 18 | return Descriptors.BertzCT(mol) 19 | 20 | 21 | def mol_weight(mol: Mol) -> float: 22 | return Descriptors.MolWt(mol) 23 | 24 | 25 | def num_H_donors(mol: Mol) -> int: 26 | return Descriptors.NumHDonors(mol) 27 | 28 | 29 | def num_H_acceptors(mol: Mol) -> int: 30 | return Descriptors.NumHAcceptors(mol) 31 | 32 | 33 | def num_rotatable_bonds(mol: Mol) -> int: 34 | return Descriptors.NumRotatableBonds(mol) 35 | 36 | 37 | def num_rings(mol: Mol) -> int: 38 | return rdMolDescriptors.CalcNumRings(mol) 39 | 40 | 41 | def num_aromatic_rings(mol: Mol) -> int: 42 | return rdMolDescriptors.CalcNumAromaticRings(mol) 43 | 44 | 45 | def num_atoms(mol: Mol) -> int: 46 | """ 47 | Returns the total number of atoms, H included 48 | """ 49 | mol = Chem.AddHs(mol) 50 | return mol.GetNumAtoms() 51 | 52 | 53 | class AtomCounter: 54 | 55 | def __init__(self, element: str) -> None: 56 | """ 57 | Args: 58 | element: element to count within a molecule 59 | """ 60 | self.element = element 61 | 62 | def __call__(self, mol: Mol) -> int: 63 | """ 64 | Count the number of atoms of a given type. 65 | 66 | Args: 67 | mol: molecule 68 | 69 | Returns: 70 | The number of atoms of the given type. 71 | """ 72 | # if the molecule contains H atoms, they may be implicit, so add them 73 | if self.element == "H": 74 | mol = Chem.AddHs(mol) 75 | 76 | return sum(1 for a in mol.GetAtoms() if a.GetSymbol() == self.element) 77 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/utils/fingerprints.py: -------------------------------------------------------------------------------- 1 | from rdkit.Chem import AllChem, Mol 2 | from rdkit.Chem.AtomPairs.Sheridan import GetBPFingerprint, GetBTFingerprint 3 | from rdkit.Chem.Pharm2D import Generate, Gobbi_Pharm2D 4 | 5 | 6 | class _FingerprintCalculator: 7 | """ 8 | Calculate the fingerprint while avoiding a series of if-else. 9 | See recipe 8.21 of the book "Python Cookbook". 10 | 11 | To support a new type of fingerprint, just add a function "get_fpname(self, mol)". 12 | """ 13 | 14 | def get_fingerprint(self, mol: Mol, fp_type: str): 15 | method_name = "get_" + fp_type 16 | method = getattr(self, method_name) 17 | if method is None: 18 | raise Exception(f"{fp_type} is not a supported fingerprint type.") 19 | return method(mol) 20 | 21 | def get_AP(self, mol: Mol): 22 | return AllChem.GetAtomPairFingerprint(mol, maxLength=10) 23 | 24 | def get_PHCO(self, mol: Mol): 25 | return Generate.Gen2DFingerprint(mol, Gobbi_Pharm2D.factory) 26 | 27 | def get_BPF(self, mol: Mol): 28 | return GetBPFingerprint(mol) 29 | 30 | def get_BTF(self, mol: Mol): 31 | return GetBTFingerprint(mol) 32 | 33 | def get_PATH(self, mol: Mol): 34 | return AllChem.RDKFingerprint(mol) 35 | 36 | def get_ECFP4(self, mol: Mol): 37 | return AllChem.GetMorganFingerprint(mol, 2) 38 | 39 | def get_ECFP6(self, mol: Mol): 40 | return AllChem.GetMorganFingerprint(mol, 3) 41 | 42 | def get_FCFP4(self, mol: Mol): 43 | return AllChem.GetMorganFingerprint(mol, 2, useFeatures=True) 44 | 45 | def get_FCFP6(self, mol: Mol): 46 | return AllChem.GetMorganFingerprint(mol, 3, useFeatures=True) 47 | 48 | 49 | def get_fingerprint(mol: Mol, fp_type: str): 50 | return _FingerprintCalculator().get_fingerprint(mol=mol, fp_type=fp_type) 51 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_default_logger(): 5 | """ 6 | Call this function in your main function to initialize a basic logger. 7 | 8 | To have more control on the format or level, call `logging.basicConfig()` directly instead. 9 | 10 | If you don't initialize any logger, log entries from the guacamol package will not appear anywhere. 11 | """ 12 | logging.basicConfig(format="%(levelname)s : %(message)s", 13 | level=logging.INFO) 14 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/guacamol/utils/math.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | 6 | def arithmetic_mean(values: List[float]) -> float: 7 | """ 8 | Computes the arithmetic mean of a list of values. 9 | """ 10 | return sum(values) / len(values) 11 | 12 | 13 | def geometric_mean(values: List[float]) -> float: 14 | """ 15 | Computes the geometric mean of a list of values. 16 | """ 17 | a = np.array(values) 18 | return a.prod()**(1.0 / len(a)) 19 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | 3 | [mypy-fcd.*] 4 | ignore_missing_imports = True 5 | 6 | [mypy-joblib.*] 7 | ignore_missing_imports = True 8 | 9 | [mypy-numpy.*] 10 | ignore_missing_imports = True 11 | 12 | [mypy-pytest.*] 13 | ignore_missing_imports = True 14 | 15 | [mypy-rdkit.*] 16 | ignore_missing_imports = True 17 | 18 | [mypy-scipy.*] 19 | ignore_missing_imports = True 20 | 21 | [mypy-tqdm.*] 22 | ignore_missing_imports = True 23 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/requirements.txt: -------------------------------------------------------------------------------- 1 | # Needed for library use 2 | 3 | joblib>=0.12.5 4 | numpy>=1.15.2 5 | scipy>=1.1.0 6 | tqdm>=4.26.0 7 | FCD==1.1 8 | # FCD doesn't pin the tensorflow and Keras dependencies, so we have to do the honours 9 | tensorflow==1.15.4 10 | Keras==2.1.0 11 | h5py==2.10.0 12 | # rdkit is also required and is best installed with conda: 13 | # conda install rdkit -c rdkit 14 | 15 | 16 | # Needed for testing, linting, etc 17 | 18 | flake8>=3.5.0 19 | mypy>=0.630 20 | pytest>=3.8.2 21 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import re 3 | from os import path 4 | 5 | from setuptools import setup 6 | 7 | # Get the version from guacamol/__init__.py 8 | # Adapted from https://stackoverflow.com/a/39671214 9 | __version__ = re.search( 10 | r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', 11 | io.open("guacamol/__init__.py", encoding="utf_8_sig").read(), 12 | ).group(1) 13 | 14 | this_directory = path.abspath(path.dirname(__file__)) 15 | with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: 16 | long_description = f.read() 17 | 18 | setup( 19 | name="guacamol", 20 | version=__version__, 21 | author="BenevolentAI", 22 | author_email="guacamol@benevolent.ai", 23 | description="Guacamol: benchmarks for de novo molecular design", 24 | long_description=long_description, 25 | long_description_content_type="text/markdown", 26 | url="https://github.com/BenevolentAI/guacamol", 27 | packages=["guacamol", "guacamol.data", "guacamol.utils"], 28 | license="MIT", 29 | install_requires=[ 30 | "joblib>=0.12.5", 31 | "numpy>=1.15.2", 32 | "scipy>=1.1.0", 33 | "tqdm>=4.26.0", 34 | "FCD==1.1", 35 | # FCD doesn't pin the tensorflow and Keras dependencies, so we have to do the honours 36 | "tensorflow==1.15.4", 37 | "Keras==2.1.0", 38 | "h5py==2.10.0", 39 | ], 40 | python_requires=">=3.6", 41 | extras_require={ 42 | "rdkit": ["rdkit>=2018.09.1.0"], 43 | }, 44 | include_package_data=True, 45 | zip_safe=False, 46 | ) 47 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol/upload.sh: -------------------------------------------------------------------------------- 1 | 2 | scp -r $1 tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/guacamol_tdc/guacamol 3 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, E731 3 | exclude = .git,__pycache__ 4 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 BenevolentAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/best_from_chembl/chembl_file_reader.py: -------------------------------------------------------------------------------- 1 | class ChemblFileReader: 2 | """ 3 | This class can repeatedly generate an iterator for iterating over the content of a file containing SMILES strings. 4 | """ 5 | 6 | def __init__(self, smiles_file_path: str): 7 | """ 8 | Args: 9 | smiles_file_path: Path of a file containing a list of SMILES strings. 10 | """ 11 | self.smiles_file_path = smiles_file_path 12 | 13 | def __iter__(self): 14 | with open(self.smiles_file_path) as f: 15 | for line in f: 16 | yield line.strip() 17 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/best_from_chembl/goal_directed_generation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from guacamol.assess_goal_directed_generation import assess_goal_directed_generation 5 | from guacamol.utils.helpers import setup_default_logger 6 | 7 | from .chembl_file_reader import ChemblFileReader 8 | from .optimizer import BestFromChemblOptimizer 9 | 10 | if __name__ == "__main__": 11 | setup_default_logger() 12 | 13 | parser = argparse.ArgumentParser( 14 | description= 15 | "Goal-directed benchmark for best molecules from SMILES file", 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 17 | ) 18 | parser.add_argument("--smiles_file", default="data/guacamol_v1_all.smiles") 19 | parser.add_argument("--output_dir", default=None, help="Output directory") 20 | parser.add_argument("--n_jobs", type=int, default=-1) 21 | parser.add_argument("--suite", default="v3") 22 | 23 | args = parser.parse_args() 24 | 25 | if args.output_dir is None: 26 | args.output_dir = os.path.dirname(os.path.realpath(__file__)) 27 | 28 | smiles_reader = ChemblFileReader(args.smiles_file) 29 | 30 | optimizer = BestFromChemblOptimizer(smiles_reader=smiles_reader, 31 | n_jobs=args.n_jobs) 32 | 33 | json_file_path = os.path.join(args.output_dir, "goal_directed_results.json") 34 | 35 | assess_goal_directed_generation(optimizer, 36 | json_output_file=json_file_path, 37 | benchmark_version=args.suite) 38 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/best_from_chembl/optimizer.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | from typing import List, Optional, Tuple 3 | 4 | import joblib 5 | from guacamol.goal_directed_generator import GoalDirectedGenerator 6 | from guacamol.scoring_function import ScoringFunction 7 | from joblib import delayed 8 | 9 | from guacamol.scoring_function import max_oracle_num 10 | 11 | from .chembl_file_reader import ChemblFileReader 12 | 13 | 14 | class BestFromChemblOptimizer(GoalDirectedGenerator): 15 | """ 16 | Goal-directed molecule generator that will simply look for the most adequate molecules present in a file. 17 | """ 18 | 19 | def __init__(self, smiles_reader: ChemblFileReader, n_jobs: int) -> None: 20 | self.pool = joblib.Parallel(n_jobs=n_jobs) 21 | # get a list of all the smiles 22 | self.smiles = [s for s in smiles_reader] 23 | 24 | ### limit oracle calls. 25 | from random import shuffle 26 | 27 | shuffle(self.smiles) 28 | self.smiles = self.smiles[:max_oracle_num] 29 | 30 | def top_k(self, smiles, scoring_function, k): 31 | joblist = (delayed(scoring_function.score)(s) for s in smiles) 32 | scores = self.pool(joblist) 33 | scored_smiles = list(zip(scores, smiles)) 34 | scored_smiles = sorted(scored_smiles, key=lambda x: x[0], reverse=True) 35 | return [smile for score, smile in scored_smiles][:k] 36 | 37 | def generate_optimized_molecules( 38 | self, 39 | scoring_function: ScoringFunction, 40 | number_molecules: int, 41 | starting_population: Optional[List[str]] = None, 42 | ) -> List[str]: 43 | """ 44 | Will iterate through the reference set of SMILES strings and select the best molecules. 45 | """ 46 | return self.top_k(self.smiles, scoring_function, number_molecules) 47 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/dockers/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-cudnn7-runtime-ubuntu16.04 2 | 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | build-essential \ 5 | cmake ca-certificates \ 6 | libglib2.0-0 libxext6 libsm6 libxrender1 \ 7 | wget \ 8 | curl \ 9 | bash \ 10 | bzip2 \ 11 | git \ 12 | && \ 13 | apt-get clean && \ 14 | rm -rf /var/lib/apt/lists/* 15 | 16 | # MiniConda 17 | RUN curl -LO --silent https://repo.continuum.io/miniconda/Miniconda3-4.5.11-Linux-x86_64.sh && \ 18 | bash Miniconda3-4.5.11-Linux-x86_64.sh -p /miniconda -b && \ 19 | rm Miniconda3-4.5.11-Linux-x86_64.sh 20 | 21 | ENV PATH=/miniconda/bin:${PATH} 22 | 23 | # RDKit 24 | RUN conda install -y -q -c rdkit rdkit=2018.09.1.0 25 | 26 | # python deps 27 | COPY requirements.txt /tmp 28 | RUN pip install --upgrade pip 29 | RUN pip install --no-cache-dir -r /tmp/requirements.txt 30 | 31 | # Add the source code 32 | RUN mkdir -p /app 33 | ADD . /app 34 | 35 | # Launch inside the folder 36 | WORKDIR /app/ 37 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/dockers/requirements.txt: -------------------------------------------------------------------------------- 1 | guacamol==0.5.3 2 | matplotlib==3.0.2 3 | torch>=1.0.0 4 | joblib==0.12.5 5 | numpy==1.15.2 6 | tqdm==4.26.0 7 | cython==0.29 8 | nltk==3.4.5 9 | flake8==3.5.0 10 | git+https://github.com/molecularsets/moses.git@master#egg=moses 11 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/fetch_guacamol_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | 3 | mkdir -p data 4 | 5 | wget https://ndownloader.figshare.com/files/13612745 -O data/guacamol_v1_all.smiles 6 | 7 | wget https://ndownloader.figshare.com/files/13612760 -O data/guacamol_v1_train.smiles 8 | 9 | wget https://ndownloader.figshare.com/files/13612766 -O data/guacamol_v1_valid.smiles 10 | 11 | wget https://ndownloader.figshare.com/files/13612757 -O data/guacamol_v1_test.smiles 12 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/p1.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/p1.p -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/p_ring.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/p_ring.p -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/r_s1.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/r_s1.p -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/rs_make_ring.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/rs_make_ring.p -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/rs_ring.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/rs_ring.p -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/size_stats.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/graph_mcts/size_stats.p -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/moses_baselines/README.md: -------------------------------------------------------------------------------- 1 | # Wrappers for baseline models from MOSES 2 | 3 | MOSES (https://github.com/molecularsets/moses) provides baseline implementations of a series of generative models: 4 | * VAE 5 | * ORGAN 6 | * AAE 7 | * JT-VAE 8 | * charRNN 9 | 10 | Here, we wrap those models for applying the GuacaMol benchmarks. 11 | Since MOSES does not consider goal-directed optimization, only distribution-learning benchmarks are done with these models. 12 | 13 | JT-VAE is not included because training it with the ChEMBL dataset leads to errors. 14 | Also, `charRNN` is not included because it is similar to the SMILES LSTM model. 15 | 16 | 17 | ## Execution 18 | 19 | Execute the following commands from the root of the repository. 20 | Replace `--device cpu` by `--device cuda` if your machine is GPU-enabled. 21 | 22 | ### AAE 23 | 24 | Train: 25 | ```bash 26 | python -m moses_baselines.aae_train --device cpu --train_load data/guacamol_v1_train.smiles 27 | ``` 28 | 29 | Benchmark: 30 | ```bash 31 | python -m moses_baselines.aae_distribution_learning --device cpu --n_samples 0 32 | ``` 33 | 34 | ### VAE 35 | 36 | Train: 37 | ```bash 38 | python -m moses_baselines.vae_train --device cpu --train_load data/guacamol_v1_train.smiles 39 | ``` 40 | 41 | Benchmark: 42 | ```bash 43 | python -m moses_baselines.vae_distribution_learning --device cpu --n_samples 0 44 | ``` 45 | 46 | ### ORGAN 47 | 48 | Train: 49 | ```bash 50 | python -m moses_baselines.organ_train --device cpu --train_load data/guacamol_v1_train.smiles 51 | ``` 52 | 53 | Benchmark: 54 | ```bash 55 | python -m moses_baselines.organ_distribution_learning --device cpu --n_samples 0 56 | ``` 57 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/moses_baselines/aae_train.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/molecularsets/moses/blob/master/scripts/aae/train.py 2 | 3 | import torch 4 | from guacamol.utils.helpers import setup_default_logger 5 | 6 | from moses.aae import AAE, AAETrainer, get_parser as aae_parser 7 | from moses.script_utils import add_train_args, read_smiles_csv, set_seed 8 | from moses.utils import CharVocab 9 | 10 | from moses_baselines.common import read_smiles 11 | 12 | 13 | def get_parser(): 14 | return add_train_args(aae_parser()) 15 | 16 | 17 | def main(config): 18 | setup_default_logger() 19 | 20 | set_seed(config.seed) 21 | 22 | train = read_smiles(config.train_load) 23 | 24 | vocab = CharVocab.from_data(train) 25 | torch.save(config, config.config_save) 26 | torch.save(vocab, config.vocab_save) 27 | 28 | device = torch.device(config.device) 29 | 30 | model = AAE(vocab, config) 31 | model = model.to(device) 32 | 33 | trainer = AAETrainer(config) 34 | trainer.fit(model, train) 35 | 36 | model.to("cpu") 37 | torch.save(model.state_dict(), config.model_save) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = get_parser() 42 | config = parser.parse_known_args()[0] 43 | main(config) 44 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/moses_baselines/common.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def read_smiles(smiles_file: str) -> List[str]: 5 | with open(smiles_file, "r") as f: 6 | smiles_list = [line.strip() for line in f.readlines()] 7 | return smiles_list 8 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/moses_baselines/organ_train.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/molecularsets/moses/blob/master/scripts/organ/train.py 2 | 3 | import torch 4 | import rdkit 5 | 6 | from moses.organ import ORGAN, ORGANTrainer, get_parser as organ_parser 7 | from moses.script_utils import add_train_args, set_seed, MetricsReward 8 | from moses.utils import CharVocab 9 | from multiprocessing import Pool 10 | 11 | from moses_baselines.common import read_smiles 12 | 13 | lg = rdkit.RDLogger.logger() 14 | lg.setLevel(rdkit.RDLogger.CRITICAL) 15 | 16 | 17 | def get_parser(): 18 | parser = add_train_args(organ_parser()) 19 | 20 | parser.add_argument( 21 | "--n_ref_subsample", 22 | type=int, 23 | default=500, 24 | help="Number of reference molecules (sampling from training data)", 25 | ) 26 | parser.add_argument( 27 | "--addition_rewards", 28 | nargs="+", 29 | type=str, 30 | choices=MetricsReward.supported_metrics, 31 | default=[], 32 | help="Adding of addition rewards", 33 | ) 34 | 35 | return parser 36 | 37 | 38 | def main(config): 39 | set_seed(config.seed) 40 | 41 | train = read_smiles(config.train_load) 42 | vocab = CharVocab.from_data(train) 43 | device = torch.device(config.device) 44 | 45 | with Pool(config.n_jobs) as pool: 46 | reward_func = MetricsReward( 47 | train, 48 | config.n_ref_subsample, 49 | config.rollouts, 50 | pool, 51 | config.addition_rewards, 52 | ) 53 | model = ORGAN(vocab, config, reward_func) 54 | model = model.to(device) 55 | 56 | trainer = ORGANTrainer(config) 57 | trainer.fit(model, train) 58 | 59 | torch.save(model.state_dict(), config.model_save) 60 | torch.save(config, config.config_save) 61 | torch.save(vocab, config.vocab_save) 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = get_parser() 66 | config = parser.parse_known_args()[0] 67 | main(config) 68 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/moses_baselines/vae_train.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/molecularsets/moses/blob/master/scripts/vae/train.py 2 | 3 | import torch 4 | 5 | from moses.script_utils import add_train_args, set_seed 6 | from moses.vae.config import get_parser as vae_parser 7 | from moses.vae.corpus import OneHotCorpus 8 | from moses.vae.model import VAE 9 | from moses.vae.trainer import VAETrainer 10 | 11 | from moses_baselines.common import read_smiles 12 | 13 | 14 | def get_parser(): 15 | return add_train_args(vae_parser()) 16 | 17 | 18 | def main(config): 19 | set_seed(config.seed) 20 | 21 | train = read_smiles(config.train_load) 22 | 23 | device = torch.device(config.device) 24 | 25 | # For CUDNN to work properly: 26 | if device.type.startswith("cuda"): 27 | torch.cuda.set_device(device.index or 0) 28 | 29 | corpus = OneHotCorpus(config.n_batch, device) 30 | train = corpus.fit(train).transform(train) 31 | 32 | model = VAE(corpus.vocab, config).to(device) 33 | 34 | trainer = VAETrainer(config) 35 | 36 | torch.save(config, config.config_save) 37 | torch.save(corpus.vocab, config.vocab_save) 38 | trainer.fit(model, train) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = get_parser() 43 | config = parser.parse_known_args()[0] 44 | main(config) 45 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/random_smiles_sampler/distribution_learning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from guacamol.assess_distribution_learning import assess_distribution_learning 5 | from guacamol.utils.helpers import setup_default_logger 6 | 7 | from .generator import RandomSmilesSampler 8 | 9 | if __name__ == "__main__": 10 | setup_default_logger() 11 | 12 | parser = argparse.ArgumentParser( 13 | description= 14 | "Molecule distribution learning benchmark for random smiles sampler", 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 16 | ) 17 | parser.add_argument("--dist_file", default="data/guacamol_v1_all.smiles") 18 | parser.add_argument("--output_dir", default=None, help="Output directory") 19 | parser.add_argument("--suite", default="v2") 20 | args = parser.parse_args() 21 | 22 | if args.output_dir is None: 23 | args.output_dir = os.path.dirname(os.path.realpath(__file__)) 24 | 25 | with open(args.dist_file, "r") as smiles_file: 26 | smiles_list = [line.strip() for line in smiles_file.readlines()] 27 | 28 | generator = RandomSmilesSampler(molecules=smiles_list) 29 | 30 | json_file_path = os.path.join(args.output_dir, 31 | "distribution_learning_results.json") 32 | 33 | assess_distribution_learning( 34 | generator, 35 | chembl_training_file=args.dist_file, 36 | json_output_file=json_file_path, 37 | benchmark_version=args.suite, 38 | ) 39 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/random_smiles_sampler/generator.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | from guacamol.distribution_matching_generator import DistributionMatchingGenerator 5 | 6 | 7 | class RandomSmilesSampler(DistributionMatchingGenerator): 8 | """ 9 | Generator that samples SMILES strings from a predefined list. 10 | """ 11 | 12 | def __init__(self, molecules: List[str]) -> None: 13 | """ 14 | Args: 15 | molecules: list of molecules from which the samples will be drawn 16 | """ 17 | self.molecules = molecules 18 | 19 | def generate(self, number_samples: int) -> List[str]: 20 | return list(np.random.choice(self.molecules, size=number_samples)) 21 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/random_smiles_sampler/goal_directed_generation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from guacamol.assess_goal_directed_generation import assess_goal_directed_generation 5 | from guacamol.utils.helpers import setup_default_logger 6 | 7 | from .generator import RandomSmilesSampler 8 | from .optimizer import RandomSamplingOptimizer 9 | 10 | if __name__ == "__main__": 11 | setup_default_logger() 12 | 13 | parser = argparse.ArgumentParser( 14 | description= 15 | "Molecule distribution learning benchmark for random smiles sampler", 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 17 | ) 18 | parser.add_argument("--smiles_file", default="data/guacamol_v1_all.smiles") 19 | parser.add_argument("--output_dir", default=None, help="Output directory") 20 | args = parser.parse_args() 21 | 22 | if args.output_dir is None: 23 | args.output_dir = os.path.dirname(os.path.realpath(__file__)) 24 | 25 | with open(args.smiles_file, "r") as smiles_file: 26 | smiles_list = smiles_file.readlines() 27 | 28 | sampler = RandomSmilesSampler(molecules=smiles_list) 29 | 30 | optimizer = RandomSamplingOptimizer(sampler=sampler) 31 | 32 | json_file_path = os.path.join(args.output_dir, "goal_directed_results.json") 33 | 34 | assess_goal_directed_generation(optimizer, json_output_file=json_file_path) 35 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/random_smiles_sampler/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from guacamol.goal_directed_generator import GoalDirectedGenerator 4 | from guacamol.scoring_function import ScoringFunction 5 | 6 | from .generator import RandomSmilesSampler 7 | 8 | 9 | class RandomSamplingOptimizer(GoalDirectedGenerator): 10 | """ 11 | Mock optimizer that will return molecules drawn from a random sampler 12 | """ 13 | 14 | def __init__(self, sampler: RandomSmilesSampler) -> None: 15 | self.sampler = sampler 16 | 17 | def generate_optimized_molecules( 18 | self, 19 | scoring_function: ScoringFunction, 20 | number_molecules: int, 21 | starting_population: Optional[List[str]] = None, 22 | ) -> List[str]: 23 | return self.sampler.generate(number_samples=number_molecules) 24 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/requirements.txt: -------------------------------------------------------------------------------- 1 | guacamol==0.5.3 2 | matplotlib==3.0.2 3 | torch>=1.0.0 4 | joblib==0.12.5 5 | numpy==1.15.2 6 | tqdm==4.26.0 7 | cython==0.29 8 | nltk==3.4.5 9 | flake8==3.5.0 10 | git+https://github.com/molecularsets/moses.git@master#egg=moses 11 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/smiles_ga/cfg_util.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | import numpy as np 4 | 5 | from . import smiles_grammar 6 | 7 | 8 | def get_smiles_tokenizer(cfg): 9 | long_tokens = [a for a in cfg._lexical_index.keys() if len(a) > 1] 10 | # there are currently 6 double letter entities in the grammar 11 | # these are their replacement, with no particular meaning 12 | # they need to be ascii and not part of the SMILES symbol vocabulary 13 | replacements = ["!", "?", ".", ",", ";", "$"] 14 | assert len(long_tokens) == len(replacements) 15 | for token in replacements: 16 | assert token not in cfg._lexical_index 17 | 18 | def tokenize(smiles): 19 | for i, token in enumerate(long_tokens): 20 | smiles = smiles.replace(token, replacements[i]) 21 | tokens = [] 22 | for token in smiles: 23 | try: 24 | ix = replacements.index(token) 25 | tokens.append(long_tokens[ix]) 26 | except Exception: 27 | tokens.append(token) 28 | return tokens 29 | 30 | return tokenize 31 | 32 | 33 | def encode(smiles): 34 | GCFG = smiles_grammar.GCFG 35 | tokenize = get_smiles_tokenizer(GCFG) 36 | tokens = tokenize(smiles) 37 | parser = nltk.ChartParser(GCFG) 38 | parse_tree = parser.parse(tokens).__next__() 39 | productions_seq = parse_tree.productions() 40 | productions = GCFG.productions() 41 | prod_map = {} 42 | for ix, prod in enumerate(productions): 43 | prod_map[prod] = ix 44 | indices = np.array([prod_map[prod] for prod in productions_seq], dtype=int) 45 | return indices 46 | 47 | 48 | def prods_to_eq(prods): 49 | seq = [prods[0].lhs()] 50 | for prod in prods: 51 | if str(prod.lhs()) == "Nothing": 52 | break 53 | for ix, s in enumerate(seq): 54 | if s == prod.lhs(): 55 | seq = seq[:ix] + list(prod.rhs()) + seq[ix + 1:] 56 | break 57 | try: 58 | return "".join(seq) 59 | except Exception: 60 | return "" 61 | 62 | 63 | def decode(rule): 64 | productions = smiles_grammar.GCFG.productions() 65 | prod_seq = [productions[i] for i in rule] 66 | return prods_to_eq(prod_seq) 67 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/smiles_ga/smiles_grammar.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | # smiles grammar 4 | gram = """smiles -> chain 5 | atom -> bracket_atom 6 | atom -> aliphatic_organic 7 | atom -> aromatic_organic 8 | aliphatic_organic -> 'B' 9 | aliphatic_organic -> 'C' 10 | aliphatic_organic -> 'F' 11 | aliphatic_organic -> 'H' 12 | aliphatic_organic -> 'I' 13 | aliphatic_organic -> 'N' 14 | aliphatic_organic -> 'O' 15 | aliphatic_organic -> 'P' 16 | aliphatic_organic -> 'S' 17 | aliphatic_organic -> 'Cl' 18 | aliphatic_organic -> 'Br' 19 | aliphatic_organic -> 'Si' 20 | aliphatic_organic -> 'Se' 21 | aromatic_organic -> 'b' 22 | aromatic_organic -> 'c' 23 | aromatic_organic -> 'n' 24 | aromatic_organic -> 'o' 25 | aromatic_organic -> 'p' 26 | aromatic_organic -> 's' 27 | aromatic_organic -> 'se' 28 | bracket_atom -> '[' BAI ']' 29 | BAI -> isotope symbol BAC 30 | BAI -> symbol BAC 31 | BAI -> isotope symbol 32 | BAI -> symbol 33 | BAC -> chiral BAH 34 | BAC -> BAH 35 | BAC -> chiral 36 | BAH -> hcount BACH 37 | BAH -> BACH 38 | BAH -> hcount 39 | BACH -> charge 40 | symbol -> aliphatic_organic 41 | symbol -> aromatic_organic 42 | isotope -> DIGIT 43 | isotope -> DIGIT DIGIT 44 | isotope -> DIGIT DIGIT DIGIT 45 | DIGIT -> '1' 46 | DIGIT -> '2' 47 | DIGIT -> '3' 48 | DIGIT -> '4' 49 | DIGIT -> '5' 50 | DIGIT -> '6' 51 | DIGIT -> '7' 52 | DIGIT -> '8' 53 | DIGIT -> '9' 54 | DIGIT -> '0' 55 | chiral -> '@' 56 | chiral -> '@@' 57 | hcount -> 'H' 58 | hcount -> 'H' DIGIT 59 | charge -> '-' 60 | charge -> '-' DIGIT 61 | charge -> '-' DIGIT DIGIT 62 | charge -> '+' 63 | charge -> '+' DIGIT 64 | charge -> '+' DIGIT DIGIT 65 | bond -> '-' 66 | bond -> '=' 67 | bond -> '#' 68 | bond -> '/' 69 | bond -> '\\' 70 | ringbond -> DIGIT 71 | ringbond -> bond DIGIT 72 | branched_atom -> atom 73 | branched_atom -> atom RB 74 | branched_atom -> atom BB 75 | branched_atom -> atom RB BB 76 | RB -> RB ringbond 77 | RB -> ringbond 78 | BB -> BB branch 79 | BB -> branch 80 | branch -> '(' chain ')' 81 | branch -> '(' bond chain ')' 82 | chain -> branched_atom 83 | chain -> chain branched_atom 84 | chain -> chain bond branched_atom 85 | Nothing -> None""" 86 | 87 | # form the CFG and get the start symbol 88 | GCFG = nltk.CFG.fromstring(gram) 89 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/smiles_lstm_hc/distribution_learning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from pathlib import Path 5 | 6 | import torch 7 | 8 | from guacamol.assess_distribution_learning import assess_distribution_learning 9 | from guacamol.utils.helpers import setup_default_logger 10 | 11 | from .rnn_utils import load_rnn_model, set_random_seed 12 | from .smiles_rnn_generator import SmilesRnnGenerator 13 | 14 | if __name__ == "__main__": 15 | setup_default_logger() 16 | logger = logging.getLogger(__name__) 17 | 18 | parser = argparse.ArgumentParser( 19 | description="Distribution learning benchmark for SMILES RNN", 20 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 21 | ) 22 | parser.add_argument("--seed", default=42, type=int, help="Random seed") 23 | parser.add_argument("--model_path", 24 | default=None, 25 | help="Full path to SMILES RNN model") 26 | parser.add_argument("--output_dir", default=None, help="Output directory") 27 | parser.add_argument("--dist_file", 28 | default="data/guacamol_v1_all.smiles", 29 | help="Distribution file") 30 | parser.add_argument("--suite", default="v2") 31 | 32 | args = parser.parse_args() 33 | 34 | device = "cuda" if torch.cuda.is_available() else "cpu" 35 | logger.info(f"device:\t{device}") 36 | 37 | set_random_seed(args.seed, device) 38 | 39 | if args.output_dir is None: 40 | args.output_dir = os.path.dirname(os.path.realpath(__file__)) 41 | 42 | if args.model_path is None: 43 | dir_path = os.path.dirname(os.path.realpath(__file__)) 44 | args.model_path = os.path.join(dir_path, "pretrained_model", 45 | "model_final_0.473.pt") 46 | 47 | model_def = Path(args.model_path).with_suffix(".json") 48 | model = load_rnn_model(model_def, args.model_path, device, copy_to_cpu=True) 49 | generator = SmilesRnnGenerator(model=model, device=device) 50 | 51 | json_file_path = os.path.join(args.output_dir, 52 | "distribution_learning_results.json") 53 | assess_distribution_learning( 54 | generator, 55 | chembl_training_file=args.dist_file, 56 | json_output_file=json_file_path, 57 | benchmark_version=args.suite, 58 | ) 59 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/smiles_lstm_hc/rnn_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .action_sampler import ActionSampler 4 | from .rnn_model import SmilesRnn 5 | from .smiles_char_dict import SmilesCharDictionary 6 | 7 | 8 | class SmilesRnnSampler: 9 | """ 10 | Samples molecules from an RNN smiles language model 11 | """ 12 | 13 | def __init__(self, device: str, batch_size=64) -> None: 14 | """ 15 | Args: 16 | device: cpu | cuda 17 | batch_size: number of concurrent samples to generate 18 | """ 19 | self.device = device 20 | self.batch_size = batch_size 21 | self.sd = SmilesCharDictionary() 22 | 23 | def sample(self, model: SmilesRnn, num_to_sample: int, max_seq_len=100): 24 | """ 25 | 26 | Args: 27 | model: RNN to sample from 28 | num_to_sample: number of samples to produce 29 | max_seq_len: maximum length of the samples 30 | batch_size: number of concurrent samples to generate 31 | 32 | Returns: a list of SMILES string, with no beginning nor end symbols 33 | 34 | """ 35 | sampler = ActionSampler( 36 | max_batch_size=self.batch_size, 37 | max_seq_length=max_seq_len, 38 | device=self.device, 39 | ) 40 | 41 | model.eval() 42 | with torch.no_grad(): 43 | indices = sampler.sample(model, num_samples=num_to_sample) 44 | return self.sd.matrix_to_smiles(indices) 45 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/smiles_lstm_hc/smiles_rnn_generator.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from guacamol.distribution_matching_generator import DistributionMatchingGenerator 4 | 5 | from .rnn_model import SmilesRnn 6 | from .rnn_sampler import SmilesRnnSampler 7 | 8 | 9 | class SmilesRnnGenerator(DistributionMatchingGenerator): 10 | """ 11 | Wraps SmilesRnn in a class satisfying the DistributionMatchingGenerator interface. 12 | """ 13 | 14 | def __init__(self, model: SmilesRnn, device: str) -> None: 15 | self.model = model 16 | self.device = device 17 | 18 | def generate(self, number_samples: int) -> List[str]: 19 | sampler = SmilesRnnSampler(device=self.device) 20 | return sampler.sample(model=self.model, num_to_sample=number_samples) 21 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/smiles_lstm_ppo/molecule_batch.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from guacamol.utils.chemistry import canonicalize 4 | from guacamol.utils.data import remove_duplicates 5 | 6 | 7 | class MoleculeBatch(object): 8 | """ 9 | Delivers useful properties about a batch of generated SMILES strings. 10 | 11 | Canonicalization of the SMILES strings, and removal of the duplicates, will be 12 | done only one time, and only if necessary. 13 | """ 14 | 15 | def __init__(self, smiles: List[str]) -> None: 16 | self._smiles = smiles 17 | self._canonical_smiles = None 18 | self._unique_canonical_smiles = None 19 | 20 | @property 21 | def canonical_smiles(self): 22 | self._canonicalize() 23 | return self._canonical_smiles 24 | 25 | @property 26 | def unique_canonical_smiles(self): 27 | self._remove_duplicates() 28 | return self._unique_canonical_smiles 29 | 30 | @property 31 | def size(self): 32 | return len(self._smiles) 33 | 34 | @property 35 | def number_valid(self): 36 | self._canonicalize() 37 | return len(self._canonical_smiles) 38 | 39 | @property 40 | def number_unique(self): 41 | self._remove_duplicates() 42 | return len(self._unique_canonical_smiles) 43 | 44 | @property 45 | def ratio_valid(self): 46 | return self.number_valid / self.size 47 | 48 | @property 49 | def ratio_unique(self): 50 | """The ratio of unique valid molecules compared to the total size""" 51 | return self.number_unique / self.size 52 | 53 | @property 54 | def ratio_unique_among_valid(self): 55 | return self.number_unique / self.number_valid 56 | 57 | def _canonicalize(self): 58 | if self._canonical_smiles is not None: 59 | return 60 | 61 | canonical = [canonicalize(mol) for mol in self._smiles] 62 | self._canonical_smiles = [s for s in canonical if s is not None] 63 | 64 | def _remove_duplicates(self): 65 | if self._unique_canonical_smiles is not None: 66 | return 67 | 68 | self._canonicalize() 69 | self._unique_canonical_smiles = remove_duplicates( 70 | self._canonical_smiles) 71 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/smiles_lstm_ppo/ppo_generator.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import logging 4 | 5 | from smiles_lstm_ppo.ppo_trainer import PPOTrainer, OptResult 6 | from smiles_lstm_ppo.rnn_model import SmilesRnnActorCritic 7 | from smiles_lstm_hc.rnn_sampler import SmilesRnnSampler 8 | from guacamol.scoring_function import ScoringFunction 9 | 10 | logger = logging.getLogger(__name__) 11 | logger.addHandler(logging.NullHandler()) 12 | 13 | 14 | class PPOMoleculeGenerator: 15 | 16 | def __init__(self, model: SmilesRnnActorCritic, max_seq_length, 17 | device) -> None: 18 | self.model = model 19 | self.max_seq_length = max_seq_length 20 | self.device = device 21 | self.sampler = SmilesRnnSampler(device=device, batch_size=512) 22 | 23 | def optimise(self, objective: ScoringFunction, start_population: list, 24 | **kwargs) -> List[OptResult]: 25 | if start_population: 26 | logger.warning( 27 | "PPO algorithm does not support (yet) a starting population") 28 | num_epochs = kwargs["num_epochs"] 29 | episode_size = kwargs["optimize_episode_size"] 30 | batch_size = kwargs["optimize_batch_size"] 31 | entropy_weight = kwargs["entropy_weight"] 32 | kl_div_weight = kwargs["kl_div_weight"] 33 | clip_param = kwargs["clip_param"] 34 | 35 | trainer = PPOTrainer( 36 | self.model, 37 | objective, 38 | device=self.device, 39 | max_seq_length=self.max_seq_length, 40 | batch_size=batch_size, 41 | num_epochs=num_epochs, 42 | clip_param=clip_param, 43 | episode_size=episode_size, 44 | entropy_weight=entropy_weight, 45 | kl_div_weight=kl_div_weight, 46 | ) 47 | trainer.train() 48 | 49 | return sorted(trainer.smiles_history, reverse=True) 50 | 51 | def sample(self, num_mols) -> List[str]: 52 | return self.sampler.sample( 53 | self.model.smiles_rnn, 54 | num_to_sample=num_mols, 55 | max_seq_len=self.max_seq_length, 56 | ) 57 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/smiles_lstm_ppo/rnn_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SmilesRnnActorCritic(nn.Module): 5 | 6 | def __init__(self, smiles_rnn) -> None: 7 | """ 8 | Creates an Actor-Critic model from a Smiles RNN Language model 9 | 10 | Args: 11 | smiles_rnn: a SmilesRnn object 12 | """ 13 | super().__init__() 14 | 15 | self.smiles_rnn = smiles_rnn 16 | 17 | self.critic_decoder = nn.Linear(self.smiles_rnn.hidden_size, 1) 18 | 19 | self.init_weights() 20 | 21 | def init_weights(self): 22 | # critic_decoder 23 | nn.init.xavier_uniform_(self.critic_decoder.weight) 24 | nn.init.constant_(self.critic_decoder.bias, 0) 25 | 26 | def forward(self, x, hidden): 27 | embeds = self.smiles_rnn.encoder(x) 28 | output, hidden = self.smiles_rnn.rnn(embeds, hidden) 29 | actor_output = self.smiles_rnn.decoder(output) 30 | critic_output = self.critic_decoder(output) 31 | return actor_output, critic_output, hidden 32 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/smiles_lstm_ppo/running_reward.py: -------------------------------------------------------------------------------- 1 | class RunningReward(object): 2 | 3 | def __init__(self, keep_factor: float, initial_value=0) -> None: 4 | """ 5 | Args: 6 | keep_factor: How much of the last value to keep when a new one is added. 7 | initial_value: Initial reward 8 | """ 9 | assert keep_factor >= 0.0 10 | assert keep_factor <= 1.0 11 | 12 | self._keep_factor = keep_factor 13 | self._reward = initial_value 14 | self.last_added = initial_value 15 | 16 | @property 17 | def value(self): 18 | """Get the current running reward.""" 19 | return self._reward 20 | 21 | def update(self, reward): 22 | """Update the running reward with a new value.""" 23 | self._reward *= self._keep_factor 24 | self._reward += reward * (1.0 - self._keep_factor) 25 | self.last_added = reward 26 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/guacamol_baselines/upload.sh: -------------------------------------------------------------------------------- 1 | 2 | scp -r $1 tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/guacamol_tdc/guacamol_baselines 3 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/results0/graph_ga/goal_directed_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "generations": 1000, 3 | "mutation_rate": 0.01, 4 | "n_jobs": -1, 5 | "offspring_size": 200, 6 | "output_dir": "/net/sunlab/psunlab1/molecular_data/graphnn/guacamol_tdc/guacamol_baselines/graph_ga", 7 | "patience": 5, 8 | "population_size": 100, 9 | "random_start": false, 10 | "seed": 0, 11 | "smiles_file": "data/guacamol_v1_all.smiles", 12 | "suite": "v3" 13 | } 14 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/results0/smiles_ga/goal_directed_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "gene_size": 300, 3 | "generations": 1000, 4 | "n_jobs": -1, 5 | "n_mutations": 200, 6 | "output_dir": "/net/sunlab/psunlab1/molecular_data/graphnn/guacamol_tdc/guacamol_baselines/smiles_ga", 7 | "patience": 5, 8 | "population_size": 100, 9 | "random_start": false, 10 | "seed": 42, 11 | "smiles_file": "data/guacamol_v1_all.smiles", 12 | "suite": "v3" 13 | } 14 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/results0/smiles_lstm_ppo/goal_directed_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 1024, 3 | "clip_param": 0.2, 4 | "entropy_weight": 1, 5 | "episode_size": 8192, 6 | "kl_div_weight": 10, 7 | "model_path": "/net/sunlab/psunlab1/molecular_data/graphnn/guacamol_tdc/guacamol_baselines/smiles_lstm_ppo/pretrained_model/model_final_0.473.pt", 8 | "n_jobs": -1, 9 | "num_epochs": 20, 10 | "output_dir": "/net/sunlab/psunlab1/molecular_data/graphnn/guacamol_tdc/guacamol_baselines/smiles_lstm_ppo", 11 | "seed": 42, 12 | "suite": "v3" 13 | } 14 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/results0/smiles_lstm_ppo/pretrained_model/model_final_0.473.json: -------------------------------------------------------------------------------- 1 | {"input_size": 47, "hidden_size": 1024, "output_size": 47, "n_layers": 3, "rnn_dropout": 0.2} 2 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/guacamol_tdc/upload.sh: -------------------------------------------------------------------------------- 1 | 2 | # scp -r guacamol tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/guacamol_tdc/ 3 | 4 | # scp -r guacamol_baselines/graph_ga tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/ 5 | 6 | scp -r guacamol_baselines/smiles_lstm_hc tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/ 7 | 8 | scp guacamol_baselines/graph_ga/graph_ga_run.py tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/graph_ga_run.py 9 | 10 | scp guacamol_baselines/smiles_lstm_hc/run_smiles_lstm_hc.py tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/run_smiles_lstm_hc.py 11 | 12 | scp -r guacamol_baselines/smiles_lstm_hc tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/pyscreener/ 13 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/chemutil.py: -------------------------------------------------------------------------------- 1 | import rdkit 2 | from rdkit import Chem, DataStructs 3 | from rdkit.Chem import AllChem 4 | from rdkit.Chem import Draw 5 | 6 | 7 | ## similarity of two SMILES 8 | def similarity(a, b): 9 | if a is None or b is None: 10 | return 0.0 11 | amol = Chem.MolFromSmiles(a) 12 | bmol = Chem.MolFromSmiles(b) 13 | if amol is None or bmol is None: 14 | return 0.0 15 | fp1 = AllChem.GetMorganFingerprintAsBitVect(amol, 16 | 2, 17 | nBits=2048, 18 | useChirality=False) 19 | fp2 = AllChem.GetMorganFingerprintAsBitVect(bmol, 20 | 2, 21 | nBits=2048, 22 | useChirality=False) 23 | return DataStructs.TanimotoSimilarity(fp1, fp2) 24 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/configs/bootstrap_dqn.json: -------------------------------------------------------------------------------- 1 | { 2 | "atom_types": ["C", "O", "N"], 3 | "max_steps_per_episode": 40, 4 | "allow_removal": true, 5 | "allow_no_modification": true, 6 | "allow_bonds_between_rings": false, 7 | "allowed_ring_sizes": [5, 6], 8 | "replay_buffer_size": 5000, 9 | "learning_rate": 0.0001, 10 | "learning_rate_decay_steps": 10000, 11 | "learning_rate_decay_rate": 0.9, 12 | "num_episodes": 5000, 13 | "batch_size": 128, 14 | "learning_frequency": 4, 15 | "update_frequency": 20, 16 | "grad_clipping": 10, 17 | "gamma": 1.0, 18 | "discount_factor": 0.9, 19 | "double_q": true, 20 | "num_bootstrap_heads": 12, 21 | "prioritized": false, 22 | "prioritized_alpha": 0.6, 23 | "prioritized_beta": 0.4, 24 | "prioritized_epsilon": 0.000001, 25 | "fingerprint_radius": 3, 26 | "fingerprint_length": 2048, 27 | "dense_layers": [1024, 512, 128, 32], 28 | "activation": "relu", 29 | "optimizer": "Adam", 30 | "batch_norm": false, 31 | "save_frequency": 200, 32 | "max_num_checkpoints": 10 33 | } 34 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/configs/bootstrap_dqn_opt_800.json: -------------------------------------------------------------------------------- 1 | { 2 | "atom_types": ["C", "N", "O"], 3 | "max_steps_per_episode": 20, 4 | "allow_removal": true, 5 | "allow_no_modification": true, 6 | "allow_bonds_between_rings": false, 7 | "allowed_ring_sizes": [5, 6], 8 | "replay_buffer_size": 5000, 9 | "learning_rate": 0.0001, 10 | "learning_rate_decay_steps": 10000, 11 | "learning_rate_decay_rate": 0.9, 12 | "num_episodes": 40000, 13 | "batch_size": 128, 14 | "learning_frequency": 4, 15 | "update_frequency": 20, 16 | "grad_clipping": 10, 17 | "gamma": 1.0, 18 | "discount_factor": 0.9, 19 | "double_q": true, 20 | "num_bootstrap_heads": 12, 21 | "prioritized": false, 22 | "prioritized_alpha": 0.6, 23 | "prioritized_beta": 0.4, 24 | "prioritized_epsilon": 0.000001, 25 | "fingerprint_radius": 3, 26 | "fingerprint_length": 2048, 27 | "dense_layers": [1024, 512, 128, 32], 28 | "activation": "relu", 29 | "optimizer": "Adam", 30 | "batch_norm": false, 31 | "save_frequency": 200, 32 | "max_num_checkpoints": 10 33 | } 34 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/configs/bootstrap_dqn_step1.json: -------------------------------------------------------------------------------- 1 | { 2 | "atom_types": ["C", "O", "N"], 3 | "max_steps_per_episode": 40, 4 | "allow_removal": true, 5 | "allow_no_modification": true, 6 | "allow_bonds_between_rings": false, 7 | "allowed_ring_sizes": [5, 6, 7], 8 | "replay_buffer_size": 5000, 9 | "learning_rate": 0.0001, 10 | "learning_rate_decay_steps": 10000, 11 | "learning_rate_decay_rate": 0.9, 12 | "num_episodes": 5000, 13 | "batch_size": 128, 14 | "learning_frequency": 4, 15 | "update_frequency": 20, 16 | "grad_clipping": 10, 17 | "gamma": 1.0, 18 | "discount_factor": 0.9, 19 | "double_q": true, 20 | "num_bootstrap_heads": 12, 21 | "prioritized": false, 22 | "prioritized_alpha": 0.6, 23 | "prioritized_beta": 0.4, 24 | "prioritized_epsilon": 0.000001, 25 | "fingerprint_radius": 3, 26 | "fingerprint_length": 2048, 27 | "dense_layers": [1024, 512, 128, 32], 28 | "activation": "relu", 29 | "optimizer": "Adam", 30 | "batch_norm": false, 31 | "save_frequency": 200, 32 | "max_num_checkpoints": 10 33 | } 34 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/configs/bootstrap_dqn_step2.json: -------------------------------------------------------------------------------- 1 | { 2 | "atom_types": ["C", "O", "N"], 3 | "max_steps_per_episode": 40, 4 | "allow_removal": true, 5 | "allow_no_modification": true, 6 | "allow_bonds_between_rings": false, 7 | "allowed_ring_sizes": [5, 6, 7], 8 | "replay_buffer_size": 5000, 9 | "learning_rate": 0.0001, 10 | "learning_rate_decay_steps": 10000, 11 | "learning_rate_decay_rate": 0.9, 12 | "num_episodes": 5000, 13 | "batch_size": 128, 14 | "learning_frequency": 4, 15 | "update_frequency": 20, 16 | "grad_clipping": 10, 17 | "gamma": 1.0, 18 | "discount_factor": 0.9, 19 | "double_q": true, 20 | "num_bootstrap_heads": 12, 21 | "prioritized": false, 22 | "prioritized_alpha": 0.6, 23 | "prioritized_beta": 0.4, 24 | "prioritized_epsilon": 0.000001, 25 | "fingerprint_radius": 3, 26 | "fingerprint_length": 2048, 27 | "dense_layers": [1024, 512, 128, 32], 28 | "activation": "relu", 29 | "optimizer": "Adam", 30 | "batch_norm": false, 31 | "save_frequency": 200, 32 | "max_num_checkpoints": 10 33 | } 34 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/configs/multi_obj_dqn.json: -------------------------------------------------------------------------------- 1 | { 2 | "atom_types": ["C", "O", "N"], 3 | "max_steps_per_episode": 20, 4 | "allow_removal": true, 5 | "allow_no_modification": true, 6 | "allow_bonds_between_rings": false, 7 | "allowed_ring_sizes": [5, 6], 8 | "replay_buffer_size": 5000, 9 | "learning_rate": 0.0001, 10 | "learning_rate_decay_steps": 10000, 11 | "learning_rate_decay_rate": 0.9, 12 | "num_episodes": 3000, 13 | "batch_size": 128, 14 | "learning_frequency": 4, 15 | "update_frequency": 20, 16 | "grad_clipping": 10, 17 | "gamma": 1.0, 18 | "discount_factor": 0.9, 19 | "double_q": true, 20 | "num_bootstrap_heads": 12, 21 | "prioritized": false, 22 | "prioritized_alpha": 0.6, 23 | "prioritized_beta": 0.4, 24 | "prioritized_epsilon": 0.000001, 25 | "fingerprint_radius": 3, 26 | "fingerprint_length": 2048, 27 | "dense_layers": [1024, 512, 128, 32], 28 | "activation": "relu", 29 | "optimizer": "Adam", 30 | "batch_norm": false, 31 | "save_frequency": 200, 32 | "max_num_checkpoints": 10 33 | } 34 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/configs/naive_dqn.json: -------------------------------------------------------------------------------- 1 | { 2 | "atom_types": ["C", "O", "N"], 3 | "max_steps_per_episode": 40, 4 | "allow_removal": true, 5 | "allow_no_modification": true, 6 | "allow_bonds_between_rings": false, 7 | "allowed_ring_sizes": [5, 6], 8 | "replay_buffer_size": 5000, 9 | "learning_rate": 0.0001, 10 | "learning_rate_decay_steps": 10000, 11 | "learning_rate_decay_rate": 0.9, 12 | "num_episodes": 10000, 13 | "batch_size": 128, 14 | "learning_frequency": 4, 15 | "update_frequency": 20, 16 | "grad_clipping": 10, 17 | "gamma": 1.0, 18 | "discount_factor": 0.9, 19 | "double_q": true, 20 | "num_bootstrap_heads": 0, 21 | "prioritized": false, 22 | "prioritized_alpha": 0.6, 23 | "prioritized_beta": 0.4, 24 | "prioritized_epsilon": 0.000001, 25 | "fingerprint_radius": 3, 26 | "fingerprint_length": 2048, 27 | "dense_layers": [1024, 512, 128, 32], 28 | "activation": "relu", 29 | "optimizer": "Adam", 30 | "batch_norm": false, 31 | "save_frequency": 200, 32 | "max_num_checkpoints": 10 33 | } 34 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/configs/naive_dqn_opt_800.json: -------------------------------------------------------------------------------- 1 | { 2 | "atom_types": ["C", "O", "N"], 3 | "max_steps_per_episode": 20, 4 | "allow_removal": true, 5 | "allow_no_modification": true, 6 | "allow_bonds_between_rings": false, 7 | "allowed_ring_sizes": [5, 6], 8 | "replay_buffer_size": 5000, 9 | "learning_rate": 0.0001, 10 | "learning_rate_decay_steps": 10000, 11 | "learning_rate_decay_rate": 0.9, 12 | "num_episodes": 40000, 13 | "batch_size": 128, 14 | "learning_frequency": 4, 15 | "update_frequency": 20, 16 | "grad_clipping": 10, 17 | "gamma": 1.0, 18 | "discount_factor": 0.9, 19 | "double_q": true, 20 | "num_bootstrap_heads": 0, 21 | "prioritized": false, 22 | "prioritized_alpha": 0.6, 23 | "prioritized_beta": 0.4, 24 | "prioritized_epsilon": 0.000001, 25 | "fingerprint_radius": 3, 26 | "fingerprint_length": 2048, 27 | "dense_layers": [1024, 512, 128, 32], 28 | "activation": "relu", 29 | "optimizer": "Adam", 30 | "batch_norm": false, 31 | "save_frequency": 200, 32 | "max_num_checkpoints": 10 33 | } 34 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/configs/qed_logp_jnk_gsk.json: -------------------------------------------------------------------------------- 1 | { 2 | "atom_types": ["C", "O", "N"], 3 | "max_steps_per_episode": 40, 4 | "allow_removal": true, 5 | "allow_no_modification": true, 6 | "allow_bonds_between_rings": false, 7 | "allowed_ring_sizes": [5, 6], 8 | "replay_buffer_size": 5000, 9 | "learning_rate": 0.0001, 10 | "learning_rate_decay_steps": 10000, 11 | "learning_rate_decay_rate": 0.9, 12 | "num_episodes": 5000, 13 | "batch_size": 128, 14 | "learning_frequency": 4, 15 | "update_frequency": 20, 16 | "grad_clipping": 10, 17 | "gamma": 1.0, 18 | "discount_factor": 0.9, 19 | "double_q": true, 20 | "num_bootstrap_heads": 0, 21 | "prioritized": false, 22 | "prioritized_alpha": 0.6, 23 | "prioritized_beta": 0.4, 24 | "prioritized_epsilon": 0.000001, 25 | "fingerprint_radius": 3, 26 | "fingerprint_length": 2048, 27 | "dense_layers": [1024, 512, 128, 32], 28 | "activation": "relu", 29 | "optimizer": "Adam", 30 | "batch_norm": false, 31 | "save_frequency": 200, 32 | "max_num_checkpoints": 10 33 | } 34 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/configs/target_sas.json: -------------------------------------------------------------------------------- 1 | { 2 | "atom_types": ["C", "O", "N"], 3 | "max_steps_per_episode": 40, 4 | "allow_removal": true, 5 | "allow_no_modification": true, 6 | "allow_bonds_between_rings": false, 7 | "allowed_ring_sizes": [3, 4, 5, 6, 7], 8 | "replay_buffer_size": 5000, 9 | "learning_rate": 0.0001, 10 | "learning_rate_decay_steps": 10000, 11 | "learning_rate_decay_rate": 0.9, 12 | "num_episodes": 5000, 13 | "batch_size": 128, 14 | "learning_frequency": 4, 15 | "update_frequency": 20, 16 | "grad_clipping": 10, 17 | "gamma": 1.0, 18 | "discount_factor": 0.9, 19 | "double_q": true, 20 | "num_bootstrap_heads": 20, 21 | "prioritized": false, 22 | "prioritized_alpha": 0.6, 23 | "prioritized_beta": 0.4, 24 | "prioritized_epsilon": 0.000001, 25 | "fingerprint_radius": 3, 26 | "fingerprint_length": 2048, 27 | "dense_layers": [1024, 512, 128, 32], 28 | "activation": "relu", 29 | "optimizer": "Adam", 30 | "batch_norm": false, 31 | "save_frequency": 200, 32 | "max_num_checkpoints": 10 33 | } 34 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/docking_iter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/moldqn/chemgraph/docking_iter.png -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_1.pkl -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_10.pkl -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_11.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_11.pkl -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_12.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_12.pkl -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_3.pkl -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_9.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/examples/generation/docking_generation/moldqn/chemgraph/docking_smilesvaluelst_9.pkl -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/download.sh: -------------------------------------------------------------------------------- 1 | 2 | scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/mol_dqn_docking/chemgraph/docking_smilesvaluelst_* . 3 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/dqn/py/SA_Score/README: -------------------------------------------------------------------------------- 1 | RDKit-based implementation of the method described in: 2 | 3 | Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 4 | Peter Ertl and Ansgar Schuffenhauer 5 | Journal of Cheminformatics 1:8 (2009) 6 | http://www.jcheminf.com/content/1/1/8 7 | 8 | Contribution from Peter Ertl and Greg Landrum 9 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/dqn/py/SA_Score/UnitTestSAScore.py: -------------------------------------------------------------------------------- 1 | from rdkit import RDConfig 2 | from rdkit import Chem 3 | import unittest, os.path 4 | import sascorer 5 | 6 | print(sascorer.__file__) 7 | 8 | 9 | class TestCase(unittest.TestCase): 10 | 11 | def test1(self): 12 | with open("data/zim.100.txt") as f: 13 | testData = [x.strip().split("\t") for x in f] 14 | testData.pop(0) 15 | for row in testData: 16 | smi = row[0] 17 | m = Chem.MolFromSmiles(smi) 18 | tgt = float(row[2]) 19 | val = sascorer.calculateScore(m) 20 | self.assertAlmostEqual(tgt, val, 3) 21 | 22 | 23 | if __name__ == "__main__": 24 | import sys, getopt, re 25 | 26 | doLong = 0 27 | if len(sys.argv) > 1: 28 | args, extras = getopt.getopt(sys.argv[1:], "l") 29 | for arg, val in args: 30 | if arg == "-l": 31 | doLong = 1 32 | sys.argv.remove("-l") 33 | if doLong: 34 | for methName in dir(TestCase): 35 | if re.match("_test", methName): 36 | newName = re.sub("_test", "test", methName) 37 | exec("TestCase.%s = TestCase.%s" % (newName, methName)) 38 | 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/dqn/py/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/dqn/tensorflow_core/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/result_analysis0.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | result_folder = "result" 6 | num_lst = [] 7 | num2dockingscore = {} 8 | file_lst = os.listdir(result_folder) 9 | for file in file_lst: 10 | file = os.path.join(result_folder, file) 11 | num = int(file.split("-")[1]) 12 | num_lst.append(num) 13 | with open(file, "r") as fin: 14 | lines = fin.readlines() 15 | smiles_score_lst = [ 16 | (line.split()[0], float(line.split()[1])) for line in lines 17 | ] 18 | score_lst = [i[1] for i in smiles_score_lst] 19 | score_lst.sort() 20 | assert score_lst[0] <= score_lst[1] 21 | num2dockingscore[num] = (score_lst[0], np.mean(score_lst[:10]), 22 | np.mean(score_lst)) 23 | num_lst.sort() 24 | 25 | top_1 = [num2dockingscore[num][0] for num in num_lst] 26 | top_10 = [num2dockingscore[num][1] for num in num_lst] 27 | top_100 = [num2dockingscore[num][2] for num in num_lst] 28 | num_lst = [i for i in range(len(num_lst))] 29 | num_lst = [i / max(num_lst) * 5000 for i in num_lst] 30 | 31 | plt.plot(num_lst, top_1, color="b", label="top-1") 32 | plt.plot(num_lst, top_10, color="r", label="top-10") 33 | plt.plot(num_lst, top_100, color="y", label="top-100") 34 | plt.legend() 35 | plt.xlabel("# docking call") 36 | plt.ylabel("docking score (DRD3) achieved by MolDQN") 37 | plt.savefig("docking_iter.png") 38 | """ 39 | cd chemgraph 40 | scp -r tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/mol_dqn_docking/chemgraph/result . 41 | 42 | """ 43 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/chemgraph/try.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.chdir("/project/molecular_data/graphnn") 4 | from .dqn import deep_q_networks 5 | 6 | exit() 7 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | baselines 3 | networkx 4 | numpy 5 | rdkit 6 | tensorflow 7 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/requirements2.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | baselines 3 | networkx 4 | numpy 5 | tensorflow 6 | -------------------------------------------------------------------------------- /examples/generation/docking_generation/moldqn/upload.sh: -------------------------------------------------------------------------------- 1 | 2 | scp -r $1 tfu42@orcus1.cc.gatech.edu:/project/molecular_data/graphnn/mol_dqn_docking/ 3 | -------------------------------------------------------------------------------- /examples/huggingface_examples/herg/.gitignore: -------------------------------------------------------------------------------- 1 | tutorial_model 2 | result 3 | mlruns 4 | -------------------------------------------------------------------------------- /examples/multi_pred/drugcombo/README.md: -------------------------------------------------------------------------------- 1 | ## Drug Combination Benchmark Group MLP Baselines 2 | 3 | This directory contains the code used to train a simple baseline MLP regression model for the DrugCombo benchmark. 4 | 5 | To run the model, simply: 6 | 7 | ```python 8 | python train_MLP.py --epochs 10 --batch_size 128 --cuda True 9 | 10 | ''' 11 | {'drugcomb_css': [16.858, 0.005], 'drugcomb_css_kidney': [14.57, 0.003], 'drugcomb_css_lung': [15.653, 0.017], 'drugcomb_css_breast': [13.432, 0.049], 'drugcomb_css_hematopoietic_lymphoid': [28.764, 0.201], 'drugcomb_css_colon': [17.729, 0.042], 'drugcomb_css_prostate': [15.692, 0.005], 'drugcomb_css_ovary': [15.263, 0.041], 'drugcomb_css_skin': [15.663, 0.065], 'drugcomb_css_brain': [15.694, 0.006], 'drugcomb_hsa': [4.453, 0.002], 'drugcomb_loewe': [9.184, 0.001], 'drugcomb_bliss': [4.56, 0.0], 'drugcomb_zip': [4.027, 0.003]} 12 | ''' 13 | 14 | ``` 15 | -------------------------------------------------------------------------------- /examples/multi_pred/dti_dg/README.md: -------------------------------------------------------------------------------- 1 | # TDC DTI Domain Generalization Leaderboard 2 | 3 | We adapt code from [domainbed](https://arxiv.org/abs/2007.01434) to add 7 baselines for this leaderboard. For the backbone model, we use [DeepDTA](https://academic.oup.com/bioinformatics/article/34/17/i821/5093245), one of the SOTA baselines for DTI affinity prediction. 4 | 5 | ### Environment 6 | 7 | `torch, numpy, pandas, tqdm, scikit-learn` 8 | 9 | ### Run 10 | 11 | ```python 12 | cd domainbed/ 13 | python train.py --algorithm GroupDRO --seed 0 14 | 15 | # supported model: ERM/IRM/GroupDRO/MMD/CORAL/AndMask/MTL 16 | ``` 17 | 18 | 19 | ### Add your own domain generalization algorithm 20 | 21 | Go to `domainbed/algorithm.py` script to add your algorithm. 22 | -------------------------------------------------------------------------------- /examples/multi_pred/dti_dg/domainbed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /examples/multi_pred/dti_dg/domainbed/lib/reporting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import collections 4 | 5 | import json 6 | import os 7 | 8 | import tqdm 9 | 10 | from domainbed.lib.query import Q 11 | 12 | 13 | def load_records(path): 14 | records = [] 15 | for i, subdir in tqdm.tqdm(list(enumerate(os.listdir(path))), 16 | ncols=80, 17 | leave=False): 18 | results_path = os.path.join(path, subdir, "results.jsonl") 19 | try: 20 | with open(results_path, "r") as f: 21 | for line in f: 22 | records.append(json.loads(line[:-1])) 23 | except IOError: 24 | pass 25 | 26 | return Q(records) 27 | 28 | 29 | def get_grouped_records(records): 30 | """Group records by (trial_seed, dataset, algorithm, test_env). Because 31 | records can have multiple test envs, a given record may appear in more than 32 | one group.""" 33 | result = collections.defaultdict(lambda: []) 34 | for r in records: 35 | for test_env in r["args"]["test_envs"]: 36 | group = (r["args"]["trial_seed"], r["args"]["dataset"], 37 | r["args"]["algorithm"], test_env) 38 | result[group].append(r) 39 | return Q([{ 40 | "trial_seed": t, 41 | "dataset": d, 42 | "algorithm": a, 43 | "test_env": e, 44 | "records": Q(r) 45 | } for (t, d, a, e), r in result.items()]) 46 | -------------------------------------------------------------------------------- /examples/multi_pred/geneperturb/prepare_benchmark_dataset.py: -------------------------------------------------------------------------------- 1 | for dataset in [ 2 | "scperturb_gene_NormanWeissman2019", 3 | "scperturb_gene_ReplogleWeissman2022_rpe1", 4 | "scperturb_gene_ReplogleWeissman2022_k562_essential" 5 | ]: 6 | 7 | from tdc.benchmark_group import geneperturb_group 8 | group = geneperturb_group.GenePerturbGroup() 9 | train, val = group.get_train_valid_split(dataset=dataset) 10 | test = group.get_test() 11 | 12 | import anndata as ad 13 | adata = ad.concat([train, val, test]) 14 | adata.obs['cell_type'] = adata.obs['cell_line'] 15 | adata.var['gene_name'] = adata.var.index.values 16 | from scipy.sparse import csr_matrix 17 | adata.X = csr_matrix(adata.X) 18 | 19 | from gears import PertData 20 | 21 | pert_data = PertData('./data') # specific saved folder 22 | pert_data.new_data_process( 23 | dataset_name=dataset, 24 | adata=adata) # specific dataset name and adata object 25 | -------------------------------------------------------------------------------- /examples/multi_pred/geneperturb/run_gears.py: -------------------------------------------------------------------------------- 1 | from tdc.benchmark_group import geneperturb_group 2 | from gears.utils import filter_pert_in_go 3 | import numpy as np 4 | from gears import PertData, GEARS 5 | import pickle 6 | 7 | group = geneperturb_group.GenePerturbGroup() 8 | dataset = 'scperturb_gene_NormanWeissman2019' 9 | 10 | train, val = group.get_train_valid_split(dataset=dataset) 11 | test = group.get_test() 12 | 13 | set2conditions = { 14 | 'train': train.obs.condition.unique().tolist(), 15 | 'val': val.obs.condition.unique().tolist(), 16 | 'test': test.obs.condition.unique().tolist() 17 | } 18 | 19 | pert_data = PertData('./data') # specific saved folder 20 | pert_data.load(data_path='./data/' + 21 | dataset.lower()) # specific dataset name and adata object 22 | 23 | train_not_seen = [ 24 | i for i in set2conditions['train'] 25 | if not filter_pert_in_go(i, pert_data.pert_names) 26 | ] 27 | val_not_seen = [ 28 | i for i in set2conditions['val'] 29 | if not filter_pert_in_go(i, pert_data.pert_names) 30 | ] 31 | test_not_seen = [ 32 | i for i in set2conditions['test'] 33 | if not filter_pert_in_go(i, pert_data.pert_names) 34 | ] 35 | print('test perts not in gears', test_not_seen) 36 | print('train perts not in gears', train_not_seen) 37 | print('val perts not in gears', val_not_seen) 38 | set2conditions['train'] = np.setdiff1d(set2conditions['train'], train_not_seen) 39 | set2conditions['val'] = np.setdiff1d(set2conditions['val'], val_not_seen) 40 | set2conditions['test'] = np.setdiff1d(set2conditions['test'], test_not_seen) 41 | 42 | pickle.dump(set2conditions, open(dataset + '_set2conditions.pkl', 'wb')) 43 | split_dict_path = dataset + '_set2conditions.pkl' 44 | 45 | pert_data.prepare_split(split='custom', seed=1, split_dict_path=split_dict_path) 46 | pert_data.get_dataloader(batch_size=32, test_batch_size=128) 47 | 48 | # set up and train a model 49 | gears_model = GEARS(pert_data, device='cuda:1') 50 | gears_model.model_initialize(hidden_size=64) 51 | gears_model.train(epochs=20) 52 | 53 | # save/load model 54 | gears_model.save_model('gears_' + dataset) 55 | -------------------------------------------------------------------------------- /examples/single_pred/admet/README.md: -------------------------------------------------------------------------------- 1 | ## ADMET Benchmark Group DeepPurpose Baselines 2 | 3 | Paper: [https://doi.org/10.1093/bioinformatics/btaa1005](https://doi.org/10.1093/bioinformatics/btaa1005) 4 | 5 | GitHub: [https://github.com/kexinhuang12345/DeepPurpose](https://github.com/kexinhuang12345/DeepPurpose) 6 | 7 | In this directory, we show how to use DeepPurpose to build three models for ADMET predictions. It is roughly around 50 lines of codes for the entire 22 benchmarks with 5 random seeds in ADMET benchmark group. 8 | 9 | 10 | ## Installation 11 | 12 | ```bash 13 | conda create -n DeepPurpose python=3.6 14 | conda activate DeepPurpose 15 | conda install -c conda-forge rdkit 16 | pip install git+https://github.com/bp-kelley/descriptastorus 17 | pip install DeepPurpose 18 | pip install PyTDC 19 | ``` 20 | 21 | For build from source installation, checkout the [installation page](https://github.com/kexinhuang12345/DeepPurpose#install--usage) of DeepPurpose. 22 | 23 | ## Reproduce Results 24 | 25 | ```python 26 | python run.py --model Morgan 27 | ``` 28 | 29 | You can select the following model in the --model parameter: 'Morgan', 'RDKit2D', 'CNN', 'NeuralFP', 'MPNN', 'AttentiveFP', 'AttrMasking', 'ContextPred' 30 | 31 | ## Sample Output 32 | 33 | ```python 34 | python run.py --model RDKit2D 35 | ''' 36 | {'caco2_wang': [0.393, 0.024], 37 | 'hia_hou': [0.972, 0.008], 38 | 'pgp_broccatelli': [0.918, 0.007], 39 | 'bioavailability_ma': [0.672, 0.021], 40 | 'lipophilicity_astrazeneca': [0.574, 0.017], 41 | 'solubility_aqsoldb': [0.827, 0.047], 42 | 'bbb_martins': [0.889, 0.016], 43 | 'ppbr_az': [9.994, 0.319], 44 | 'vdss_lombardo': [0.561, 0.025], 45 | 'cyp2d6_veith': [0.616, 0.007], 46 | 'cyp3a4_veith': [0.829, 0.007], 47 | 'cyp2c9_veith': [0.742, 0.006], 48 | 'cyp2d6_substrate_carbonmangels': [0.677, 0.047], 49 | 'cyp3a4_substrate_carbonmangels': [0.639, 0.012], 50 | 'cyp2c9_substrate_carbonmangels': [0.36, 0.04], 51 | 'half_life_obach': [0.184, 0.111], 52 | 'clearance_microsome_az': [0.586, 0.014], 53 | 'clearance_hepatocyte_az': [0.382, 0.007], 54 | 'herg': [0.841, 0.02], 55 | 'ames': [0.823, 0.011], 56 | 'dili': [0.875, 0.019], 57 | 'ld50_zhu': [0.678, 0.003]} 58 | ''' 59 | ``` 60 | ## Contact 61 | 62 | Please contact [Kexin](mailto:kexinhuang@hsph.harvard.edu) if you have any question! 63 | -------------------------------------------------------------------------------- /fig/TDCneurips.pptx(1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/fig/TDCneurips.pptx(1).png -------------------------------------------------------------------------------- /fig/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/fig/logo.png -------------------------------------------------------------------------------- /fig/tdc_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/fig/tdc_overview.png -------------------------------------------------------------------------------- /fig/tdc_problems.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/fig/tdc_problems.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.33.0 2 | dataclasses>=0.6,<1.0 3 | datasets<2.20.0 4 | evaluate==0.4.2 5 | fuzzywuzzy>=0.18.0,<1.0 6 | huggingface_hub>=0.20.3,<1.0 7 | numpy>=1.26.4,<2.0.0 8 | openpyxl>=3.0.10,<4.0.0 9 | pandas>=2.1.4,<3.0.0 10 | requests>=2.31.0,<3.0.0 11 | scikit-learn>=1.2.2 12 | seaborn>=0.12.2,<1.0.0 13 | tqdm>=4.65.0,<5.0.0 14 | transformers>=4.43.0,<4.51.0 15 | cellxgene-census==1.15.0 16 | gget>=0.28.4,<1.0.0 17 | pydantic>=2.6.3,<3.0.0 18 | rdkit>=2023.9.5,<2024.3.1 19 | tiledbsoma>=1.7.2,<2.0.0 20 | -------------------------------------------------------------------------------- /run_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import sys 3 | 4 | if __name__ == '__main__': 5 | loader = unittest.TestLoader() 6 | start_dir = 'tdc/test' 7 | 8 | # Check if a specific test is provided as a command-line argument 9 | if len(sys.argv) > 1: 10 | test_name = sys.argv[1] 11 | suite = loader.loadTestsFromName(test_name) 12 | else: 13 | suite = loader.discover(start_dir) 14 | 15 | runner = unittest.TextTestRunner() 16 | res = runner.run(suite) 17 | if res.wasSuccessful(): 18 | print("All base tests passed") 19 | else: 20 | raise RuntimeError("Some base tests failed") 21 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # read the contents of README file 4 | from os import path 5 | from io import open # for Python 2 and 3 compatibility 6 | 7 | # get __version__ from _version.py 8 | ver_file = path.join("tdc", "version.py") 9 | with open(ver_file) as f: 10 | exec(f.read()) 11 | 12 | this_directory = path.abspath(path.dirname(__file__)) 13 | 14 | 15 | # read the contents of README.md 16 | def readme(): 17 | with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: 18 | return f.read() 19 | 20 | 21 | # read the contents of requirements.txt 22 | with open(path.join(this_directory, "requirements.txt"), encoding="utf-8") as f: 23 | requirements = f.read().splitlines() 24 | 25 | setup( 26 | name="pytdc", 27 | version=__version__, 28 | license="MIT", 29 | description="Therapeutics Commons", 30 | long_description=readme(), 31 | long_description_content_type="text/markdown", 32 | url="https://github.com/mims-harvard/TDC", 33 | author="PyTDC Team", 34 | author_email="amva13@alum.mit.edu", 35 | packages=find_packages(exclude=["test"]), 36 | zip_safe=False, 37 | include_package_data=True, 38 | install_requires=requirements, 39 | setup_requires=["setuptools>=38.6.0"], 40 | ) 41 | -------------------------------------------------------------------------------- /tdc/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import Evaluator 2 | from .oracles import Oracle 3 | from .benchmark_deprecated import BenchmarkGroup 4 | from .model_server.tdc_hf import tdc_hf_interface 5 | from tdc.utils.knowledge_graph import KnowledgeGraph 6 | -------------------------------------------------------------------------------- /tdc/benchmark_group/__init__.py: -------------------------------------------------------------------------------- 1 | from .admet_group import admet_group 2 | from .drugcombo_group import drugcombo_group 3 | from .dti_dg_group import dti_dg_group 4 | from .docking_group import docking_group 5 | -------------------------------------------------------------------------------- /tdc/benchmark_group/admet_group.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | from .base_group import BenchmarkGroup 6 | 7 | 8 | class admet_group(BenchmarkGroup): 9 | """Create ADMET Group Class object. 10 | 11 | Args: 12 | path (str, optional): the path to store/retrieve the ADMET group datasets. 13 | """ 14 | 15 | def __init__(self, path="./data"): 16 | """Create an ADMET benchmark group class.""" 17 | super().__init__(name="ADMET_Group", path=path, file_format="csv") 18 | -------------------------------------------------------------------------------- /tdc/benchmark_group/drugcombo_group.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | from .base_group import BenchmarkGroup 6 | 7 | 8 | class drugcombo_group(BenchmarkGroup): 9 | """create a drug combination benchmark group 10 | 11 | Args: 12 | path (str, optional): path to save/load benchmarks 13 | """ 14 | 15 | def __init__(self, path="./data"): 16 | """create a drug combination benchmark group""" 17 | super().__init__(name="DrugCombo_Group", path=path, file_format="pkl") 18 | 19 | def get_cell_line_meta_data(self): 20 | import os 21 | from ..utils.load import download_wrapper 22 | from ..utils import load_dict 23 | name = download_wrapper('drug_comb_meta_data', self.path, 24 | ['drug_comb_meta_data']) 25 | return load_dict(os.path.join(self.path, name + '.pkl')) 26 | -------------------------------------------------------------------------------- /tdc/benchmark_group/dti_dg_group.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | from .base_group import BenchmarkGroup 6 | 7 | 8 | class dti_dg_group(BenchmarkGroup): 9 | """Create a DTI domain generalization benchmark group 10 | 11 | Args: 12 | path (str, optional): path to save/load benchmarks 13 | """ 14 | 15 | def __init__(self, path="./data"): 16 | """Create a DTI domain generalization benchmark group""" 17 | super().__init__(name="DTI_DG_Group", path=path, file_format="csv") 18 | -------------------------------------------------------------------------------- /tdc/chem_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import ( 2 | validity, 3 | uniqueness, 4 | novelty, 5 | diversity, 6 | kl_divergence, 7 | fcd_distance, 8 | ) 9 | from .featurize.molconvert import MolConvert 10 | from .oracle.oracle import ( 11 | PyScreener_meta, 12 | Vina_3d, 13 | Score_3d, 14 | Vina_smiles, 15 | molecule_one_retro, 16 | ibm_rxn, 17 | askcos, 18 | isomers_c7h8n2o2, 19 | isomers_c9h10n2o2pf2cl, 20 | isomers_c11h24, 21 | valsartan_smarts, 22 | scaffold_hop, 23 | deco_hop, 24 | sitagliptin_mpo, 25 | zaleplon_mpo, 26 | amlodipine_mpo, 27 | sitagliptin_mpo_prev, 28 | zaleplon_mpo_prev, 29 | perindopril_mpo, 30 | ranolazine_mpo, 31 | fexofenadine_mpo, 32 | osimertinib_mpo, 33 | median1, 34 | median2, 35 | aripiprazole_similarity, 36 | albuterol_similarity, 37 | mestranol_similarity, 38 | celecoxib_rediscovery, 39 | troglitazone_rediscovery, 40 | thiothixene_rediscovery, 41 | median_meta, 42 | isomer_meta, 43 | rediscovery_meta, 44 | similarity_meta, 45 | jnk3, 46 | gsk3b, 47 | SA, 48 | cyp3a4_veith, 49 | drd2, 50 | qed, 51 | penalized_logp, 52 | ) 53 | from .oracle.filter import MolFilter 54 | -------------------------------------------------------------------------------- /tdc/chem_utils/featurize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tdc/chem_utils/featurize/__init__.py -------------------------------------------------------------------------------- /tdc/chem_utils/oracle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tdc/chem_utils/oracle/__init__.py -------------------------------------------------------------------------------- /tdc/chem_utils/oracle/docking.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from vina import Vina 3 | from time import time 4 | 5 | ligand_pdbqt_file = sys.argv[1] 6 | receptor_pdbqt_file = sys.argv[2] 7 | output_file = sys.argv[3] 8 | center = [sys.argv[4], sys.argv[5], sys.argv[6]] 9 | center = [float(i) for i in center] 10 | box_size = [sys.argv[7], sys.argv[8], sys.argv[9]] 11 | box_size = [float(i) for i in box_size] 12 | 13 | 14 | # print(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, box_size) 15 | def docking(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, 16 | box_size): 17 | t1 = time() 18 | v = Vina(sf_name="vina") 19 | v.set_receptor(rigid_pdbqt_filename=receptor_pdbqt_file) 20 | v.set_ligand_from_file(ligand_pdbqt_file) 21 | v.compute_vina_maps(center=center, box_size=box_size) 22 | energy = v.score() 23 | energy_minimized = v.optimize() 24 | t2 = time() 25 | print("vina takes seconds: ", str(t2 - t1)[:5]) 26 | with open(output_file, "w") as fout: 27 | fout.write(str(energy_minimized[0])) 28 | 29 | 30 | docking(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, box_size) 31 | """ 32 | Example: 33 | python XXXX.py data/1iep_ligand.pdbqt ./data/1iep_receptor.pdbqt ./data/out 15.190 53.903 16.917 20 20 20 34 | """ 35 | -------------------------------------------------------------------------------- /tdc/dataset_configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import DatasetConfig 2 | from .brown_mdm2_ace2_12ca5_config import BrownProteinPeptideConfig 3 | -------------------------------------------------------------------------------- /tdc/dataset_configs/brown_mdm2_ace2_12ca5_config.py: -------------------------------------------------------------------------------- 1 | from .config import DatasetConfig 2 | from ..feature_generators.protein_feature_generator import ProteinFeatureGenerator 3 | 4 | 5 | class BrownProteinPeptideConfig(DatasetConfig): 6 | """Configuration for the brown-protein-peptide datasets""" 7 | 8 | def __init__(self): 9 | super(BrownProteinPeptideConfig, self).__init__( 10 | dataset_name="brown_mdm2_ace2_12ca5", 11 | data_processing_class=ProteinFeatureGenerator, 12 | functions_to_run=[ 13 | "autofill_identifier", "create_range", "insert_protein_sequence" 14 | ], 15 | args_for_functions=[{ 16 | "autofill_column": "Name", 17 | "key_column": "Sequence", 18 | }, { 19 | "column": "KD (nM)", 20 | "keys": ["Putative binder"], 21 | "subs": [0] 22 | }, { 23 | "gene_column": "Protein Target" 24 | }], 25 | var_map={ 26 | "X1": "Sequence", 27 | "X2": "protein_or_rna_sequence", 28 | "ID1": "Name", 29 | "ID2": "Protein Target", 30 | }, 31 | ) 32 | -------------------------------------------------------------------------------- /tdc/dataset_configs/cellxgene_config.py: -------------------------------------------------------------------------------- 1 | from .config import DatasetConfig 2 | from ..feature_generators.cellxgene_generator import CellXGeneFeatureGenerator 3 | 4 | 5 | class CellXGeneConfig(DatasetConfig): 6 | 7 | def __init__(self): 8 | super(CellXGeneConfig, self).__init__( 9 | data_processing_class=CellXGeneFeatureGenerator, 10 | functions_to_run=[ 11 | "get_dense_soma_dataframe", "format_cellxgene_dataframe" 12 | ], 13 | args_for_functions=[{}, {}, {}], 14 | ) 15 | -------------------------------------------------------------------------------- /tdc/dataset_configs/config_map.py: -------------------------------------------------------------------------------- 1 | from .brown_mdm2_ace2_12ca5_config import BrownProteinPeptideConfig 2 | from .opentargets_dti import OpentargetsDTI 3 | from .scperturb_config import SCPerturb, SCPerturb_Gene 4 | 5 | scperturb_datasets = [ 6 | "scperturb_drug_AissaBenevolenskaya2021", 7 | "scperturb_drug_SrivatsanTrapnell2020_sciplex2", 8 | "scperturb_drug_SrivatsanTrapnell2020_sciplex3", 9 | "scperturb_drug_SrivatsanTrapnell2020_sciplex4", 10 | "scperturb_drug_ZhaoSims2021", 11 | ] 12 | 13 | scperturb_gene_datasets = [ 14 | "scperturb_gene_NormanWeissman2019", 15 | "scperturb_gene_ReplogleWeissman2022_rpe1", 16 | "scperturb_gene_ReplogleWeissman2022_k562_essential", 17 | ] 18 | 19 | 20 | class ConfigMap(dict): 21 | """ 22 | The ConfigMap stores key-value pairs where the key is a dataset string name and the value is a config class. 23 | """ 24 | 25 | def __init__(self): 26 | self["brown_mdm2_ace2_12ca5"] = BrownProteinPeptideConfig 27 | for ds in scperturb_datasets: 28 | self[ds] = SCPerturb 29 | for ds in scperturb_gene_datasets: 30 | self[ds] = SCPerturb_Gene 31 | self["opentargets_dti"] = OpentargetsDTI 32 | -------------------------------------------------------------------------------- /tdc/dataset_configs/scperturb_config.py: -------------------------------------------------------------------------------- 1 | from .config import DatasetConfig 2 | from ..feature_generators.anndata_to_dataframe import AnnDataToDataFrame 3 | 4 | 5 | class SCPerturb(DatasetConfig): 6 | """Configuration for the scPerturb datasets""" 7 | 8 | def __init__(self): 9 | super(SCPerturb, self).__init__( 10 | data_processing_class=AnnDataToDataFrame, 11 | functions_to_run=["anndata_to_df"], 12 | args_for_functions=[ 13 | { 14 | "obs_cols": [ 15 | "ncounts", 'celltype', 'cell_line', 'cancer', 'disease', 16 | 'tissue_type', 'perturbation', 'perturbation_type', 17 | 'ngenes' 18 | ], 19 | }, 20 | ], 21 | ) 22 | 23 | 24 | class SCPerturb_Gene(DatasetConfig): 25 | """Configuration for the scPerturb genetic perturbation datasets""" 26 | 27 | def __init__(self): 28 | super(SCPerturb_Gene, self).__init__( 29 | data_processing_class=AnnDataToDataFrame, 30 | functions_to_run=["anndata_to_df"], 31 | args_for_functions=[ 32 | { 33 | "obs_cols": [ 34 | 'UMI_count', 'cancer', 'cell_line', 'disease', 35 | 'guide_id', 'ncounts', 'ngenes', 'nperts', 'organism', 36 | 'percent_mito', 'percent_ribo', 'perturbation', 37 | 'perturbation_type', 'tissue_type' 38 | ], 39 | }, 40 | ], 41 | ) 42 | -------------------------------------------------------------------------------- /tdc/feature_generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tdc/feature_generators/__init__.py -------------------------------------------------------------------------------- /tdc/feature_generators/anndata_to_dataframe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class for customizations when transforming anndata format objects to pandas.DataFrame or other formats 3 | """ 4 | 5 | import pandas as pd 6 | import numpy as np 7 | from .data_feature_generator import DataFeatureGenerator 8 | 9 | 10 | class AnnDataToDataFrame(DataFeatureGenerator): 11 | 12 | @classmethod 13 | def anndata_to_df(cls, dataset=None, obs_cols=None): 14 | if dataset is None: 15 | raise ValueError("dataset must be specified") 16 | adata = dataset 17 | if not isinstance(adata.X, np.ndarray): 18 | adata.X = adata.X.todense() 19 | df_main = pd.DataFrame(adata.X if adata.X is not None else adata.X, 20 | columns=adata.var_names, 21 | index=adata.obs_names) 22 | dfobs = pd.DataFrame(adata.obs, 23 | columns=adata.obs.keys(), 24 | index=adata.obs.index) 25 | if obs_cols is None: 26 | return df_main 27 | elif obs_cols == "ALL": 28 | return df_main.merge(dfobs, 29 | left_index=True, 30 | right_index=True, 31 | how='left') 32 | elif isinstance(obs_cols, list): 33 | return df_main.merge(dfobs[obs_cols], 34 | left_index=True, 35 | right_index=True, 36 | how='left') 37 | else: 38 | raise ValueError("obs_cols must be a list of column names or 'ALL'") 39 | -------------------------------------------------------------------------------- /tdc/feature_generators/base.py: -------------------------------------------------------------------------------- 1 | class FeatureGenerator(object): 2 | pass 3 | -------------------------------------------------------------------------------- /tdc/feature_generators/cellxgene_generator.py: -------------------------------------------------------------------------------- 1 | from .data_feature_generator import DataFeatureGenerator 2 | 3 | 4 | class CellXGeneFeatureGenerator(DataFeatureGenerator): 5 | 6 | @classmethod 7 | def get_dense_soma_dataframe(cls, dataset): 8 | return dataset[dataset["soma_data"] != 0] 9 | 10 | @classmethod 11 | def format_cellxgene_dataframe(cls, dataset): 12 | dataset.columns = ["cell_idx", "gene_idx", "expression"] 13 | return dataset 14 | -------------------------------------------------------------------------------- /tdc/generation/__init__.py: -------------------------------------------------------------------------------- 1 | from .molgen import MolGen 2 | from .reaction import Reaction 3 | from .retrosyn import RetroSyn 4 | from .sbdd import SBDD 5 | -------------------------------------------------------------------------------- /tdc/generation/ligandmolgen.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT" 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | 9 | from . import bi_generation_dataset 10 | from ..metadata import dataset_names 11 | 12 | 13 | class LigandMolGen(bi_generation_dataset.DataLoader): 14 | """Data loader class accessing to pocket-based ligand generation task.""" 15 | 16 | def __init__(self, name, path="./data", print_stats=False): 17 | """To create an data loader object for pocket-based ligand generation task. The goal is to generate ligands 18 | that bind to a given protein pocket. 19 | 20 | Args: 21 | name (str): the name of the datset 22 | path (str, optional): the path to the saved data file. 23 | print_stats (bool, optional): whether to print the basic statistics 24 | """ 25 | super().__init__(name, path, print_stats) 26 | -------------------------------------------------------------------------------- /tdc/generation/molgen.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | 9 | from . import generation_dataset 10 | from ..metadata import dataset_names 11 | 12 | 13 | class MolGen(generation_dataset.DataLoader): 14 | """Data loader class accessing to molecular generation task (distribution learning)""" 15 | 16 | def __init__(self, 17 | name, 18 | path="./data", 19 | print_stats=False, 20 | column_name="smiles"): 21 | """To create an data loader object for molecular generation task. The goal is to generate diverse, 22 | novel molecules that has desirable chemical properties. One can combined with oracle functions. 23 | 24 | Args: 25 | name (str): the name of the datset 26 | path (str, optional): the path to the saved data file. 27 | print_stats (bool, optional): whether to print the basic statistics 28 | column_name (str, optional): the name of the column containing molecular data. 29 | """ 30 | super().__init__(name, path, print_stats, column_name) 31 | -------------------------------------------------------------------------------- /tdc/generation/reaction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT" 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | 9 | from . import generation_dataset 10 | from ..metadata import dataset_names 11 | 12 | 13 | class Reaction(generation_dataset.PairedDataLoader): 14 | """Data loader class accessing to forward reaction prediction task.""" 15 | 16 | def __init__( 17 | self, 18 | name, 19 | path="./data", 20 | print_stats=False, 21 | input_name="reactant", 22 | output_name="product", 23 | ): 24 | """To create an data loader object for forward reaction prediction task. The goal is to predict 25 | the reaction products given a set of reactants 26 | 27 | Args: 28 | name (str): the name of the datset 29 | path (str, optional): the path to the saved data file. 30 | print_stats (bool, optional): whether to print the basic statistics 31 | input_name (str, optional): the name of the column containing input molecular data (reactant) 32 | output_name (str, optional): the name of the column containing output molecular data (product) 33 | """ 34 | super().__init__(name, path, print_stats, input_name, output_name) 35 | -------------------------------------------------------------------------------- /tdc/model_server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tdc/model_server/__init__.py -------------------------------------------------------------------------------- /tdc/model_server/model_loaders/scvi_loader.py: -------------------------------------------------------------------------------- 1 | class scVILoader(): 2 | 3 | def __init__(self): 4 | pass 5 | 6 | def load(self, census_version): 7 | import requests 8 | import os 9 | 10 | scvi_url = f"https://cellxgene-contrib-public.s3.us-west-2.amazonaws.com/models/scvi/{census_version}/homo_sapiens/model.pt" 11 | os.makedirs(os.path.join(os.getcwd(), 'scvi_model'), exist_ok=True) 12 | 13 | output_path = os.path.join('scvi_model', 'model.pt') 14 | 15 | try: 16 | response = requests.get(scvi_url, verify=False) 17 | if response.status_code == 404: 18 | raise Exception( 19 | 'Census version not found, defaulting to version 2024-07-01' 20 | ) 21 | except Exception as e: 22 | print(e) 23 | census_version = "2024-07-01" 24 | scvi_url = f"https://cellxgene-contrib-public.s3.us-west-2.amazonaws.com/models/scvi/2024-07-01/homo_sapiens/model.pt" 25 | response = requests.get(scvi_url, verify=False) 26 | 27 | with open(output_path, "wb") as file: 28 | file.write(response.content) 29 | 30 | print( 31 | f'scVI version {census_version} downloaded to {output_path} in current directory' 32 | ) 33 | -------------------------------------------------------------------------------- /tdc/model_server/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tdc/model_server/models/__init__.py -------------------------------------------------------------------------------- /tdc/model_server/tokenizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tdc/model_server/tokenizers/__init__.py -------------------------------------------------------------------------------- /tdc/model_server/tokenizers/scgpt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List, Tuple 3 | 4 | from ...utils.load import pd_load, download_wrapper 5 | 6 | 7 | def tokenize_batch( 8 | data: np.ndarray, 9 | gene_ids: np.ndarray, 10 | return_pt: bool = True, 11 | append_cls: bool = True, 12 | include_zero_gene: bool = False, 13 | cls_id: str = "", 14 | ) -> List[Tuple]: 15 | """ 16 | Tokenize a batch of data. Returns a list of tuple (gene_id, count). 17 | 18 | Args: 19 | data (array-like): A batch of data, with shape (batch_size, n_features). 20 | n_features equals the number of all genes. 21 | gene_ids (array-like): A batch of gene ids, with shape (n_features,). 22 | return_pt (bool): Whether to return torch tensors of gene_ids and counts, 23 | default to True. 24 | 25 | Returns: 26 | list: A list of tuple (gene_names, counts) of non zero gene expressions. 27 | """ 28 | download_wrapper("scgpt_vocab", "./data", ["scgpt_vocab"]) 29 | vocab_map = pd_load("scgpt_vocab", "./data") 30 | if data.shape[1] != len(gene_ids): 31 | raise ValueError( 32 | f"Number of features in data ({data.shape[1]}) does not match " 33 | f"number of gene_ids ({len(gene_ids)}).") 34 | 35 | tokenized_data = [] 36 | for i in range(len(data)): 37 | row = data[i] 38 | if include_zero_gene: 39 | values = row 40 | genes = gene_ids 41 | else: 42 | idx = np.nonzero(row)[0] 43 | values = row[idx] 44 | genes = gene_ids[idx] 45 | if append_cls: 46 | genes = np.insert(genes, 0, cls_id) 47 | values = np.insert(values, 0, 0) 48 | if return_pt: 49 | import torch 50 | genes = torch.tensor([vocab_map.get(x, 0) for x in genes], 51 | dtype=torch.int64) 52 | values = torch.from_numpy(values).float() 53 | tokenized_data.append((genes, values)) 54 | return tokenized_data 55 | 56 | 57 | class scGPTTokenizer: 58 | 59 | def __init__(self): 60 | pass 61 | 62 | @classmethod 63 | def tokenize_cell_vectors(cls, data, gene_names): 64 | """ 65 | Tokenizing single-cell gene expression vectors formatted as anndata types 66 | """ 67 | return tokenize_batch(data, gene_names) 68 | -------------------------------------------------------------------------------- /tdc/multi_pred/__init__.py: -------------------------------------------------------------------------------- 1 | from .antibodyaff import AntibodyAff 2 | from .catalyst import Catalyst 3 | from .ddi import DDI 4 | from .drugres import DrugRes 5 | from .drugsyn import DrugSyn 6 | from .dti import DTI 7 | from .gda import GDA 8 | from .mti import MTI 9 | from .peptidemhc import PeptideMHC 10 | from .ppi import PPI 11 | from .proteinpeptide import ProteinPeptide 12 | from .test_multi_pred import TestMultiPred 13 | from .tcr_epi import TCREpitopeBinding 14 | from .trialoutcome import TrialOutcome 15 | from .perturboutcome import PerturbOutcome 16 | -------------------------------------------------------------------------------- /tdc/multi_pred/anndata_dataset.py: -------------------------------------------------------------------------------- 1 | from .multi_pred_dataset import DataLoader as DL 2 | from ..dataset_configs.config_map import ConfigMap 3 | from ..feature_generators.anndata_to_dataframe import AnnDataToDataFrame 4 | 5 | 6 | class DataLoader(DL): 7 | 8 | def __init__(self, 9 | name, 10 | path, 11 | print_stats=False, 12 | dataset_names=None, 13 | no_convert=True): 14 | super(DataLoader, self).__init__(name, path, print_stats, dataset_names) 15 | self.adata = self.df # this is in AnnData format 16 | if no_convert: 17 | return 18 | cmap = ConfigMap() 19 | self.cmap = cmap 20 | self.config = cmap.get(name) 21 | if self.config is None: 22 | # default to converting adata to dataframe as is 23 | self.df = AnnDataToDataFrame.anndata_to_df(self.adata) 24 | else: 25 | cf = self.config() 26 | self.df = cf.processing_callback(self.adata) 27 | -------------------------------------------------------------------------------- /tdc/multi_pred/antibodyaff.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset, multi_pred_dataset 12 | from ..metadata import dataset_names 13 | 14 | 15 | class AntibodyAff(bi_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Antibody-antigen Affinity Prediction task. 17 | More info: https://tdcommons.ai/multi_pred_tasks/antibodyaff/ 18 | 19 | Task Description: Regression. Given the amino acid sequence of antibody and antigen, predict their binding affinity. 20 | 21 | 22 | Args: 23 | name (str): the dataset name. 24 | path (str, optional): 25 | The path to save the data file, defaults to './data' 26 | label_name (str, optional): 27 | For multi-label dataset, specify the label name, defaults to None 28 | print_stats (bool, optional): 29 | Whether to print basic statistics of the dataset, defaults to False 30 | 31 | 32 | """ 33 | 34 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 35 | """Create Antibody-antigen Affinity dataloader object""" 36 | super().__init__( 37 | name, 38 | path, 39 | label_name, 40 | print_stats, 41 | dataset_names=dataset_names["AntibodyAff"], 42 | ) 43 | self.entity1_name = "Antibody" 44 | self.entity2_name = "Antigen" 45 | self.two_types = True 46 | 47 | if print_stats: 48 | self.print_stats() 49 | 50 | print("Done!", flush=True, file=sys.stderr) 51 | -------------------------------------------------------------------------------- /tdc/multi_pred/catalyst.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset, multi_pred_dataset 12 | from ..metadata import dataset_names 13 | 14 | 15 | class Catalyst(bi_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Catalyst Prediction task 17 | More info: https://tdcommons.ai/multi_pred_tasks/catalyst/ 18 | 19 | Task Description: Given reactant and product set X, predict the catalyst Y from a set of most common catalysts. 20 | 21 | 22 | Args: 23 | name (str): the dataset name. 24 | path (str, optional): 25 | The path to save the data file, defaults to './data' 26 | label_name (str, optional): 27 | For multi-label dataset, specify the label name, defaults to None 28 | print_stats (bool, optional): 29 | Whether to print basic statistics of the dataset, defaults to False 30 | 31 | """ 32 | 33 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 34 | """Create Catalyst Prediction dataloader object""" 35 | super().__init__(name, 36 | path, 37 | label_name, 38 | print_stats, 39 | dataset_names=dataset_names["Catalyst"]) 40 | self.entity1_name = "Reactant" 41 | self.entity2_name = "Product" 42 | self.two_types = True 43 | 44 | if print_stats: 45 | self.print_stats() 46 | 47 | print("Done!", flush=True, file=sys.stderr) 48 | -------------------------------------------------------------------------------- /tdc/multi_pred/ddi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset, multi_pred_dataset 12 | from ..metadata import dataset_names 13 | 14 | 15 | class DDI(bi_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Drug-Drug Interaction Prediction task 17 | More info: https://tdcommons.ai/multi_pred_tasks/ddi/ 18 | 19 | Task Description: Multi-class classification. Given the SMILES strings of two drugs, predict their interaction type. 20 | 21 | Args: 22 | name (str): the dataset name. 23 | path (str, optional): 24 | The path to save the data file, defaults to './data' 25 | label_name (str, optional): 26 | For multi-label dataset, specify the label name, defaults to None 27 | print_stats (bool, optional): 28 | Whether to print basic statistics of the dataset, defaults to False 29 | 30 | """ 31 | 32 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 33 | """Create Drug-Drug Interaction (DDI) Prediction dataloader object""" 34 | super().__init__(name, 35 | path, 36 | label_name, 37 | print_stats, 38 | dataset_names=dataset_names["DDI"]) 39 | self.entity1_name = "Drug1" 40 | self.entity2_name = "Drug2" 41 | self.two_types = False 42 | 43 | if print_stats: 44 | self.print_stats() 45 | 46 | print("Done!", flush=True, file=sys.stderr) 47 | 48 | def print_stats(self): 49 | """print the statistics of the dataset""" 50 | import numpy as np 51 | 52 | print_sys("--- Dataset Statistics ---") 53 | print( 54 | "There are " + 55 | str(len(np.unique(self.entity1.tolist() + self.entity2.tolist()))) + 56 | " unique drugs.", 57 | flush=True, 58 | file=sys.stderr, 59 | ) 60 | print( 61 | "There are " + str(len(self.y)) + " drug-drug pairs.", 62 | flush=True, 63 | file=sys.stderr, 64 | ) 65 | print_sys("--------------------------") 66 | -------------------------------------------------------------------------------- /tdc/multi_pred/drugres.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from ..utils.load import download_wrapper, pd_load 12 | from . import bi_pred_dataset, multi_pred_dataset 13 | from ..metadata import dataset_names 14 | 15 | 16 | class DrugRes(bi_pred_dataset.DataLoader): 17 | """Data loader class to load datasets in Drug Response Prediction Task. 18 | More info: https://tdcommons.ai/multi_pred_tasks/drugres/ 19 | 20 | Task Description: Regression. Given the gene expression of cell lines and the SMILES of drug, predict the drug sensitivity level. 21 | 22 | Args: 23 | name (str): the dataset name. 24 | path (str, optional): 25 | The path to save the data file, defaults to './data' 26 | label_name (str, optional): 27 | For multi-label dataset, specify the label name, defaults to None 28 | print_stats (bool, optional): 29 | Whether to print basic statistics of the dataset, defaults to False 30 | 31 | """ 32 | 33 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 34 | """Create Drug Response Prediction dataloader object""" 35 | super().__init__(name, 36 | path, 37 | label_name, 38 | print_stats, 39 | dataset_names=dataset_names["DrugRes"]) 40 | self.entity1_name = "Drug" 41 | self.entity2_name = "Cell Line" 42 | self.two_types = True 43 | self.path = path 44 | 45 | if print_stats: 46 | self.print_stats() 47 | 48 | print("Done!", flush=True, file=sys.stderr) 49 | 50 | def get_gene_symbols(self): 51 | """ 52 | Retrieve the gene symbols for the cell line gene expression 53 | """ 54 | path = self.path 55 | name = download_wrapper("gdsc_gene_symbols", path, 56 | ["gdsc_gene_symbols"]) 57 | print_sys("Loading...") 58 | import pandas as pd 59 | import os 60 | 61 | df = pd.read_csv(os.path.join(path, name + ".tab"), sep="\t") 62 | return df.values.reshape(-1,) 63 | -------------------------------------------------------------------------------- /tdc/multi_pred/drugsyn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset, multi_pred_dataset 12 | from ..metadata import dataset_names 13 | 14 | 15 | class DrugSyn(multi_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Drug Synergy Prediction task. 17 | More info: https://tdcommons.ai/multi_pred_tasks/drugsyn/ 18 | 19 | Task Description: Regression. 20 | Given the gene expression of cell lines and two SMILES strings of the drug combos, 21 | predict the drug synergy level. 22 | 23 | Args: 24 | name (str): the dataset name. 25 | path (str, optional): 26 | The path to save the data file, defaults to './data' 27 | print_stats (bool, optional): 28 | Whether to print basic statistics of the dataset, defaults to False 29 | 30 | """ 31 | 32 | def __init__(self, name, path="./data", print_stats=False): 33 | """Create Drug Synergy Prediction dataloader object""" 34 | super().__init__(name, 35 | path, 36 | print_stats, 37 | dataset_names=dataset_names["DrugSyn"]) 38 | 39 | if print_stats: 40 | self.print_stats() 41 | 42 | print("Done!", flush=True, file=sys.stderr) 43 | -------------------------------------------------------------------------------- /tdc/multi_pred/gda.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset, multi_pred_dataset 12 | from ..metadata import dataset_names 13 | 14 | 15 | class GDA(bi_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Gene-Disease Association Prediction task. 17 | More info: https://tdcommons.ai/multi_pred_tasks/gdi/ 18 | 19 | 20 | Task Description: Regression. 21 | Given the disease description and the amino acid sequence of the gene, predict their association. 22 | 23 | Args: 24 | name (str): the dataset name. 25 | path (str, optional): 26 | The path to save the data file, defaults to './data' 27 | label_name (str, optional): 28 | For multi-label dataset, specify the label name, defaults to None 29 | print_stats (bool, optional): 30 | Whether to print basic statistics of the dataset, defaults to False 31 | 32 | 33 | """ 34 | 35 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 36 | """Create Gene-Disease Association Prediction dataloader object""" 37 | super().__init__(name, 38 | path, 39 | label_name, 40 | print_stats, 41 | dataset_names=dataset_names["GDA"]) 42 | self.entity1_name = "Gene" 43 | self.entity2_name = "Disease" 44 | self.two_types = True 45 | 46 | if print_stats: 47 | self.print_stats() 48 | 49 | print("Done!", flush=True, file=sys.stderr) 50 | -------------------------------------------------------------------------------- /tdc/multi_pred/mti.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset, multi_pred_dataset 12 | from ..metadata import dataset_names 13 | 14 | 15 | class MTI(bi_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in MicroRNA-Target Interaction Prediction task. 17 | More info: https://tdcommons.ai/multi_pred_tasks/mti/ 18 | 19 | 20 | Task Description: Binary Classification. 21 | Given the miRNA mature sequence and target amino acid sequence, 22 | predict their likelihood of interaction. 23 | 24 | Args: 25 | name (str): the dataset name. 26 | path (str, optional): 27 | The path to save the data file, defaults to './data' 28 | label_name (str, optional): 29 | For multi-label dataset, specify the label name, defaults to None 30 | print_stats (bool, optional): 31 | Whether to print basic statistics of the dataset, defaults to False 32 | 33 | """ 34 | 35 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 36 | """Create MicroRNA-Target Interaction Prediction dataloader object""" 37 | super().__init__(name, 38 | path, 39 | label_name, 40 | print_stats, 41 | dataset_names=dataset_names["MTI"]) 42 | self.entity1_name = "miRNA" 43 | self.entity2_name = "Target" 44 | self.two_types = True 45 | 46 | if print_stats: 47 | self.print_stats() 48 | 49 | print("Done!", flush=True, file=sys.stderr) 50 | -------------------------------------------------------------------------------- /tdc/multi_pred/peptidemhc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset, multi_pred_dataset 12 | from ..metadata import dataset_names 13 | 14 | 15 | class PeptideMHC(bi_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Peptide-MHC Binding Prediction task. 17 | More info: https://tdcommons.ai/multi_pred_tasks/peptidemhc/ 18 | 19 | Task Description: Regression. 20 | Given the amino acid sequence of peptide and the pseudo amino acid sequence of MHC, 21 | predict the binding affinity. 22 | 23 | Args: 24 | name (str): the dataset name. 25 | path (str, optional): 26 | The path to save the data file, defaults to './data' 27 | label_name (str, optional): 28 | For multi-label dataset, specify the label name, defaults to None 29 | print_stats (bool, optional): 30 | Whether to print basic statistics of the dataset, defaults to False 31 | 32 | 33 | """ 34 | 35 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 36 | """Create Peptide-MHC Prediction dataloader object""" 37 | super().__init__( 38 | name, 39 | path, 40 | label_name, 41 | print_stats, 42 | dataset_names=dataset_names["PeptideMHC"], 43 | ) 44 | self.entity1_name = "Peptide" 45 | self.entity2_name = "MHC" 46 | self.two_types = True 47 | 48 | if print_stats: 49 | self.print_stats() 50 | 51 | print("Done!", flush=True, file=sys.stderr) 52 | -------------------------------------------------------------------------------- /tdc/multi_pred/ppi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset, multi_pred_dataset 12 | from ..metadata import dataset_names 13 | 14 | 15 | class PPI(bi_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Protein-Protein Interaction Prediction task. 17 | More info: https://tdcommons.ai/multi_pred_tasks/ppi/ 18 | 19 | Task Description: Binary Classification. Given the target amino acid sequence pairs, predict if they interact or not. 20 | 21 | 22 | Args: 23 | name (str): the dataset name. 24 | path (str, optional): 25 | The path to save the data file, defaults to './data' 26 | label_name (str, optional): 27 | For multi-label dataset, specify the label name, defaults to None 28 | print_stats (bool, optional): 29 | Whether to print basic statistics of the dataset, defaults to False 30 | 31 | """ 32 | 33 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 34 | """Create Protein-Protein Interaction Prediction dataloader object""" 35 | super().__init__(name, 36 | path, 37 | label_name, 38 | print_stats, 39 | dataset_names=dataset_names["PPI"]) 40 | self.entity1_name = "Protein1" 41 | self.entity2_name = "Protein2" 42 | self.two_types = False 43 | 44 | if print_stats: 45 | self.print_stats() 46 | 47 | print("Done!", flush=True, file=sys.stderr) 48 | 49 | def print_stats(self): 50 | """print the statistics of the dataset""" 51 | import numpy as np 52 | 53 | print_sys("--- Dataset Statistics ---") 54 | print( 55 | "There are " + 56 | str(len(np.unique(self.entity1.tolist() + self.entity2.tolist()))) + 57 | " unique proteins.", 58 | flush=True, 59 | file=sys.stderr, 60 | ) 61 | print( 62 | "There are " + str(len(self.y)) + " protein-protein pairs.", 63 | flush=True, 64 | file=sys.stderr, 65 | ) 66 | print_sys("--------------------------") 67 | -------------------------------------------------------------------------------- /tdc/multi_pred/proteinpeptide.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset 12 | from ..metadata import dataset_names 13 | from ..dataset_configs.config_map import ConfigMap 14 | 15 | 16 | class ProteinPeptide(bi_pred_dataset.DataLoader): 17 | """Data loader class to load datasets in Protein-Peptide Binding Prediction task. 18 | More info: TODO 19 | 20 | Task Description: Regression. 21 | Given the amino acid sequence of peptide and (TODO: complete), 22 | predict the binding affinity. 23 | 24 | Args: 25 | name (str): the dataset name. 26 | path (str, optional): 27 | The path to save the data file, defaults to './data' 28 | label_name (str, optional): 29 | For multi-label dataset, specify the label name, defaults to None 30 | print_stats (bool, optional): 31 | Whether to print basic statistics of the dataset, defaults to False 32 | 33 | 34 | """ 35 | 36 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 37 | """Create Protein-Peptide Prediction dataloader object""" 38 | label_name = label_name if label_name is not None else "KD (nm)" # TODO: this column should be parsed into float and upper/lower 39 | cfm = ConfigMap() 40 | config = cfm.get(name) 41 | super().__init__( 42 | name, 43 | path, 44 | label_name, 45 | print_stats, 46 | dataset_names=dataset_names["ProteinPeptide"], 47 | data_config=config(), 48 | ) 49 | self.entity1_name = "Sequence" # peptide sequence 50 | self.entity2_name = "Protein Target" # protein target label 51 | self.two_types = True 52 | 53 | if print_stats: 54 | self.print_stats() 55 | 56 | print("Done!", flush=True, file=sys.stderr) 57 | -------------------------------------------------------------------------------- /tdc/multi_pred/tcr_epi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from ..utils.load import download_wrapper, pd_load 12 | from . import bi_pred_dataset, multi_pred_dataset 13 | from ..metadata import dataset_names 14 | 15 | 16 | class TCREpitopeBinding(multi_pred_dataset.DataLoader): 17 | """Data loader class to load datasets in T cell receptor (TCR) Specificity Prediction Task. 18 | More info: 19 | 20 | Task Description: Given the TCR and epitope sequence, predict binding probability. 21 | 22 | Args: 23 | name (str): the dataset name. 24 | path (str, optional): 25 | The path to save the data file, defaults to './data' 26 | print_stats (bool, optional): 27 | Whether to print basic statistics of the dataset, defaults to False 28 | 29 | """ 30 | 31 | def __init__(self, name, path="./data", print_stats=False): 32 | """Create TCR Specificity Prediction dataloader object""" 33 | super().__init__(name, 34 | path, 35 | print_stats, 36 | dataset_names=dataset_names["TCREpitopeBinding"]) 37 | self.entity1_name = "TCR" 38 | self.entity2_name = "Epitope" 39 | 40 | if print_stats: 41 | self.print_stats() 42 | 43 | print("Done!", flush=True, file=sys.stderr) 44 | -------------------------------------------------------------------------------- /tdc/multi_pred/test_multi_pred.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset, multi_pred_dataset 12 | from ..metadata import dataset_names 13 | 14 | 15 | class TestMultiPred(bi_pred_dataset.DataLoader): 16 | """Summary 17 | 18 | Attributes: 19 | entity1_name (str): Description 20 | entity2_name (str): Description 21 | two_types (bool): Description 22 | """ 23 | 24 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 25 | """Summary 26 | 27 | Args: 28 | name (TYPE): Description 29 | path (str, optional): Description 30 | label_name (None, optional): Description 31 | print_stats (bool, optional): Description 32 | """ 33 | super().__init__( 34 | name, 35 | path, 36 | label_name, 37 | print_stats, 38 | dataset_names=dataset_names["test_multi_pred"], 39 | ) 40 | self.entity1_name = "Antibody" 41 | self.entity2_name = "Antigen" 42 | self.two_types = True 43 | 44 | if print_stats: 45 | self.print_stats() 46 | 47 | print("Done!", flush=True, file=sys.stderr) 48 | -------------------------------------------------------------------------------- /tdc/multi_pred/trialoutcome.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | 10 | from ..utils import print_sys 11 | from . import bi_pred_dataset, multi_pred_dataset 12 | from ..metadata import dataset_names 13 | 14 | 15 | class TrialOutcome(multi_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in clinical trial outcome Prediction task. 17 | More info: https://tdcommons.ai/multi_pred_tasks/trialoutcome/ 18 | 19 | 20 | Task Description: Binary Classification. 21 | Given the drug molecule, disease code (ICD) and trial protocol (eligibility criteria), 22 | predict their trial approval rate. 23 | 24 | Args: 25 | name (str): the dataset name. 26 | path (str, optional): 27 | The path to save the data file, defaults to './data' 28 | label_name (str, optional): 29 | For multi-label dataset, specify the label name, defaults to None 30 | print_stats (bool, optional): 31 | Whether to print basic statistics of the dataset, defaults to False 32 | 33 | """ 34 | 35 | def __init__(self, name, path="./data", label_name=None, print_stats=False): 36 | """Create Clinical Trial Outcome Prediction dataloader object""" 37 | super().__init__(name, 38 | path, 39 | print_stats, 40 | dataset_names=dataset_names["TrialOutcome"]) 41 | self.entity1_name = "drug_molecule" 42 | self.entity2_name = "disease_code" 43 | # self.entity3_name = "eligibility_criteria" 44 | 45 | if print_stats: 46 | self.print_stats() 47 | 48 | print("Done!", flush=True, file=sys.stderr) 49 | -------------------------------------------------------------------------------- /tdc/resource/__init__.py: -------------------------------------------------------------------------------- 1 | from .primekg import PrimeKG 2 | from .cellxgene_census import CensusResource 3 | -------------------------------------------------------------------------------- /tdc/resource/pharmone.py: -------------------------------------------------------------------------------- 1 | from ..utils import general_load 2 | """ 3 | Resource class for the Eve Bio (https://evebio.org/) Pharmone Map. 4 | """ 5 | 6 | 7 | class PharmoneMap(object): 8 | 9 | def __init__(self, path="./data"): 10 | self.path = path 11 | 12 | def get_data(self): 13 | return general_load('evebio_pharmone_v1_detailed_result_table', 14 | self.path, '\t') # Load the Pharmone Map data 15 | 16 | def get_obs_metadata(self): 17 | return general_load("evebio_pharmone_v1_observed_points_table", 18 | self.path, "\t") 19 | 20 | def get_control_data(self): 21 | return general_load("evebio_pharmone_v1_control_table", self.path, 22 | "\t") # Load the control data 23 | 24 | def get_compound_data(self): 25 | return general_load("evebio_pharmone_v1_compound_table", self.path, 26 | "\t") 27 | 28 | def get_target_data(self): 29 | return general_load("evebio_pharmone_v1_target_table", self.path, "\t") 30 | -------------------------------------------------------------------------------- /tdc/resource/primekg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | """ 5 | This file contains a primekg dataloader. 6 | """ 7 | 8 | import numpy as np 9 | import warnings 10 | 11 | from ..utils import general_load 12 | from ..utils.knowledge_graph import KnowledgeGraph 13 | 14 | warnings.filterwarnings("ignore") 15 | 16 | 17 | class PrimeKG(KnowledgeGraph): 18 | """PrimeKG data loader class to load the knowledge graph with additional support functions. 19 | """ 20 | 21 | def __init__(self, path="./data"): 22 | """load the KG to the specified path""" 23 | df = general_load("primekg", path, ",") 24 | self.df = df 25 | self.path = path 26 | super().__init__(self.df) 27 | 28 | def get_data(self): 29 | return self.df 30 | 31 | def to_nx(self): 32 | import networkx as nx 33 | 34 | G = nx.Graph() 35 | for i in self.df.relation.unique(): 36 | G.add_edges_from( 37 | self.df[self.df.relation == i][["x_name", "y_name"]].values, 38 | relation=i) 39 | return G 40 | 41 | def get_features(self, feature_type): 42 | if feature_type not in ["drug", "disease"]: 43 | raise ValueError("feature_type only supports drug/disease!") 44 | return general_load("primekg_" + feature_type + "_feature", self.path, 45 | "\t") 46 | 47 | def get_node_list(self, node_type): 48 | df = self.df 49 | return np.unique(df[(df.x_type == node_type)].x_id.unique().tolist() + 50 | df[(df.y_type == node_type)].y_id.unique().tolist()) 51 | -------------------------------------------------------------------------------- /tdc/single_pred/__init__.py: -------------------------------------------------------------------------------- 1 | from .adme import ADME 2 | from .crispr_outcome import CRISPROutcome 3 | from .develop import Develop 4 | from .epitope import Epitope 5 | from .hts import HTS 6 | from .paratope import Paratope 7 | from .qm import QM 8 | from .test_single_pred import TestSinglePred 9 | from .tox import Tox 10 | from .yields import Yields 11 | -------------------------------------------------------------------------------- /tdc/single_pred/crispr_outcome.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import sys 6 | import warnings 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | from . import single_pred_dataset 11 | from ..utils import print_sys 12 | from ..metadata import dataset_names 13 | 14 | 15 | class CRISPROutcome(single_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in CRISPROutcome task. More info: https://tdcommons.ai/single_pred_tasks/CRISPROutcome/ 17 | 18 | Args: 19 | name (str): the dataset name. 20 | path (str, optional): 21 | The path to save the data file, defaults to './data' 22 | label_name (str, optional): 23 | For multi-label dataset, specify the label name, defaults to None 24 | print_stats (bool, optional): 25 | Whether to print basic statistics of the dataset, defaults to False 26 | convert_format (str, optional): 27 | Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe, defaults to None 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name, 33 | path="./data", 34 | label_name=None, 35 | print_stats=False, 36 | convert_format=None, 37 | ): 38 | """Create CRISPROutcome dataloader object.""" 39 | super().__init__( 40 | name, 41 | path, 42 | label_name, 43 | print_stats, 44 | dataset_names=dataset_names["CRISPROutcome"], 45 | convert_format=convert_format, 46 | ) 47 | self.entity1_name = "GuideSeq" 48 | if print_stats: 49 | self.print_stats() 50 | print("Done!", flush=True, file=sys.stderr) 51 | -------------------------------------------------------------------------------- /tdc/single_pred/epitope.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import sys 6 | import warnings 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | from . import single_pred_dataset 11 | from ..utils import print_sys 12 | from ..metadata import dataset_names 13 | 14 | 15 | class Epitope(single_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Epitope Prediction task. More info: https://tdcommons.ai/single_pred_tasks/epitope/ 17 | 18 | Args: 19 | name (str): the dataset name. 20 | path (str, optional): 21 | The path to save the data file, defaults to './data' 22 | label_name (str, optional): 23 | For multi-label dataset, specify the label name, defaults to None 24 | print_stats (bool, optional): 25 | Whether to print basic statistics of the dataset, defaults to False 26 | convert_format (str, optional): 27 | Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe, defaults to None 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name, 33 | path="./data", 34 | label_name=None, 35 | print_stats=False, 36 | convert_format=None, 37 | ): 38 | """Create an Epitope prediction dataloader object.""" 39 | super().__init__( 40 | name, 41 | path, 42 | label_name, 43 | print_stats, 44 | dataset_names=dataset_names["Epitope"], 45 | convert_format=convert_format, 46 | ) 47 | self.entity1_name = "Antigen" 48 | if print_stats: 49 | self.print_stats() 50 | print("Done!", flush=True, file=sys.stderr) 51 | -------------------------------------------------------------------------------- /tdc/single_pred/hts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import sys 6 | import warnings 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | from . import single_pred_dataset 11 | from ..utils import print_sys 12 | from ..metadata import dataset_names 13 | 14 | 15 | class HTS(single_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in HTS task. More info: https://tdcommons.ai/single_pred_tasks/hts/ 17 | 18 | Args: 19 | name (str): the dataset name. 20 | path (str, optional): 21 | The path to save the data file, defaults to './data' 22 | label_name (str, optional): 23 | For multi-label dataset, specify the label name, defaults to None 24 | print_stats (bool, optional): 25 | Whether to print basic statistics of the dataset, defaults to False 26 | convert_format (str, optional): 27 | Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe, defaults to None 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name, 33 | path="./data", 34 | label_name=None, 35 | print_stats=False, 36 | convert_format=None, 37 | ): 38 | """Create a HTS dataloader object.""" 39 | super().__init__( 40 | name, 41 | path, 42 | label_name, 43 | print_stats, 44 | dataset_names=dataset_names["HTS"], 45 | convert_format=convert_format, 46 | ) 47 | if print_stats: 48 | self.print_stats() 49 | print("Done!", flush=True, file=sys.stderr) 50 | -------------------------------------------------------------------------------- /tdc/single_pred/paratope.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import sys 6 | import warnings 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | from . import single_pred_dataset 11 | from ..utils import print_sys 12 | from ..metadata import dataset_names 13 | 14 | 15 | class Paratope(single_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Paratope Prediction task. More info: https://tdcommons.ai/single_pred_tasks/paratope/ 17 | 18 | Args: 19 | name (str): the dataset name. 20 | path (str, optional): 21 | The path to save the data file, defaults to './data' 22 | label_name (str, optional): 23 | For multi-label dataset, specify the label name, defaults to None 24 | print_stats (bool, optional): 25 | Whether to print basic statistics of the dataset, defaults to False 26 | convert_format (str, optional): 27 | Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe, defaults to None 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name, 33 | path="./data", 34 | label_name=None, 35 | print_stats=False, 36 | convert_format=None, 37 | ): 38 | """Create a paratope prediction dataloader object.""" 39 | super().__init__( 40 | name, 41 | path, 42 | label_name, 43 | print_stats, 44 | dataset_names=dataset_names["Paratope"], 45 | convert_format=convert_format, 46 | ) 47 | self.entity1_name = "Antibody" 48 | if print_stats: 49 | self.print_stats() 50 | print("Done!", flush=True, file=sys.stderr) 51 | -------------------------------------------------------------------------------- /tdc/single_pred/qm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import sys 6 | import warnings 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | from . import single_pred_dataset 11 | from ..utils import print_sys 12 | from ..metadata import dataset_names 13 | 14 | 15 | class QM(single_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in QM (Quantum Mechanics Modeling) task. More info: https://tdcommons.ai/single_pred_tasks/qm/ 17 | 18 | Args: 19 | name (str): the dataset name. 20 | path (str, optional): 21 | The path to save the data file, defaults to './data' 22 | label_name (str, optional): 23 | For multi-label dataset, specify the label name, defaults to None 24 | print_stats (bool, optional): 25 | Whether to print basic statistics of the dataset, defaults to False 26 | convert_format (str, optional): 27 | Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe, defaults to None 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name, 33 | path="./data", 34 | label_name=None, 35 | print_stats=False, 36 | convert_format=None, 37 | raw_format="Raw3D", 38 | ): 39 | """Create QM (Quantum Mechanics Modeling) dataloader object.""" 40 | super().__init__( 41 | name, 42 | path, 43 | label_name, 44 | print_stats, 45 | dataset_names=dataset_names["QM"], 46 | convert_format=convert_format, 47 | raw_format=raw_format, 48 | ) 49 | if print_stats: 50 | self.print_stats() 51 | print("Done!", flush=True, file=sys.stderr) 52 | -------------------------------------------------------------------------------- /tdc/single_pred/test_single_pred.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import sys 6 | import warnings 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | from . import single_pred_dataset 11 | from ..utils import print_sys 12 | from ..metadata import dataset_names 13 | 14 | 15 | class TestSinglePred(single_pred_dataset.DataLoader): 16 | """Data loader class to test the single instance prediction data loader. 17 | 18 | Args: 19 | name (str): the dataset name. 20 | path (str, optional): 21 | The path to save the data file, defaults to './data' 22 | label_name (str, optional): 23 | For multi-label dataset, specify the label name, defaults to None 24 | print_stats (bool, optional): 25 | Whether to print basic statistics of the dataset, defaults to False 26 | convert_format (str, optional): 27 | Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe, defaults to None 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name, 33 | path="./data", 34 | label_name=None, 35 | print_stats=False, 36 | convert_format=None, 37 | ): 38 | """ 39 | Create a testing case dataloader. 40 | """ 41 | super().__init__( 42 | name, 43 | path, 44 | label_name, 45 | print_stats, 46 | dataset_names=dataset_names["test_single_pred"], 47 | convert_format=convert_format, 48 | ) 49 | if print_stats: 50 | self.print_stats() 51 | print("Done!", flush=True, file=sys.stderr) 52 | -------------------------------------------------------------------------------- /tdc/single_pred/tox.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import sys 6 | import warnings 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | from . import single_pred_dataset 11 | from ..utils import print_sys 12 | from ..metadata import dataset_names 13 | 14 | 15 | class Tox(single_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Tox (Toxicity Prediction) task. More info: https://tdcommons.ai/single_pred_tasks/tox/ 17 | 18 | Args: 19 | name (str): the dataset name. 20 | path (str, optional): 21 | The path to save the data file, defaults to './data' 22 | label_name (str, optional): 23 | For multi-label dataset, specify the label name, defaults to None 24 | print_stats (bool, optional): 25 | Whether to print basic statistics of the dataset, defaults to False 26 | convert_format (str, optional): 27 | Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe, defaults to None 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name, 33 | path="./data", 34 | label_name=None, 35 | print_stats=False, 36 | convert_format=None, 37 | ): 38 | """Create a Tox (Toxicity Prediction) dataloader object.""" 39 | super().__init__( 40 | name, 41 | path, 42 | label_name, 43 | print_stats, 44 | dataset_names=dataset_names["Tox"], 45 | convert_format=convert_format, 46 | ) 47 | if print_stats: 48 | self.print_stats() 49 | print("Done!", flush=True, file=sys.stderr) 50 | -------------------------------------------------------------------------------- /tdc/single_pred/yields.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: TDC Team 3 | # License: MIT 4 | 5 | import sys 6 | import warnings 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | from . import single_pred_dataset 11 | from ..utils import print_sys 12 | from ..metadata import dataset_names 13 | 14 | 15 | class Yields(single_pred_dataset.DataLoader): 16 | """Data loader class to load datasets in Yields (Reaction Yields Prediction) task. More info: https://tdcommons.ai/single_pred_tasks/yields/ 17 | 18 | Args: 19 | name (str): the dataset name. 20 | path (str, optional): 21 | The path to save the data file, defaults to './data' 22 | label_name (str, optional): 23 | For multi-label dataset, specify the label name, defaults to None 24 | print_stats (bool, optional): 25 | Whether to print basic statistics of the dataset, defaults to False 26 | convert_format (str, optional): 27 | Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe, defaults to None 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name, 33 | path="./data", 34 | label_name=None, 35 | print_stats=False, 36 | convert_format=None, 37 | ): 38 | """Create Yields (Reaction Yields Prediction) dataloader object.""" 39 | super().__init__( 40 | name, 41 | path, 42 | label_name, 43 | print_stats, 44 | dataset_names=dataset_names["Yields"], 45 | convert_format=convert_format, 46 | ) 47 | self.entity1_name = "Reaction" 48 | if print_stats: 49 | self.print_stats() 50 | print("Done!", flush=True, file=sys.stderr) 51 | -------------------------------------------------------------------------------- /tdc/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tdc/test/__init__.py -------------------------------------------------------------------------------- /tdc/test/dev_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tdc/test/dev_tests/__init__.py -------------------------------------------------------------------------------- /tdc/test/dev_tests/chem_utils_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tdc/test/dev_tests/chem_utils_test/__init__.py -------------------------------------------------------------------------------- /tdc/test/dev_tests/chem_utils_test/test_molconverter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | 9 | import unittest 10 | import shutil 11 | 12 | # temporary solution for relative imports in case TDC is not installed 13 | # if TDC is installed, no need to use the following line 14 | sys.path.append( 15 | os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) 16 | 17 | 18 | class TestMolConvert(unittest.TestCase): 19 | 20 | def setUp(self): 21 | print(os.getcwd()) 22 | pass 23 | 24 | def test_MolConvert(self): 25 | from tdc.chem_utils import MolConvert 26 | 27 | converter = MolConvert(src="SMILES", dst="Graph2D") 28 | converter([ 29 | "Clc1ccccc1C2C(=C(/N/C(=C2/C(=O)OCC)COCCN)C)\C(=O)OC", 30 | "CCCOc1cc2ncnc(Nc3ccc4ncsc4c3)c2cc1S(=O)(=O)C(C)(C)C", 31 | ]) 32 | 33 | from tdc.chem_utils import MolConvert 34 | 35 | MolConvert.eligible_format() 36 | 37 | # 38 | def tearDown(self): 39 | print(os.getcwd()) 40 | 41 | if os.path.exists(os.path.join(os.getcwd(), "data")): 42 | shutil.rmtree(os.path.join(os.getcwd(), "data")) 43 | -------------------------------------------------------------------------------- /tdc/test/dev_tests/chem_utils_test/test_molfilter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | 9 | import unittest 10 | import shutil 11 | 12 | # temporary solution for relative imports in case TDC is not installed 13 | # if TDC is installed, no need to use the following line 14 | sys.path.append( 15 | os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) 16 | 17 | 18 | class TestMolFilter(unittest.TestCase): 19 | 20 | def setUp(self): 21 | print(os.getcwd()) 22 | pass 23 | 24 | def test_MolConvert(self): 25 | from tdc.chem_utils import MolFilter 26 | 27 | filters = MolFilter(filters=["PAINS"], HBD=[0, 6]) 28 | filters(["CCSc1ccccc1C(=O)Nc1onc2c1CCC2"]) 29 | 30 | # 31 | def tearDown(self): 32 | print(os.getcwd()) 33 | 34 | if os.path.exists(os.path.join(os.getcwd(), "data")): 35 | shutil.rmtree(os.path.join(os.getcwd(), "data")) 36 | -------------------------------------------------------------------------------- /tdc/test/dev_tests/chem_utils_test/test_oracles.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | 9 | import unittest 10 | import shutil 11 | 12 | # temporary solution for relative imports in case TDC is not installed 13 | # if TDC is installed, no need to use the following line 14 | sys.path.append( 15 | os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) 16 | 17 | 18 | class TestOracle(unittest.TestCase): 19 | 20 | def setUp(self): 21 | print(os.getcwd()) 22 | pass 23 | 24 | def test_Oracle(self): 25 | from tdc import Oracle 26 | 27 | oracle = Oracle(name="SA") 28 | x = oracle([ 29 | "CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1", 30 | "CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1", 31 | "C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O", 32 | ]) 33 | 34 | oracle = Oracle(name="Hop") 35 | x = oracle(["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C=O"]) 36 | 37 | def test_distribution(self): 38 | from tdc import Evaluator 39 | 40 | evaluator = Evaluator(name="Diversity") 41 | x = evaluator([ 42 | "CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1", 43 | "C[C@@H]1CCc2c(sc(NC(=O)c3ccco3)c2C(N)=O)C1", 44 | "CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1", 45 | "C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O", 46 | ]) 47 | 48 | def tearDown(self): 49 | print(os.getcwd()) 50 | 51 | if os.path.exists(os.path.join(os.getcwd(), "data")): 52 | shutil.rmtree(os.path.join(os.getcwd(), "data")) 53 | if os.path.exists(os.path.join(os.getcwd(), "oracle")): 54 | shutil.rmtree(os.path.join(os.getcwd(), "oracle")) 55 | -------------------------------------------------------------------------------- /tdc/test/dev_tests/utils_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tdc/test/dev_tests/utils_tests/__init__.py -------------------------------------------------------------------------------- /tdc/test/test_functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | 9 | import unittest 10 | import shutil 11 | 12 | # temporary solution for relative imports in case TDC is not installed 13 | # if TDC is installed, no need to use the following line 14 | sys.path.append( 15 | os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 16 | 17 | 18 | class TestFunctions(unittest.TestCase): 19 | 20 | def setUp(self): 21 | print(os.getcwd()) 22 | pass 23 | 24 | def test_Evaluator(self): 25 | from tdc import Evaluator 26 | 27 | evaluator = Evaluator(name="ROC-AUC") 28 | print(evaluator([0, 1], [0.5, 0.6])) 29 | 30 | def test_binarize(self): 31 | from tdc.single_pred import TestSinglePred 32 | 33 | data = TestSinglePred(name="Test_Single_Pred") 34 | data.binarize(threshold=-5, order="descending") 35 | 36 | def test_convert_to_log(self): 37 | from tdc.single_pred import TestSinglePred 38 | 39 | data = TestSinglePred(name="Test_Single_Pred") 40 | data.convert_to_log() 41 | 42 | def test_print_stats(self): 43 | from tdc.single_pred import TestSinglePred 44 | 45 | data = TestSinglePred(name="Test_Single_Pred") 46 | data.print_stats() 47 | 48 | def tearDown(self): 49 | print(os.getcwd()) 50 | 51 | if os.path.exists(os.path.join(os.getcwd(), "data")): 52 | shutil.rmtree(os.path.join(os.getcwd(), "data")) 53 | if os.path.exists(os.path.join(os.getcwd(), "oracle")): 54 | shutil.rmtree(os.path.join(os.getcwd(), "oracle")) 55 | 56 | 57 | if __name__ == "__main__": 58 | unittest.main() 59 | -------------------------------------------------------------------------------- /tdc/test/test_hf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | 9 | import unittest 10 | import shutil 11 | import pytest 12 | 13 | # temporary solution for relative imports in case TDC is not installed 14 | # if TDC is installed, no need to use the following line 15 | sys.path.append( 16 | os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 17 | # TODO: add verification for the generation other than simple integration 18 | 19 | 20 | class TestHF(unittest.TestCase): 21 | 22 | def setUp(self): 23 | print(os.getcwd()) 24 | pass 25 | 26 | @pytest.mark.skip( 27 | reason="This test is skipped due to deeppurpose installation dependency" 28 | ) 29 | @unittest.skip(reason="DeepPurpose") 30 | def test_hf_load_predict(self): 31 | from tdc.single_pred import Tox 32 | data = Tox(name='herg_karim') 33 | 34 | from tdc import tdc_hf_interface 35 | tdc_hf = tdc_hf_interface("hERG_Karim-CNN") 36 | # load deeppurpose model from this repo 37 | dp_model = tdc_hf.load_deeppurpose('./data') 38 | tdc_hf.predict_deeppurpose(dp_model, ['CC(=O)NC1=CC=C(O)C=C1']) 39 | 40 | def test_hf_transformer(self): 41 | from tdc import tdc_hf_interface 42 | # from transformers import Pipeline 43 | from transformers import BertForMaskedLM as BertModel 44 | geneformer = tdc_hf_interface("Geneformer") 45 | model = geneformer.load() 46 | # assert isinstance(pipeline, Pipeline) 47 | assert isinstance(model, BertModel), type(model) 48 | 49 | # def test_hf_load_new_pytorch_standard(self): 50 | # from tdc import tdc_hf_interface 51 | # # from tdc.resource.dataloader import DataLoader 52 | # # data = DataLoader(name="pinnacle_dti") 53 | # tdc_hf = tdc_hf_interface("mli-PINNACLE") 54 | # dp_model = tdc_hf.load() 55 | # assert dp_model is not None 56 | 57 | def tearDown(self): 58 | try: 59 | print(os.getcwd()) 60 | shutil.rmtree(os.path.join(os.getcwd(), "data")) 61 | except: 62 | pass 63 | 64 | 65 | if __name__ == "__main__": 66 | unittest.main() 67 | -------------------------------------------------------------------------------- /tdc/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .load import ( 2 | distribution_dataset_load, 3 | generation_paired_dataset_load, 4 | three_dim_dataset_load, 5 | interaction_dataset_load, 6 | multi_dataset_load, 7 | property_dataset_load, 8 | bi_distribution_dataset_load, 9 | oracle_load, 10 | receptor_load, 11 | bm_group_load, 12 | general_load, 13 | ) 14 | from .split import ( 15 | create_fold, 16 | create_fold_setting_cold, 17 | create_combination_split, 18 | create_fold_time, 19 | create_scaffold_split, 20 | create_group_split, 21 | create_combination_generation_split, 22 | ) 23 | from .misc import ( 24 | print_sys, 25 | install, 26 | fuzzy_search, 27 | save_dict, 28 | load_dict, 29 | to_submission_format, 30 | ) 31 | from .label_name_list import dataset2target_lists 32 | from .label import ( 33 | NegSample, 34 | label_transform, 35 | convert_y_unit, 36 | convert_to_log, 37 | convert_back_log, 38 | binarize, 39 | label_dist, 40 | ) 41 | from .retrieve import ( 42 | get_label_map, 43 | get_reaction_type, 44 | retrieve_label_name_list, 45 | retrieve_dataset_names, 46 | retrieve_all_benchmarks, 47 | retrieve_benchmark_names, 48 | ) 49 | from .query import uniprot2seq, cid2smiles 50 | -------------------------------------------------------------------------------- /tdc/utils/knowledge_graph.py: -------------------------------------------------------------------------------- 1 | """A python module to build, handle, explore, and manipulate knowledge graphs. 2 | """ 3 | 4 | import pandas as pd 5 | from copy import copy 6 | 7 | kg_columns = [ 8 | 'relation', 'display_relation', 'x_id', 'x_type', 'x_name', 'x_source', 9 | 'y_id', 'y_type', 'y_name', 'y_source' 10 | ] 11 | 12 | 13 | class KnowledgeGraph: 14 | 15 | def __init__(self, df=None): 16 | if df is not None: 17 | self.df = df 18 | else: 19 | self.df = pd.DataFrame('', columns=kg_columns) 20 | 21 | def copy(self): 22 | return copy(self) 23 | 24 | def run_query(self, query, inplace=True): 25 | """build subgraph using given query""" 26 | df_filt = self.df.query(query).reset_index(drop=True) 27 | if inplace: 28 | self.df_raw = self.df 29 | self.df = df_filt 30 | else: 31 | return df_filt 32 | 33 | def get_nodes_by_source(self, source): 34 | # extract x nodes 35 | x_df = self.df.query( 36 | f"x_source == '{source}' | y_source == '{source}'")[[ 37 | col for col in self.df.columns if col.startswith("x_") 38 | ]] 39 | 40 | for col in x_df.columns: 41 | x_df = x_df.rename(columns={col: col[2:]}) 42 | 43 | # extract y nodes 44 | y_df = self.df.query( 45 | f"x_source == '{source}' | y_source == '{source}'")[[ 46 | col for col in self.df.columns if col.startswith("y_") 47 | ]] 48 | for col in y_df.columns: 49 | y_df = y_df.rename(columns={col: col[2:]}) 50 | # merge x and y nodes and keep only unique nodes 51 | out = pd.concat([ 52 | x_df, y_df 53 | ], axis=0).query(f'source == "{source}"').drop_duplicates().reset_index( 54 | drop=True) 55 | 56 | return out 57 | 58 | 59 | def build_KG(indices, relation, display_relation, x_id, x_type, x_name, 60 | x_source, y_id, y_type, y_name, y_source): 61 | df = pd.DataFrame('', columns=kg_columns, index=indices) 62 | 63 | df.relation = relation 64 | df.display_relation = display_relation 65 | 66 | df.x_id = x_id 67 | df.x_type = x_type 68 | df.x_name = x_name 69 | df.x_source = x_source 70 | 71 | df.y_id = y_id 72 | df.y_type = y_type 73 | df.y_name = y_name 74 | df.y_source = y_source 75 | 76 | kg = KnowledgeGraph(df) 77 | 78 | return kg 79 | -------------------------------------------------------------------------------- /tdc/version.py: -------------------------------------------------------------------------------- 1 | """TDC version file 2 | """ 3 | # Based on NiLearn package 4 | # License: simplified BSD 5 | 6 | # PEP0440 compatible formatted version, see: 7 | # https://www.python.org/dev/peps/pep-0440/ 8 | # 9 | # Generic release markers: 10 | # X.Y 11 | # X.Y.Z # For bug fix releases 12 | # 13 | # Admissible pre-release markers: 14 | # X.YaN # Alpha release 15 | # X.YbN # Beta release 16 | # X.YrcN # Release Candidate 17 | # X.Y # Final release 18 | # 19 | # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. 20 | # 'X.Y.dev0' is the canonical version of 'X.Y.dev' 21 | # 22 | __version__ = "1.1.14" # pragma: no cover 23 | -------------------------------------------------------------------------------- /tutorials/User_Group/3pbl_ligand.xyz: -------------------------------------------------------------------------------- 1 | 26 2 | out.pdbqt 3 | N 7.99200 23.31100 23.04500 4 | H 8.77200 23.69600 23.54900 5 | C 8.00200 21.90900 22.85200 6 | C 8.08300 19.12000 22.41900 7 | C 6.86200 19.79100 22.46600 8 | C 9.26900 19.84000 22.60300 9 | C 6.82500 21.17300 22.67100 10 | C 9.24300 21.23400 22.81300 11 | C 10.55500 21.94800 23.00300 12 | C 8.08200 17.64800 22.19600 13 | N 6.85500 17.01900 22.29900 14 | H 6.12500 17.51000 22.79700 15 | O 9.13100 17.05500 21.97200 16 | C 6.77500 15.57300 22.25800 17 | C 6.42100 15.00000 23.61600 18 | C 6.93500 24.17800 22.88600 19 | O 6.35400 24.39800 21.82800 20 | N 6.53400 24.80900 24.06600 21 | C 5.66200 25.97700 23.93100 22 | C 7.05700 24.50500 25.41000 23 | C 4.98000 26.34300 25.24200 24 | C 7.28100 25.78600 26.22800 25 | C 5.98200 26.60900 26.38100 26 | C 6.24800 28.10600 26.54500 27 | O 8.32300 26.55100 25.61600 28 | H 7.91700 27.12200 24.94000 29 | -------------------------------------------------------------------------------- /tutorials/User_Group/docking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/docking.png -------------------------------------------------------------------------------- /tutorials/User_Group/docking_gflownet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/docking_gflownet.png -------------------------------------------------------------------------------- /tutorials/User_Group/ga_illustration.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/ga_illustration.pdf -------------------------------------------------------------------------------- /tutorials/User_Group/ga_illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/ga_illustration.png -------------------------------------------------------------------------------- /tutorials/User_Group/generation_process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/generation_process.png -------------------------------------------------------------------------------- /tutorials/User_Group/leaderboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/leaderboard.png -------------------------------------------------------------------------------- /tutorials/User_Group/leaderboard_generative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/leaderboard_generative.png -------------------------------------------------------------------------------- /tutorials/User_Group/oracle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/oracle.png -------------------------------------------------------------------------------- /tutorials/User_Group/tdc_problems.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/tdc_problems.png -------------------------------------------------------------------------------- /tutorials/User_Group/vina.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/vina.png -------------------------------------------------------------------------------- /tutorials/User_Group/why_docking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/TDC/c310c35f27e3f506411018ac43d97b8ba23ca652/tutorials/User_Group/why_docking.png --------------------------------------------------------------------------------