├── LICENSE ├── README.md ├── code ├── known_class │ ├── data_process │ │ ├── generate_SDF │ │ │ ├── filter_and_merge.py │ │ │ ├── final_data │ │ │ │ └── .gitkeep │ │ │ ├── main.py │ │ │ ├── prepare_dataset.py │ │ │ ├── rdkit_conf_parallel.py │ │ │ ├── sdfdir │ │ │ │ └── .gitkeep │ │ │ └── torchdrug │ │ │ │ └── torchdrug │ │ │ │ ├── __init__.py │ │ │ │ ├── core │ │ │ │ ├── __init__.py │ │ │ │ ├── core.py │ │ │ │ ├── engine.py │ │ │ │ ├── logger.py │ │ │ │ └── meter.py │ │ │ │ ├── data │ │ │ │ ├── __init__.py │ │ │ │ ├── constant.py │ │ │ │ ├── dataloader.py │ │ │ │ ├── dataset.py │ │ │ │ ├── dictionary.py │ │ │ │ ├── feature.py │ │ │ │ ├── graph.py │ │ │ │ ├── molecule.py │ │ │ │ ├── protein.py │ │ │ │ └── rdkit │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── draw.py │ │ │ │ ├── datasets │ │ │ │ ├── __init__.py │ │ │ │ ├── alphafolddb.py │ │ │ │ ├── bace.py │ │ │ │ ├── bbbp.py │ │ │ │ ├── beta_lactamase.py │ │ │ │ ├── binary_localization.py │ │ │ │ ├── bindingdb.py │ │ │ │ ├── cep.py │ │ │ │ ├── chembl_filtered.py │ │ │ │ ├── citeseer.py │ │ │ │ ├── clintox.py │ │ │ │ ├── cora.py │ │ │ │ ├── delaney.py │ │ │ │ ├── enzyme_commission.py │ │ │ │ ├── fb15k.py │ │ │ │ ├── fluorescence.py │ │ │ │ ├── fold.py │ │ │ │ ├── freesolv.py │ │ │ │ ├── gene_ontology.py │ │ │ │ ├── hetionet.py │ │ │ │ ├── hiv.py │ │ │ │ ├── human_ppi.py │ │ │ │ ├── lipophilicity.py │ │ │ │ ├── malaria.py │ │ │ │ ├── moses.py │ │ │ │ ├── muv.py │ │ │ │ ├── opv.py │ │ │ │ ├── pcqm4m.py │ │ │ │ ├── pdbbind.py │ │ │ │ ├── ppi_affinity.py │ │ │ │ ├── proteinnet.py │ │ │ │ ├── pubchem110m.py │ │ │ │ ├── pubmed.py │ │ │ │ ├── qm8.py │ │ │ │ ├── qm9.py │ │ │ │ ├── secondary_structure.py │ │ │ │ ├── sider.py │ │ │ │ ├── solubility.py │ │ │ │ ├── stability.py │ │ │ │ ├── subcellular_localization.py │ │ │ │ ├── tox21.py │ │ │ │ ├── toxcast.py │ │ │ │ ├── uspto50k.py │ │ │ │ ├── wn18.py │ │ │ │ ├── yago310.py │ │ │ │ ├── yeast_ppi.py │ │ │ │ ├── zinc250k.py │ │ │ │ └── zinc2m.py │ │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── block.py │ │ │ │ ├── common.py │ │ │ │ ├── conv.py │ │ │ │ ├── distribution.py │ │ │ │ ├── flow.py │ │ │ │ ├── functional │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── embedding.py │ │ │ │ │ ├── extension │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── embedding.cpp │ │ │ │ │ │ ├── embedding.cu │ │ │ │ │ │ ├── embedding.h │ │ │ │ │ │ ├── operator.cuh │ │ │ │ │ │ ├── rspmm.cpp │ │ │ │ │ │ ├── rspmm.cu │ │ │ │ │ │ ├── rspmm.h │ │ │ │ │ │ ├── spmm.cpp │ │ │ │ │ │ ├── spmm.cu │ │ │ │ │ │ ├── spmm.h │ │ │ │ │ │ └── util.cuh │ │ │ │ │ ├── functional.py │ │ │ │ │ └── spmm.py │ │ │ │ ├── geometry │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── function.py │ │ │ │ │ └── graph.py │ │ │ │ ├── pool.py │ │ │ │ ├── readout.py │ │ │ │ └── sampler.py │ │ │ │ ├── metrics │ │ │ │ ├── __init__.py │ │ │ │ ├── metric.py │ │ │ │ └── rdkit │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── sascorer.py │ │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── bert.py │ │ │ │ ├── chebnet.py │ │ │ │ ├── cnn.py │ │ │ │ ├── embedding.py │ │ │ │ ├── esm.py │ │ │ │ ├── flow.py │ │ │ │ ├── gat.py │ │ │ │ ├── gcn.py │ │ │ │ ├── gearnet.py │ │ │ │ ├── gin.py │ │ │ │ ├── infograph.py │ │ │ │ ├── kbgat.py │ │ │ │ ├── lstm.py │ │ │ │ ├── mpnn.py │ │ │ │ ├── neuralfp.py │ │ │ │ ├── neurallp.py │ │ │ │ ├── physicochemical.py │ │ │ │ ├── schnet.py │ │ │ │ └── statistic.py │ │ │ │ ├── patch.py │ │ │ │ ├── tasks │ │ │ │ ├── __init__.py │ │ │ │ ├── contact_prediction.py │ │ │ │ ├── generation.py │ │ │ │ ├── pretrain.py │ │ │ │ ├── property_prediction.py │ │ │ │ ├── reasoning.py │ │ │ │ ├── retrosynthesis.py │ │ │ │ └── task.py │ │ │ │ ├── transforms │ │ │ │ ├── __init__.py │ │ │ │ └── transform.py │ │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── comm.py │ │ │ │ ├── decorator.py │ │ │ │ ├── extension │ │ │ │ ├── __init__.py │ │ │ │ └── torch_ext.cpp │ │ │ │ ├── file.py │ │ │ │ ├── io.py │ │ │ │ ├── plot.py │ │ │ │ ├── pretty.py │ │ │ │ ├── template │ │ │ │ └── echarts.html │ │ │ │ └── torch.py │ │ └── get_dataset_1st_stage │ │ │ ├── cmp_data │ │ │ └── .gitkeep │ │ │ ├── get_mapping_mol-product_class.py │ │ │ ├── main.py │ │ │ ├── molecule-datasets │ │ │ └── .gitkeep │ │ │ ├── torchdrug │ │ │ ├── __init__.py │ │ │ ├── core │ │ │ │ ├── __init__.py │ │ │ │ ├── core.py │ │ │ │ ├── engine.py │ │ │ │ ├── logger.py │ │ │ │ └── meter.py │ │ │ ├── data │ │ │ │ ├── __init__.py │ │ │ │ ├── constant.py │ │ │ │ ├── dataloader.py │ │ │ │ ├── dataset.py │ │ │ │ ├── dictionary.py │ │ │ │ ├── feature.py │ │ │ │ ├── graph.py │ │ │ │ ├── molecule.py │ │ │ │ ├── protein.py │ │ │ │ └── rdkit │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── draw.py │ │ │ ├── datasets │ │ │ │ ├── __init__.py │ │ │ │ ├── alphafolddb.py │ │ │ │ ├── bace.py │ │ │ │ ├── bbbp.py │ │ │ │ ├── beta_lactamase.py │ │ │ │ ├── binary_localization.py │ │ │ │ ├── bindingdb.py │ │ │ │ ├── cep.py │ │ │ │ ├── chembl_filtered.py │ │ │ │ ├── citeseer.py │ │ │ │ ├── clintox.py │ │ │ │ ├── cora.py │ │ │ │ ├── delaney.py │ │ │ │ ├── enzyme_commission.py │ │ │ │ ├── fb15k.py │ │ │ │ ├── fluorescence.py │ │ │ │ ├── fold.py │ │ │ │ ├── freesolv.py │ │ │ │ ├── gene_ontology.py │ │ │ │ ├── hetionet.py │ │ │ │ ├── hiv.py │ │ │ │ ├── human_ppi.py │ │ │ │ ├── lipophilicity.py │ │ │ │ ├── malaria.py │ │ │ │ ├── moses.py │ │ │ │ ├── muv.py │ │ │ │ ├── opv.py │ │ │ │ ├── pcqm4m.py │ │ │ │ ├── pdbbind.py │ │ │ │ ├── ppi_affinity.py │ │ │ │ ├── proteinnet.py │ │ │ │ ├── pubchem110m.py │ │ │ │ ├── pubmed.py │ │ │ │ ├── qm8.py │ │ │ │ ├── qm9.py │ │ │ │ ├── secondary_structure.py │ │ │ │ ├── sider.py │ │ │ │ ├── solubility.py │ │ │ │ ├── stability.py │ │ │ │ ├── subcellular_localization.py │ │ │ │ ├── tox21.py │ │ │ │ ├── toxcast.py │ │ │ │ ├── uspto50k.py │ │ │ │ ├── wn18.py │ │ │ │ ├── yago310.py │ │ │ │ ├── yeast_ppi.py │ │ │ │ ├── zinc250k.py │ │ │ │ └── zinc2m.py │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── block.py │ │ │ │ ├── common.py │ │ │ │ ├── conv.py │ │ │ │ ├── distribution.py │ │ │ │ ├── flow.py │ │ │ │ ├── functional │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── embedding.py │ │ │ │ │ ├── extension │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── embedding.cpp │ │ │ │ │ │ ├── embedding.cu │ │ │ │ │ │ ├── embedding.h │ │ │ │ │ │ ├── operator.cuh │ │ │ │ │ │ ├── rspmm.cpp │ │ │ │ │ │ ├── rspmm.cu │ │ │ │ │ │ ├── rspmm.h │ │ │ │ │ │ ├── spmm.cpp │ │ │ │ │ │ ├── spmm.cu │ │ │ │ │ │ ├── spmm.h │ │ │ │ │ │ └── util.cuh │ │ │ │ │ ├── functional.py │ │ │ │ │ └── spmm.py │ │ │ │ ├── geometry │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── function.py │ │ │ │ │ └── graph.py │ │ │ │ ├── pool.py │ │ │ │ ├── readout.py │ │ │ │ └── sampler.py │ │ │ ├── metrics │ │ │ │ ├── __init__.py │ │ │ │ ├── metric.py │ │ │ │ └── rdkit │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── sascorer.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── bert.py │ │ │ │ ├── chebnet.py │ │ │ │ ├── cnn.py │ │ │ │ ├── embedding.py │ │ │ │ ├── esm.py │ │ │ │ ├── flow.py │ │ │ │ ├── gat.py │ │ │ │ ├── gcn.py │ │ │ │ ├── gearnet.py │ │ │ │ ├── gin.py │ │ │ │ ├── infograph.py │ │ │ │ ├── kbgat.py │ │ │ │ ├── lstm.py │ │ │ │ ├── mpnn.py │ │ │ │ ├── neuralfp.py │ │ │ │ ├── neurallp.py │ │ │ │ ├── physicochemical.py │ │ │ │ ├── schnet.py │ │ │ │ └── statistic.py │ │ │ ├── patch.py │ │ │ ├── tasks │ │ │ │ ├── __init__.py │ │ │ │ ├── contact_prediction.py │ │ │ │ ├── generation.py │ │ │ │ ├── pretrain.py │ │ │ │ ├── property_prediction.py │ │ │ │ ├── reasoning.py │ │ │ │ ├── retrosynthesis.py │ │ │ │ └── task.py │ │ │ ├── transforms │ │ │ │ ├── __init__.py │ │ │ │ └── transform.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── comm.py │ │ │ │ ├── decorator.py │ │ │ │ ├── extension │ │ │ │ ├── __init__.py │ │ │ │ └── torch_ext.cpp │ │ │ │ ├── file.py │ │ │ │ ├── io.py │ │ │ │ ├── plot.py │ │ │ │ ├── pretty.py │ │ │ │ ├── template │ │ │ │ └── echarts.html │ │ │ │ └── torch.py │ │ │ └── uspto50k.py │ ├── stage1 │ │ ├── stage1_to_result_dict.py │ │ ├── torchdrug │ │ │ ├── __init__.py │ │ │ ├── core │ │ │ │ ├── __init__.py │ │ │ │ ├── core.py │ │ │ │ ├── engine.py │ │ │ │ ├── logger.py │ │ │ │ └── meter.py │ │ │ ├── data │ │ │ │ ├── __init__.py │ │ │ │ ├── constant.py │ │ │ │ ├── dataloader.py │ │ │ │ ├── dataset.py │ │ │ │ ├── dictionary.py │ │ │ │ ├── feature.py │ │ │ │ ├── graph.py │ │ │ │ ├── molecule.py │ │ │ │ └── rdkit │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── draw.py │ │ │ ├── datasets │ │ │ │ ├── __init__.py │ │ │ │ └── uspto50k.py │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── common.py │ │ │ │ ├── conv.py │ │ │ │ ├── functional │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── embedding.py │ │ │ │ │ ├── extension │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── embedding.cpp │ │ │ │ │ │ ├── embedding.cu │ │ │ │ │ │ ├── embedding.h │ │ │ │ │ │ ├── operator.cuh │ │ │ │ │ │ ├── rspmm.cpp │ │ │ │ │ │ ├── rspmm.cu │ │ │ │ │ │ ├── rspmm.h │ │ │ │ │ │ ├── spmm.cpp │ │ │ │ │ │ ├── spmm.cu │ │ │ │ │ │ ├── spmm.h │ │ │ │ │ │ └── util.cuh │ │ │ │ │ ├── functional.py │ │ │ │ │ └── spmm.py │ │ │ │ └── readout.py │ │ │ ├── metrics │ │ │ │ ├── __init__.py │ │ │ │ ├── metric.py │ │ │ │ └── rdkit │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── sascorer.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ └── gcn.py │ │ │ ├── patch.py │ │ │ ├── tasks │ │ │ │ ├── __init__.py │ │ │ │ ├── retrosynthesis.py │ │ │ │ └── task.py │ │ │ ├── transforms │ │ │ │ ├── __init__.py │ │ │ │ └── transform.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── comm.py │ │ │ │ ├── decorator.py │ │ │ │ ├── doc.py │ │ │ │ ├── extension │ │ │ │ ├── __init__.py │ │ │ │ └── torch_ext.cpp │ │ │ │ ├── file.py │ │ │ │ ├── io.py │ │ │ │ ├── plot.py │ │ │ │ ├── pretty.py │ │ │ │ ├── template │ │ │ │ └── echarts.html │ │ │ │ └── torch.py │ │ └── train.py │ └── stage2 │ │ ├── configs │ │ ├── uspto_gdiffretro.yml │ │ ├── uspto_sample.yml │ │ └── uspto_size.yml │ │ ├── merge_result.py │ │ ├── run_get_results.sh │ │ ├── sample.py │ │ ├── src │ │ ├── __pycache__ │ │ │ └── utils.cpython-38.pyc │ │ ├── const.py │ │ ├── datasets.py │ │ ├── edm.py │ │ ├── egnn.py │ │ ├── lightning.py │ │ ├── linker_size.py │ │ ├── linker_size_lightning.py │ │ ├── metrics.py │ │ ├── molecule_builder.py │ │ ├── noise.py │ │ ├── utils.py │ │ └── visualizer.py │ │ ├── train_gdiffretro.py │ │ ├── train_size_gnn.py │ │ ├── vis_get_result.py │ │ └── xyz_split.py └── unknown_class │ ├── data_process │ ├── generate_SDF │ │ ├── filter_and_merge.py │ │ ├── final_data │ │ │ └── .gitkeep │ │ ├── main.py │ │ ├── prepare_dataset.py │ │ ├── rdkit_conf_parallel.py │ │ ├── sdfdir │ │ │ └── .gitkeep │ │ └── torchdrug │ │ │ └── torchdrug │ │ │ ├── __init__.py │ │ │ ├── core │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── engine.py │ │ │ ├── logger.py │ │ │ └── meter.py │ │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── constant.py │ │ │ ├── dataloader.py │ │ │ ├── dataset.py │ │ │ ├── dictionary.py │ │ │ ├── feature.py │ │ │ ├── graph.py │ │ │ ├── molecule.py │ │ │ ├── protein.py │ │ │ └── rdkit │ │ │ │ ├── __init__.py │ │ │ │ └── draw.py │ │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── alphafolddb.py │ │ │ ├── bace.py │ │ │ ├── bbbp.py │ │ │ ├── beta_lactamase.py │ │ │ ├── binary_localization.py │ │ │ ├── bindingdb.py │ │ │ ├── cep.py │ │ │ ├── chembl_filtered.py │ │ │ ├── citeseer.py │ │ │ ├── clintox.py │ │ │ ├── cora.py │ │ │ ├── delaney.py │ │ │ ├── enzyme_commission.py │ │ │ ├── fb15k.py │ │ │ ├── fluorescence.py │ │ │ ├── fold.py │ │ │ ├── freesolv.py │ │ │ ├── gene_ontology.py │ │ │ ├── hetionet.py │ │ │ ├── hiv.py │ │ │ ├── human_ppi.py │ │ │ ├── lipophilicity.py │ │ │ ├── malaria.py │ │ │ ├── moses.py │ │ │ ├── muv.py │ │ │ ├── opv.py │ │ │ ├── pcqm4m.py │ │ │ ├── pdbbind.py │ │ │ ├── ppi_affinity.py │ │ │ ├── proteinnet.py │ │ │ ├── pubchem110m.py │ │ │ ├── pubmed.py │ │ │ ├── qm8.py │ │ │ ├── qm9.py │ │ │ ├── secondary_structure.py │ │ │ ├── sider.py │ │ │ ├── solubility.py │ │ │ ├── stability.py │ │ │ ├── subcellular_localization.py │ │ │ ├── tox21.py │ │ │ ├── toxcast.py │ │ │ ├── uspto50k.py │ │ │ ├── wn18.py │ │ │ ├── yago310.py │ │ │ ├── yeast_ppi.py │ │ │ ├── zinc250k.py │ │ │ └── zinc2m.py │ │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── block.py │ │ │ ├── common.py │ │ │ ├── conv.py │ │ │ ├── distribution.py │ │ │ ├── flow.py │ │ │ ├── functional │ │ │ │ ├── __init__.py │ │ │ │ ├── embedding.py │ │ │ │ ├── extension │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── embedding.cpp │ │ │ │ │ ├── embedding.cu │ │ │ │ │ ├── embedding.h │ │ │ │ │ ├── operator.cuh │ │ │ │ │ ├── rspmm.cpp │ │ │ │ │ ├── rspmm.cu │ │ │ │ │ ├── rspmm.h │ │ │ │ │ ├── spmm.cpp │ │ │ │ │ ├── spmm.cu │ │ │ │ │ ├── spmm.h │ │ │ │ │ └── util.cuh │ │ │ │ ├── functional.py │ │ │ │ └── spmm.py │ │ │ ├── geometry │ │ │ │ ├── __init__.py │ │ │ │ ├── function.py │ │ │ │ └── graph.py │ │ │ ├── pool.py │ │ │ ├── readout.py │ │ │ └── sampler.py │ │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ ├── metric.py │ │ │ └── rdkit │ │ │ │ ├── __init__.py │ │ │ │ └── sascorer.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── bert.py │ │ │ ├── chebnet.py │ │ │ ├── cnn.py │ │ │ ├── embedding.py │ │ │ ├── esm.py │ │ │ ├── flow.py │ │ │ ├── gat.py │ │ │ ├── gcn.py │ │ │ ├── gearnet.py │ │ │ ├── gin.py │ │ │ ├── infograph.py │ │ │ ├── kbgat.py │ │ │ ├── lstm.py │ │ │ ├── mpnn.py │ │ │ ├── neuralfp.py │ │ │ ├── neurallp.py │ │ │ ├── physicochemical.py │ │ │ ├── schnet.py │ │ │ └── statistic.py │ │ │ ├── patch.py │ │ │ ├── tasks │ │ │ ├── __init__.py │ │ │ ├── contact_prediction.py │ │ │ ├── generation.py │ │ │ ├── pretrain.py │ │ │ ├── property_prediction.py │ │ │ ├── reasoning.py │ │ │ ├── retrosynthesis.py │ │ │ └── task.py │ │ │ ├── transforms │ │ │ ├── __init__.py │ │ │ └── transform.py │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── comm.py │ │ │ ├── decorator.py │ │ │ ├── extension │ │ │ ├── __init__.py │ │ │ └── torch_ext.cpp │ │ │ ├── file.py │ │ │ ├── io.py │ │ │ ├── plot.py │ │ │ ├── pretty.py │ │ │ ├── template │ │ │ └── echarts.html │ │ │ └── torch.py │ └── get_dataset_1st_stage │ │ ├── cmp_data │ │ └── .gitkeep │ │ ├── main.py │ │ ├── molecule-datasets │ │ └── .gitkeep │ │ ├── torchdrug │ │ └── torchdrug │ │ │ ├── __init__.py │ │ │ ├── core │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── engine.py │ │ │ ├── logger.py │ │ │ └── meter.py │ │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── constant.py │ │ │ ├── dataloader.py │ │ │ ├── dataset.py │ │ │ ├── dictionary.py │ │ │ ├── feature.py │ │ │ ├── graph.py │ │ │ ├── molecule.py │ │ │ ├── protein.py │ │ │ └── rdkit │ │ │ │ ├── __init__.py │ │ │ │ └── draw.py │ │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── alphafolddb.py │ │ │ ├── bace.py │ │ │ ├── bbbp.py │ │ │ ├── beta_lactamase.py │ │ │ ├── binary_localization.py │ │ │ ├── bindingdb.py │ │ │ ├── cep.py │ │ │ ├── chembl_filtered.py │ │ │ ├── citeseer.py │ │ │ ├── clintox.py │ │ │ ├── cora.py │ │ │ ├── delaney.py │ │ │ ├── enzyme_commission.py │ │ │ ├── fb15k.py │ │ │ ├── fluorescence.py │ │ │ ├── fold.py │ │ │ ├── freesolv.py │ │ │ ├── gene_ontology.py │ │ │ ├── hetionet.py │ │ │ ├── hiv.py │ │ │ ├── human_ppi.py │ │ │ ├── lipophilicity.py │ │ │ ├── malaria.py │ │ │ ├── moses.py │ │ │ ├── muv.py │ │ │ ├── opv.py │ │ │ ├── pcqm4m.py │ │ │ ├── pdbbind.py │ │ │ ├── ppi_affinity.py │ │ │ ├── proteinnet.py │ │ │ ├── pubchem110m.py │ │ │ ├── pubmed.py │ │ │ ├── qm8.py │ │ │ ├── qm9.py │ │ │ ├── secondary_structure.py │ │ │ ├── sider.py │ │ │ ├── solubility.py │ │ │ ├── stability.py │ │ │ ├── subcellular_localization.py │ │ │ ├── tox21.py │ │ │ ├── toxcast.py │ │ │ ├── uspto50k.py │ │ │ ├── wn18.py │ │ │ ├── yago310.py │ │ │ ├── yeast_ppi.py │ │ │ ├── zinc250k.py │ │ │ └── zinc2m.py │ │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── block.py │ │ │ ├── common.py │ │ │ ├── conv.py │ │ │ ├── distribution.py │ │ │ ├── flow.py │ │ │ ├── functional │ │ │ │ ├── __init__.py │ │ │ │ ├── embedding.py │ │ │ │ ├── extension │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── embedding.cpp │ │ │ │ │ ├── embedding.cu │ │ │ │ │ ├── embedding.h │ │ │ │ │ ├── operator.cuh │ │ │ │ │ ├── rspmm.cpp │ │ │ │ │ ├── rspmm.cu │ │ │ │ │ ├── rspmm.h │ │ │ │ │ ├── spmm.cpp │ │ │ │ │ ├── spmm.cu │ │ │ │ │ ├── spmm.h │ │ │ │ │ └── util.cuh │ │ │ │ ├── functional.py │ │ │ │ └── spmm.py │ │ │ ├── geometry │ │ │ │ ├── __init__.py │ │ │ │ ├── function.py │ │ │ │ └── graph.py │ │ │ ├── pool.py │ │ │ ├── readout.py │ │ │ └── sampler.py │ │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ ├── metric.py │ │ │ └── rdkit │ │ │ │ ├── __init__.py │ │ │ │ └── sascorer.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── bert.py │ │ │ ├── chebnet.py │ │ │ ├── cnn.py │ │ │ ├── embedding.py │ │ │ ├── esm.py │ │ │ ├── flow.py │ │ │ ├── gat.py │ │ │ ├── gcn.py │ │ │ ├── gearnet.py │ │ │ ├── gin.py │ │ │ ├── infograph.py │ │ │ ├── kbgat.py │ │ │ ├── lstm.py │ │ │ ├── mpnn.py │ │ │ ├── neuralfp.py │ │ │ ├── neurallp.py │ │ │ ├── physicochemical.py │ │ │ ├── schnet.py │ │ │ └── statistic.py │ │ │ ├── patch.py │ │ │ ├── tasks │ │ │ ├── __init__.py │ │ │ ├── contact_prediction.py │ │ │ ├── generation.py │ │ │ ├── pretrain.py │ │ │ ├── property_prediction.py │ │ │ ├── reasoning.py │ │ │ ├── retrosynthesis.py │ │ │ └── task.py │ │ │ ├── transforms │ │ │ ├── __init__.py │ │ │ └── transform.py │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── comm.py │ │ │ ├── decorator.py │ │ │ ├── extension │ │ │ ├── __init__.py │ │ │ └── torch_ext.cpp │ │ │ ├── file.py │ │ │ ├── io.py │ │ │ ├── plot.py │ │ │ ├── pretty.py │ │ │ ├── template │ │ │ └── echarts.html │ │ │ └── torch.py │ │ └── uspto50k.py │ ├── stage1 │ ├── stage1_to_result_dict.py │ ├── torchdrug │ │ ├── __init__.py │ │ ├── core │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── engine.py │ │ │ ├── logger.py │ │ │ └── meter.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── constant.py │ │ │ ├── dataloader.py │ │ │ ├── dataset.py │ │ │ ├── dictionary.py │ │ │ ├── feature.py │ │ │ ├── graph.py │ │ │ ├── molecule.py │ │ │ └── rdkit │ │ │ │ ├── __init__.py │ │ │ │ └── draw.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ └── uspto50k.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── conv.py │ │ │ ├── functional │ │ │ │ ├── __init__.py │ │ │ │ ├── embedding.py │ │ │ │ ├── extension │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── embedding.cpp │ │ │ │ │ ├── embedding.cu │ │ │ │ │ ├── embedding.h │ │ │ │ │ ├── operator.cuh │ │ │ │ │ ├── rspmm.cpp │ │ │ │ │ ├── rspmm.cu │ │ │ │ │ ├── rspmm.h │ │ │ │ │ ├── spmm.cpp │ │ │ │ │ ├── spmm.cu │ │ │ │ │ ├── spmm.h │ │ │ │ │ └── util.cuh │ │ │ │ ├── functional.py │ │ │ │ └── spmm.py │ │ │ └── readout.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ ├── metric.py │ │ │ └── rdkit │ │ │ │ ├── __init__.py │ │ │ │ └── sascorer.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── gcn.py │ │ ├── patch.py │ │ ├── tasks │ │ │ ├── __init__.py │ │ │ ├── retrosynthesis.py │ │ │ └── task.py │ │ ├── transforms │ │ │ ├── __init__.py │ │ │ └── transform.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── comm.py │ │ │ ├── decorator.py │ │ │ ├── doc.py │ │ │ ├── extension │ │ │ ├── __init__.py │ │ │ └── torch_ext.cpp │ │ │ ├── file.py │ │ │ ├── io.py │ │ │ ├── plot.py │ │ │ ├── pretty.py │ │ │ ├── template │ │ │ └── echarts.html │ │ │ └── torch.py │ └── train.py │ └── stage2 │ ├── configs │ ├── uspto_gdiffretro.yml │ ├── uspto_sample.yml │ └── uspto_size.yml │ ├── merge_result.py │ ├── run_get_results.sh │ ├── sample.py │ ├── src │ ├── __pycache__ │ │ └── utils.cpython-38.pyc │ ├── const.py │ ├── datasets.py │ ├── edm.py │ ├── egnn.py │ ├── lightning.py │ ├── linker_size.py │ ├── linker_size_lightning.py │ ├── metrics.py │ ├── molecule_builder.py │ ├── noise.py │ ├── utils.py │ ├── visualizer-wenhao.py │ └── visualizer.py │ ├── train_gdiffretro.py │ ├── train_size_gnn.py │ ├── vis_get_result.py │ └── xyz_split.py ├── fig ├── fig1_framework.pdf └── framework.png └── requirement.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 sunshy-1 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/final_data/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/sdfdir/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/__init__.py: -------------------------------------------------------------------------------- 1 | from . import patch 2 | from .data.constant import * 3 | 4 | import sys 5 | import logging 6 | 7 | logger = logging.getLogger("") 8 | logger.setLevel(logging.INFO) 9 | format = logging.Formatter("%(asctime)-10s %(message)s", "%H:%M:%S") 10 | 11 | handler = logging.StreamHandler(sys.stdout) 12 | handler.setFormatter(format) 13 | logger.addHandler(handler) 14 | 15 | __version__ = "0.2.1" 16 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import _MetaContainer, Registry, Configurable, make_configurable 2 | from .engine import Engine 3 | from .meter import Meter 4 | from .logger import LoggerBase, LoggingLogger, WandbLogger 5 | 6 | __all__ = [ 7 | "_MetaContainer", "Registry", "Configurable", 8 | "Engine", "Meter", "LoggerBase", "LoggingLogger", "WandbLogger", 9 | ] 10 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dictionary import PerfectHash, Dictionary 2 | from .graph import Graph, PackedGraph, cat 3 | from .molecule import Molecule, PackedMolecule 4 | from .protein import Protein, PackedProtein 5 | from .dataset import MoleculeDataset, ReactionDataset, ProteinDataset, \ 6 | ProteinPairDataset, ProteinLigandDataset, \ 7 | NodeClassificationDataset, KnowledgeGraphDataset, SemiSupervised, \ 8 | semisupervised, key_split, scaffold_split, ordered_scaffold_split 9 | from .dataloader import DataLoader, graph_collate 10 | from . import constant 11 | from . import feature 12 | 13 | __all__ = [ 14 | "Graph", "PackedGraph", "Molecule", "PackedMolecule", "Protein", "PackedProtein", "PerfectHash", "Dictionary", 15 | "MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised", 16 | "ProteinDataset", "ProteinPairDataset", "ProteinLigandDataset", 17 | "semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split", 18 | "DataLoader", "graph_collate", "feature", "constant", 19 | ] 20 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/data/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/data_process/generate_SDF/torchdrug/torchdrug/data/rdkit/__init__.py -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/bace.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.BACE") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class BACE(data.MoleculeDataset): 10 | r""" 11 | Binary binding results for a set of inhibitors of human :math:`\beta`-secretase 1(BACE-1). 12 | 13 | Statistics: 14 | - #Molecule: 1,513 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/bace.csv" 24 | md5 = "ba7f8fa3fdf463a811fa7edea8c982c2" 25 | target_fields = ["Class"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="mol", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/bbbp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.BBBP") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class BBBP(data.MoleculeDataset): 10 | """ 11 | Binary labels of blood-brain barrier penetration. 12 | 13 | Statistics: 14 | - #Molecule: 2,039 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/BBBP.csv" 24 | md5 = "66286cb9e6b148bd75d80c870df580fb" 25 | target_fields = ["p_np"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/cep.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.CEP") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class CEP(data.MoleculeDataset): 10 | """ 11 | Photovoltaic efficiency estimated by Havard clean energy project. 12 | 13 | Statistics: 14 | - #Molecule: 20,000 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-02-cep-pce/cep-processed.csv" 24 | md5 = "b6d257ff416917e4e6baa5e1103f3929" 25 | target_fields = ["PCE"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) 37 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/chembl_filtered.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ChEMBLFiltered") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ChEMBLFiltered(data.MoleculeDataset): 10 | """ 11 | Statistics: 12 | - #Molecule: 430,710 13 | - #Regression task: 1,310 14 | 15 | Parameters: 16 | path (str): path to store the dataset 17 | verbose (int, optional): output verbose level 18 | **kwargs 19 | """ 20 | 21 | url = "https://zenodo.org/record/5528681/files/chembl_filtered_torchdrug.csv.gz" 22 | md5 = "2fff04fecd6e697f28ebb127e8a37561" 23 | 24 | def __init__(self, path, verbose=1, **kwargs): 25 | path = os.path.expanduser(path) 26 | if not os.path.exists(path): 27 | os.makedirs(path) 28 | self.path = path 29 | 30 | zip_file = utils.download(self.url, path, md5=self.md5) 31 | csv_file = utils.extract(zip_file) 32 | 33 | self.target_fields = ["target_{}".format(i) for i in range(1310)] 34 | 35 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/citeseer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.CiteSeer") 8 | class CiteSeer(data.NodeClassificationDataset): 9 | """ 10 | A citation network of scientific publications with binary word features. 11 | 12 | Statistics: 13 | - #Node: 3,327 14 | - #Edge: 8,059 15 | - #Class: 6 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | """ 21 | 22 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/citeseer.tgz" 23 | md5 = "c8ded8ed395b31899576bfd1e91e4d6e" 24 | 25 | def __init__(self, path, verbose=1): 26 | path = os.path.expanduser(path) 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | self.path = path 30 | 31 | zip_file = utils.download(self.url, path, md5=self.md5) 32 | node_file = utils.extract(zip_file, "citeseer/citeseer.content") 33 | edge_file = utils.extract(zip_file, "citeseer/citeseer.cites") 34 | 35 | self.load_tsv(node_file, edge_file, verbose=verbose) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/clintox.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ClinTox") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ClinTox(data.MoleculeDataset): 10 | """ 11 | Qualitative data of drugs approved by the FDA and those that have failed clinical 12 | trials for toxicity reasons. 13 | 14 | Statistics: 15 | - #Molecule: 1,478 16 | - #Classification task: 2 17 | 18 | Parameters: 19 | path (str): path to store the dataset 20 | verbose (int, optional): output verbose level 21 | **kwargs 22 | """ 23 | 24 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/clintox.csv.gz" 25 | md5 = "db4f2df08be8ae92814e9d6a2d015284" 26 | target_fields = ["FDA_APPROVED", "CT_TOX"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | zip_file = utils.download(self.url, path, md5=self.md5) 35 | csv_file = utils.extract(zip_file) 36 | 37 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 38 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/cora.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Cora") 8 | class Cora(data.NodeClassificationDataset): 9 | """ 10 | A citation network of scientific publications with binary word features. 11 | 12 | Statistics: 13 | - #Node: 2,708 14 | - #Edge: 5,429 15 | - #Class: 7 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | """ 21 | 22 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz" 23 | md5 = "2fc040bee8ce3d920e4204effd1e9214" 24 | 25 | def __init__(self, path, verbose=1): 26 | path = os.path.expanduser(path) 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | self.path = path 30 | 31 | zip_file = utils.download(self.url, path, md5=self.md5) 32 | node_file = utils.extract(zip_file, "cora/cora.content") 33 | edge_file = utils.extract(zip_file, "cora/cora.cites") 34 | 35 | self.load_tsv(node_file, edge_file, verbose=verbose) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/delaney.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Delaney") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Delaney(data.MoleculeDataset): 10 | """ 11 | Log-scale water solubility of molecules. 12 | 13 | Statistics: 14 | - #Molecule: 1,128 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/delaney-processed.csv" 24 | md5 = "0c90a51668d446b9e3ab77e67662bd1c" 25 | target_fields = ["measured log solubility in mols per litre"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/freesolv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.FreeSolv") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class FreeSolv(data.MoleculeDataset): 10 | """ 11 | Experimental and calculated hydration free energy of small molecules in water. 12 | 13 | Statistics: 14 | - #Molecule: 642 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/molnet_publish/FreeSolv.zip" 24 | md5 = "8d681babd239b15e2f8b2d29f025577a" 25 | target_fields = ["expt"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, self.path, md5=self.md5) 34 | csv_file = utils.extract(zip_file, "SAMPL.csv") 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/hiv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.HIV") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class HIV(data.MoleculeDataset): 10 | """ 11 | Experimentally measured abilities to inhibit HIV replication. 12 | 13 | Statistics: 14 | - #Molecule: 41,127 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/HIV.csv" 24 | md5 = "9ad10c88f82f1dac7eb5c52b668c30a7" 25 | target_fields = ["HIV_active"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/lipophilicity.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Lipophilicity") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Lipophilicity(data.MoleculeDataset): 10 | """ 11 | Experimental results of octanol/water distribution coefficient (logD at pH 7.4). 12 | 13 | Statistics: 14 | - #Molecule: 4,200 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/Lipophilicity.csv" 24 | md5 = "85a0e1cb8b38b0dfc3f96ff47a57f0ab" 25 | target_fields = ["exp"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) 37 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/malaria.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Malaria") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Malaria(data.MoleculeDataset): 10 | """ 11 | Half-maximal effective concentration (EC50) against a parasite that causes malaria. 12 | 13 | Statistics: 14 | - #Molecule: 10,000 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-03-malaria/" \ 24 | "malaria-processed.csv" 25 | md5 = "ef40ddfd164be0e5ed1bd3dd0cce9b88" 26 | target_fields = ["activity"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | file_name = utils.download(self.url, self.path, md5=self.md5) 35 | 36 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/moses.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | from torch.utils import data as torch_data 5 | 6 | from torchdrug import data, utils 7 | from torchdrug.core import Registry as R 8 | 9 | 10 | @R.register("datasets.MOSES") 11 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 12 | class MOSES(data.MoleculeDataset): 13 | """ 14 | Subset of ZINC database for molecule generation. 15 | This dataset doesn't contain any label information. 16 | 17 | Statistics: 18 | - #Molecule: 1,936,963 19 | 20 | Parameters: 21 | path (str): path for the CSV dataset 22 | verbose (int, optional): output verbose level 23 | **kwargs 24 | """ 25 | 26 | url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv" 27 | md5 = "6bdb0d9526ddf5fdeb87d6aa541df213" 28 | target_fields = ["SPLIT"] 29 | 30 | def __init__(self, path, verbose=1, **kwargs): 31 | path = os.path.expanduser(path) 32 | if not os.path.exists(path): 33 | os.makedirs(path) 34 | self.path = path 35 | 36 | file_name = utils.download(self.url, path, md5=self.md5) 37 | 38 | self.load_csv(file_name, smiles_field="SMILES", target_fields=self.target_fields, 39 | lazy=True, verbose=verbose, **kwargs) 40 | 41 | def split(self): 42 | indexes = defaultdict(list) 43 | for i, split in enumerate(self.targets["SPLIT"]): 44 | indexes[split].append(i) 45 | train_set = torch_data.Subset(self, indexes["train"]) 46 | valid_set = torch_data.Subset(self, indexes["valid"]) 47 | test_set = torch_data.Subset(self, indexes["test"]) 48 | return train_set, valid_set, test_set 49 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/muv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.MUV") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class MUV(data.MoleculeDataset): 10 | """ 11 | Subset of PubChem BioAssay by applying a refined nearest neighbor analysis. 12 | 13 | Statistics: 14 | - #Molecule: 93,087 15 | - #Classification task: 17 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/muv.csv.gz" 24 | md5 = "9c40bd41310991efd40f4d4868fa3ddf" 25 | target_fields = ["MUV-466", "MUV-548", "MUV-600", "MUV-644", "MUV-652", "MUV-689", "MUV-692", "MUV-712", "MUV-713", 26 | "MUV-733", "MUV-737", "MUV-810", "MUV-832", "MUV-846", "MUV-852", "MUV-858", "MUV-859"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | zip_file = utils.download(self.url, path, md5=self.md5) 35 | csv_file = utils.extract(zip_file) 36 | 37 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 38 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/pcqm4m.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.PCQM4M") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class PCQM4M(data.MoleculeDataset): 10 | """ 11 | Quantum chemistry dataset originally curated under the PubChemQC of molecules. 12 | 13 | Statistics: 14 | - #Molecule: 3,803,453 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip" 24 | md5 = "5144ebaa7c67d24da1a2acbe41f57f6a" 25 | target_fields = ["homolumogap"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, self.path, md5=self.md5) 34 | zip_file = utils.extract(zip_file, "pcqm4m_kddcup2021/raw/data.csv.gz") 35 | file_name = utils.extract(zip_file) 36 | 37 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 38 | lazy=True, verbose=verbose, **kwargs) 39 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/pubchem110m.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | from tqdm import tqdm 4 | 5 | from torchdrug import data, utils 6 | from torchdrug.core import Registry as R 7 | 8 | 9 | @R.register("datasets.PubChem110m") 10 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 11 | class PubChem110m(data.MoleculeDataset): 12 | """ 13 | PubChem. 14 | This dataset doesn't contain any label information. 15 | 16 | Statistics: 17 | - #Molecule: 18 | 19 | Parameters: 20 | path (str): 21 | verbose (int, optional): output verbose level 22 | **kwargs 23 | """ 24 | # TODO: download path & md5. Is it the statistics right? 25 | 26 | target_fields = [] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | smiles_file = os.path.join(path, "CID-SMILES") 35 | 36 | with open(smiles_file, "r") as fin: 37 | reader = csv.reader(fin, delimiter="\t") 38 | if verbose: 39 | reader = iter(tqdm(reader, "Loading %s" % path, utils.get_line_count(smiles_file))) 40 | smiles_list = [] 41 | 42 | for values in reader: 43 | smiles = values[1] 44 | smiles_list.append(smiles) 45 | 46 | targets = {} 47 | self.load_smiles(smiles_list, targets, lazy=True, verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/sider.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.SIDER") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class SIDER(data.MoleculeDataset): 10 | """ 11 | Marketed drugs and adverse drug reactions (ADR) dataset, grouped into 27 system organ classes. 12 | 13 | Statistics: 14 | - #Molecule: 1,427 15 | - #Classification task: 27 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/sider.csv.gz" 24 | md5 = "77c0ef421f7cc8ce963c5836c8761fd2" 25 | target_fields = None # pick all targets 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, path, md5=self.md5) 34 | csv_file = utils.extract(zip_file) 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/solubility.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils import data as torch_data 4 | 5 | from torchdrug import data, utils 6 | from torchdrug.core import Registry as R 7 | 8 | 9 | @R.register("datasets.Solubility") 10 | @utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) 11 | class Solubility(data.ProteinDataset): 12 | """ 13 | Proteins with binary labels indicating their solubility. 14 | 15 | Statistics: 16 | - #Train: 62,478 17 | - #Valid: 6,942 18 | - #Test: 1,999 19 | 20 | Parameters: 21 | path (str): the path to store the dataset 22 | verbose (int, optional): output verbose level 23 | **kwargs 24 | """ 25 | 26 | url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/peerdata/solubility.tar.gz" 27 | md5 = "8a8612b7bfa2ed80375db6e465ccf77e" 28 | splits = ["train", "valid", "test"] 29 | target_fields = ["solubility"] 30 | 31 | def __init__(self, path, verbose=1, **kwargs): 32 | path = os.path.expanduser(path) 33 | if not os.path.exists(path): 34 | os.makedirs(path) 35 | self.path = path 36 | 37 | zip_file = utils.download(self.url, path, md5=self.md5) 38 | data_path = utils.extract(zip_file) 39 | lmdb_files = [os.path.join(data_path, "solubility/solubility_%s.lmdb" % split) 40 | for split in self.splits] 41 | 42 | self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) 43 | 44 | def split(self): 45 | offset = 0 46 | splits = [] 47 | for num_sample in self.num_samples: 48 | split = torch_data.Subset(self, range(offset, offset + num_sample)) 49 | splits.append(split) 50 | offset += num_sample 51 | return splits -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/stability.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils import data as torch_data 4 | 5 | from torchdrug import data, utils 6 | from torchdrug.core import Registry as R 7 | 8 | 9 | @R.register("datasets.Stability") 10 | @utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) 11 | class Stability(data.ProteinDataset): 12 | """ 13 | The stability values of proteins under natural environment. 14 | 15 | Statistics: 16 | - #Train: 53,571 17 | - #Valid: 2,512 18 | - #Test: 12,851 19 | 20 | Parameters: 21 | path (str): the path to store the dataset 22 | verbose (int, optional): output verbose level 23 | **kwargs 24 | """ 25 | 26 | url = "http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/stability.tar.gz" 27 | md5 = "aa1e06eb5a59e0ecdae581e9ea029675" 28 | splits = ["train", "valid", "test"] 29 | target_fields = ["stability_score"] 30 | 31 | def __init__(self, path, verbose=1, **kwargs): 32 | path = os.path.expanduser(path) 33 | if not os.path.exists(path): 34 | os.makedirs(path) 35 | self.path = path 36 | 37 | zip_file = utils.download(self.url, path, md5=self.md5) 38 | data_path = utils.extract(zip_file) 39 | lmdb_files = [os.path.join(data_path, "stability/stability_%s.lmdb" % split) 40 | for split in self.splits] 41 | 42 | self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) 43 | 44 | def split(self): 45 | offset = 0 46 | splits = [] 47 | for num_sample in self.num_samples: 48 | split = torch_data.Subset(self, range(offset, offset + num_sample)) 49 | splits.append(split) 50 | offset += num_sample 51 | return splits -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/tox21.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Tox21") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Tox21(data.MoleculeDataset): 10 | """ 11 | Qualitative toxicity measurements on 12 biological targets, including nuclear receptors 12 | and stress response pathways. 13 | 14 | Statistics: 15 | - #Molecule: 7,831 16 | - #Classification task: 12 17 | 18 | Parameters: 19 | path (str): path to store the dataset 20 | verbose (int, optional): output verbose level 21 | **kwargs 22 | """ 23 | 24 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/tox21.csv.gz" 25 | md5 = "2882d69e70bba0fec14995f26787cc25" 26 | target_fields = ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", 27 | "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] 28 | 29 | def __init__(self, path, verbose=1, **kwargs): 30 | path = os.path.expanduser(path) 31 | if not os.path.exists(path): 32 | os.makedirs(path) 33 | self.path = path 34 | 35 | zip_file = utils.download(self.url, path, md5=self.md5) 36 | csv_file = utils.extract(zip_file) 37 | 38 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 39 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/toxcast.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ToxCast") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ToxCast(data.MoleculeDataset): 10 | """ 11 | Toxicology data based on in vitro high-throughput screening. 12 | 13 | Statistics: 14 | - #Molecule: 8,575 15 | - #Classification task: 617 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/toxcast_data.csv.gz" 24 | md5 = "92911bbf9c1e2ad85231014859388cd6" 25 | target_fields = None # pick all targets 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, path, md5=self.md5) 34 | csv_file = utils.extract(zip_file) 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/zinc250k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ZINC250k") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ZINC250k(data.MoleculeDataset): 10 | """ 11 | Subset of ZINC compound database for virtual screening. 12 | 13 | Statistics: 14 | - #Molecule: 498,910 15 | - #Regression task: 2 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/" \ 24 | "250k_rndm_zinc_drugs_clean_3.csv" 25 | md5 = "b59078b2b04c6e9431280e3dc42048d5" 26 | target_fields = ["logP", "qed"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | file_name = utils.download(self.url, path, md5=self.md5) 35 | 36 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/layers/distribution.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections.abc import Sequence 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class IndependentGaussian(nn.Module): 9 | """ 10 | Independent Gaussian distribution. 11 | 12 | Parameters: 13 | mu (Tensor): mean of shape :math:`(N,)` 14 | sigma2 (Tensor): variance of shape :math:`(N,)` 15 | learnable (bool, optional): learnable parameters or not 16 | """ 17 | 18 | def __init__(self, mu, sigma2, learnable=False): 19 | super(IndependentGaussian, self).__init__() 20 | if learnable: 21 | self.mu = nn.Parameter(torch.as_tensor(mu)) 22 | self.sigma2 = nn.Parameter(torch.as_tensor(sigma2)) 23 | else: 24 | self.register_buffer("mu", torch.as_tensor(mu)) 25 | self.register_buffer("sigma2", torch.as_tensor(sigma2)) 26 | self.dim = len(mu) 27 | 28 | def forward(self, input): 29 | """ 30 | Compute the likelihood of input data. 31 | 32 | Parameters: 33 | input (Tensor): input data of shape :math:`(..., N)` 34 | """ 35 | log_likelihood = -0.5 * (math.log(2 * math.pi) + self.sigma2.log() + (input - self.mu) ** 2 / self.sigma2) 36 | return log_likelihood 37 | 38 | def sample(self, *size): 39 | """ 40 | Draw samples from the distribution. 41 | 42 | Parameters: 43 | size (tuple of int): shape of the samples 44 | """ 45 | if len(size) == 1 and isinstance(size[0], Sequence): 46 | size = size[0] 47 | size = list(size) + [self.dim] 48 | 49 | sample = torch.randn(size, device=self.mu.device) * self.sigma2.sqrt() + self.mu 50 | return sample 51 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/layers/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import multinomial, masked_mean, mean_with_nan, shifted_softplus, multi_slice, multi_slice_mask, \ 2 | as_mask, _extend, variadic_log_softmax, variadic_softmax, variadic_sum, variadic_mean, variadic_max, \ 3 | variadic_cross_entropy, variadic_sort, variadic_topk, variadic_arange, variadic_randperm, variadic_sample,\ 4 | variadic_meshgrid, variadic_to_padded, padded_to_variadic, one_hot, clipped_policy_gradient_objective, \ 5 | policy_gradient_objective 6 | from .embedding import transe_score, distmult_score, complex_score, simple_score, rotate_score 7 | from .spmm import generalized_spmm, generalized_rspmm 8 | 9 | __all__ = [ 10 | "multinomial", "masked_mean", "mean_with_nan", "shifted_softplus", "multi_slice_mask", "as_mask", 11 | "variadic_log_softmax", "variadic_softmax", "variadic_sum", "variadic_mean", "variadic_max", 12 | "variadic_cross_entropy", "variadic_sort", "variadic_topk", "variadic_arange", "variadic_randperm", 13 | "variadic_sample", "variadic_meshgrid", "variadic_to_padded", "padded_to_variadic", 14 | "one_hot", "clipped_policy_gradient_objective", "policy_gradient_objective", 15 | "transe_score", "distmult_score", "complex_score", "simple_score", "rotate_score", 16 | "generalized_spmm", "generalized_rspmm", 17 | ] -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/layers/functional/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/data_process/generate_SDF/torchdrug/torchdrug/layers/functional/extension/__init__.py -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/layers/functional/extension/util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace at { 4 | 5 | const unsigned kFullMask = 0xFFFFFFFF; 6 | 7 | template 8 | __device__ scalar_t warp_reduce(scalar_t value) { 9 | #pragma unroll 10 | for (int delta = 1; delta < warpSize; delta *= 2) 11 | #if __CUDACC_VER_MAJOR__ >= 9 12 | value += __shfl_down_sync(kFullMask, value, delta); 13 | #else 14 | value += __shfl_down(value, delta); 15 | #endif 16 | return value; 17 | } 18 | 19 | template 20 | __device__ scalar_t warp_broadcast(scalar_t value, int lane_id) { 21 | #if __CUDACC_VER_MAJOR__ >= 9 22 | return __shfl_sync(kFullMask, value, lane_id); 23 | #else 24 | return __shfl(value, lane_id); 25 | #endif 26 | } 27 | 28 | } // namespace at -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/layers/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph import GraphConstruction, SpatialLineGraph 2 | from .function import BondEdge, KNNEdge, SpatialEdge, SequentialEdge, AlphaCarbonNode, \ 3 | IdentityNode, RandomEdgeMask, SubsequenceNode, SubspaceNode 4 | 5 | __all__ = [ 6 | "GraphConstruction", "SpatialLineGraph", 7 | "BondEdge", "KNNEdge", "SpatialEdge", "SequentialEdge", "AlphaCarbonNode", 8 | "IdentityNode", "RandomEdgeMask", "SubsequenceNode", "SubspaceNode" 9 | ] -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import area_under_roc, area_under_prc, r2, QED, logP, penalized_logP, SA, chemical_validity, \ 2 | accuracy, variadic_accuracy, matthews_corrcoef, pearsonr, spearmanr, \ 3 | variadic_area_under_prc, variadic_area_under_roc, variadic_top_precision, f1_max 4 | 5 | # alias 6 | AUROC = area_under_roc 7 | AUPRC = area_under_prc 8 | 9 | __all__ = [ 10 | "area_under_roc", "area_under_prc", "r2", "QED", "logP", "penalized_logP", "SA", "chemical_validity", 11 | "accuracy", "variadic_accuracy", "matthews_corrcoef", "pearsonr", "spearmanr", 12 | "variadic_area_under_prc", "variadic_area_under_roc", "variadic_top_precision", "f1_max", 13 | "AUROC", "AUPRC", 14 | ] -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/metrics/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/data_process/generate_SDF/torchdrug/torchdrug/metrics/rdkit/__init__.py -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/tasks/task.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | 3 | from torch import nn 4 | 5 | 6 | class Task(nn.Module): 7 | 8 | _option_members = set() 9 | 10 | def _standarize_option(self, x, name): 11 | if x is None: 12 | x = {} 13 | elif isinstance(x, str): 14 | x = {x: 1} 15 | elif isinstance(x, Sequence): 16 | x = dict.fromkeys(x, 1) 17 | elif not isinstance(x, Mapping): 18 | raise ValueError("Invalid value `%s` for option member `%s`" % (x, name)) 19 | return x 20 | 21 | def __setattr__(self, key, value): 22 | if key in self._option_members: 23 | value = self._standarize_option(value, key) 24 | super(Task, self).__setattr__(key, value) 25 | 26 | def preprocess(self, train_set, valid_set, test_set): 27 | pass 28 | 29 | def predict_and_target(self, batch, all_loss=None, metric=None): 30 | return self.predict(batch, all_loss, metric), self.target(batch) 31 | 32 | def predict(self, batch, all_loss=None, metric=None): 33 | raise NotImplementedError 34 | 35 | def target(self, batch): 36 | raise NotImplementedError 37 | 38 | def evaluate(self, pred, target): 39 | raise NotImplementedError -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import NormalizeTarget, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, \ 2 | VirtualAtom, TruncateProtein, ProteinView, Compose 3 | 4 | __all__ = [ 5 | "NormalizeTarget", "RemapAtomType", "RandomBFSOrder", "Shuffle", 6 | "VirtualNode", "VirtualAtom", "TruncateProtein", "ProteinView", "Compose", 7 | ] 8 | -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .io import input_choice, literal_eval, no_rdkit_log, capture_rdkit_log 2 | from .file import download, smart_open, extract, compute_md5, get_line_count 3 | from .torch import load_extension, cpu, cuda, detach, clone, mean, cat, stack, sparse_coo_tensor 4 | from .decorator import copy_args, cached_property, cached, deprecated_alias 5 | from . import pretty, comm, plot 6 | 7 | __all__ = [ 8 | "input_choice", "literal_eval", "no_rdkit_log", "capture_rdkit_log", 9 | "download", "smart_open", "extract", "compute_md5", "get_line_count", 10 | "load_extension", "cpu", "cuda", "detach", "clone", "mean", "cat", "stack", "sparse_coo_tensor", 11 | "copy_args", "cached_property", "cached", "deprecated_alias", 12 | "pretty", "comm", "plot", 13 | ] -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/utils/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/data_process/generate_SDF/torchdrug/torchdrug/utils/extension/__init__.py -------------------------------------------------------------------------------- /code/known_class/data_process/generate_SDF/torchdrug/torchdrug/utils/extension/torch_ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace at { 4 | 5 | Tensor sparse_coo_tensor_unsafe(const Tensor &indices, const Tensor &values, IntArrayRef size) { 6 | return _sparse_coo_tensor_unsafe(indices, values, size, values.options().layout(kSparse)); 7 | } 8 | 9 | } 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("sparse_coo_tensor_unsafe", &at::sparse_coo_tensor_unsafe, 13 | "Construct sparse COO tensor without index check"); 14 | } -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/cmp_data/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/molecule-datasets/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/__init__.py: -------------------------------------------------------------------------------- 1 | from . import patch 2 | from .data.constant import * 3 | 4 | import sys 5 | import logging 6 | 7 | logger = logging.getLogger("") 8 | logger.setLevel(logging.INFO) 9 | format = logging.Formatter("%(asctime)-10s %(message)s", "%H:%M:%S") 10 | 11 | handler = logging.StreamHandler(sys.stdout) 12 | handler.setFormatter(format) 13 | logger.addHandler(handler) 14 | 15 | __version__ = "0.2.1" 16 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import _MetaContainer, Registry, Configurable, make_configurable 2 | from .engine import Engine 3 | from .meter import Meter 4 | from .logger import LoggerBase, LoggingLogger, WandbLogger 5 | 6 | __all__ = [ 7 | "_MetaContainer", "Registry", "Configurable", 8 | "Engine", "Meter", "LoggerBase", "LoggingLogger", "WandbLogger", 9 | ] 10 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dictionary import PerfectHash, Dictionary 2 | from .graph import Graph, PackedGraph, cat 3 | from .molecule import Molecule, PackedMolecule 4 | from .protein import Protein, PackedProtein 5 | from .dataset import MoleculeDataset, ReactionDataset, ProteinDataset, \ 6 | ProteinPairDataset, ProteinLigandDataset, \ 7 | NodeClassificationDataset, KnowledgeGraphDataset, SemiSupervised, \ 8 | semisupervised, key_split, scaffold_split, ordered_scaffold_split 9 | from .dataloader import DataLoader, graph_collate 10 | from . import constant 11 | from . import feature 12 | 13 | __all__ = [ 14 | "Graph", "PackedGraph", "Molecule", "PackedMolecule", "Protein", "PackedProtein", "PerfectHash", "Dictionary", 15 | "MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised", 16 | "ProteinDataset", "ProteinPairDataset", "ProteinLigandDataset", 17 | "semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split", 18 | "DataLoader", "graph_collate", "feature", "constant", 19 | ] 20 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/data/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/data_process/get_dataset_1st_stage/torchdrug/data/rdkit/__init__.py -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/bace.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.BACE") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class BACE(data.MoleculeDataset): 10 | r""" 11 | Binary binding results for a set of inhibitors of human :math:`\beta`-secretase 1(BACE-1). 12 | 13 | Statistics: 14 | - #Molecule: 1,513 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/bace.csv" 24 | md5 = "ba7f8fa3fdf463a811fa7edea8c982c2" 25 | target_fields = ["Class"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="mol", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/bbbp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.BBBP") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class BBBP(data.MoleculeDataset): 10 | """ 11 | Binary labels of blood-brain barrier penetration. 12 | 13 | Statistics: 14 | - #Molecule: 2,039 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/BBBP.csv" 24 | md5 = "66286cb9e6b148bd75d80c870df580fb" 25 | target_fields = ["p_np"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/cep.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.CEP") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class CEP(data.MoleculeDataset): 10 | """ 11 | Photovoltaic efficiency estimated by Havard clean energy project. 12 | 13 | Statistics: 14 | - #Molecule: 20,000 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-02-cep-pce/cep-processed.csv" 24 | md5 = "b6d257ff416917e4e6baa5e1103f3929" 25 | target_fields = ["PCE"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) 37 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/chembl_filtered.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ChEMBLFiltered") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ChEMBLFiltered(data.MoleculeDataset): 10 | """ 11 | Statistics: 12 | - #Molecule: 430,710 13 | - #Regression task: 1,310 14 | 15 | Parameters: 16 | path (str): path to store the dataset 17 | verbose (int, optional): output verbose level 18 | **kwargs 19 | """ 20 | 21 | url = "https://zenodo.org/record/5528681/files/chembl_filtered_torchdrug.csv.gz" 22 | md5 = "2fff04fecd6e697f28ebb127e8a37561" 23 | 24 | def __init__(self, path, verbose=1, **kwargs): 25 | path = os.path.expanduser(path) 26 | if not os.path.exists(path): 27 | os.makedirs(path) 28 | self.path = path 29 | 30 | zip_file = utils.download(self.url, path, md5=self.md5) 31 | csv_file = utils.extract(zip_file) 32 | 33 | self.target_fields = ["target_{}".format(i) for i in range(1310)] 34 | 35 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/citeseer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.CiteSeer") 8 | class CiteSeer(data.NodeClassificationDataset): 9 | """ 10 | A citation network of scientific publications with binary word features. 11 | 12 | Statistics: 13 | - #Node: 3,327 14 | - #Edge: 8,059 15 | - #Class: 6 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | """ 21 | 22 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/citeseer.tgz" 23 | md5 = "c8ded8ed395b31899576bfd1e91e4d6e" 24 | 25 | def __init__(self, path, verbose=1): 26 | path = os.path.expanduser(path) 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | self.path = path 30 | 31 | zip_file = utils.download(self.url, path, md5=self.md5) 32 | node_file = utils.extract(zip_file, "citeseer/citeseer.content") 33 | edge_file = utils.extract(zip_file, "citeseer/citeseer.cites") 34 | 35 | self.load_tsv(node_file, edge_file, verbose=verbose) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/clintox.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ClinTox") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ClinTox(data.MoleculeDataset): 10 | """ 11 | Qualitative data of drugs approved by the FDA and those that have failed clinical 12 | trials for toxicity reasons. 13 | 14 | Statistics: 15 | - #Molecule: 1,478 16 | - #Classification task: 2 17 | 18 | Parameters: 19 | path (str): path to store the dataset 20 | verbose (int, optional): output verbose level 21 | **kwargs 22 | """ 23 | 24 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/clintox.csv.gz" 25 | md5 = "db4f2df08be8ae92814e9d6a2d015284" 26 | target_fields = ["FDA_APPROVED", "CT_TOX"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | zip_file = utils.download(self.url, path, md5=self.md5) 35 | csv_file = utils.extract(zip_file) 36 | 37 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 38 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/cora.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Cora") 8 | class Cora(data.NodeClassificationDataset): 9 | """ 10 | A citation network of scientific publications with binary word features. 11 | 12 | Statistics: 13 | - #Node: 2,708 14 | - #Edge: 5,429 15 | - #Class: 7 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | """ 21 | 22 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz" 23 | md5 = "2fc040bee8ce3d920e4204effd1e9214" 24 | 25 | def __init__(self, path, verbose=1): 26 | path = os.path.expanduser(path) 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | self.path = path 30 | 31 | zip_file = utils.download(self.url, path, md5=self.md5) 32 | node_file = utils.extract(zip_file, "cora/cora.content") 33 | edge_file = utils.extract(zip_file, "cora/cora.cites") 34 | 35 | self.load_tsv(node_file, edge_file, verbose=verbose) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/delaney.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Delaney") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Delaney(data.MoleculeDataset): 10 | """ 11 | Log-scale water solubility of molecules. 12 | 13 | Statistics: 14 | - #Molecule: 1,128 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/delaney-processed.csv" 24 | md5 = "0c90a51668d446b9e3ab77e67662bd1c" 25 | target_fields = ["measured log solubility in mols per litre"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/freesolv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.FreeSolv") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class FreeSolv(data.MoleculeDataset): 10 | """ 11 | Experimental and calculated hydration free energy of small molecules in water. 12 | 13 | Statistics: 14 | - #Molecule: 642 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/molnet_publish/FreeSolv.zip" 24 | md5 = "8d681babd239b15e2f8b2d29f025577a" 25 | target_fields = ["expt"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, self.path, md5=self.md5) 34 | csv_file = utils.extract(zip_file, "SAMPL.csv") 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/hiv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.HIV") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class HIV(data.MoleculeDataset): 10 | """ 11 | Experimentally measured abilities to inhibit HIV replication. 12 | 13 | Statistics: 14 | - #Molecule: 41,127 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/HIV.csv" 24 | md5 = "9ad10c88f82f1dac7eb5c52b668c30a7" 25 | target_fields = ["HIV_active"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/lipophilicity.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Lipophilicity") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Lipophilicity(data.MoleculeDataset): 10 | """ 11 | Experimental results of octanol/water distribution coefficient (logD at pH 7.4). 12 | 13 | Statistics: 14 | - #Molecule: 4,200 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/Lipophilicity.csv" 24 | md5 = "85a0e1cb8b38b0dfc3f96ff47a57f0ab" 25 | target_fields = ["exp"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) 37 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/malaria.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Malaria") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Malaria(data.MoleculeDataset): 10 | """ 11 | Half-maximal effective concentration (EC50) against a parasite that causes malaria. 12 | 13 | Statistics: 14 | - #Molecule: 10,000 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-03-malaria/" \ 24 | "malaria-processed.csv" 25 | md5 = "ef40ddfd164be0e5ed1bd3dd0cce9b88" 26 | target_fields = ["activity"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | file_name = utils.download(self.url, self.path, md5=self.md5) 35 | 36 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/moses.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | from torch.utils import data as torch_data 5 | 6 | from torchdrug import data, utils 7 | from torchdrug.core import Registry as R 8 | 9 | 10 | @R.register("datasets.MOSES") 11 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 12 | class MOSES(data.MoleculeDataset): 13 | """ 14 | Subset of ZINC database for molecule generation. 15 | This dataset doesn't contain any label information. 16 | 17 | Statistics: 18 | - #Molecule: 1,936,963 19 | 20 | Parameters: 21 | path (str): path for the CSV dataset 22 | verbose (int, optional): output verbose level 23 | **kwargs 24 | """ 25 | 26 | url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv" 27 | md5 = "6bdb0d9526ddf5fdeb87d6aa541df213" 28 | target_fields = ["SPLIT"] 29 | 30 | def __init__(self, path, verbose=1, **kwargs): 31 | path = os.path.expanduser(path) 32 | if not os.path.exists(path): 33 | os.makedirs(path) 34 | self.path = path 35 | 36 | file_name = utils.download(self.url, path, md5=self.md5) 37 | 38 | self.load_csv(file_name, smiles_field="SMILES", target_fields=self.target_fields, 39 | lazy=True, verbose=verbose, **kwargs) 40 | 41 | def split(self): 42 | indexes = defaultdict(list) 43 | for i, split in enumerate(self.targets["SPLIT"]): 44 | indexes[split].append(i) 45 | train_set = torch_data.Subset(self, indexes["train"]) 46 | valid_set = torch_data.Subset(self, indexes["valid"]) 47 | test_set = torch_data.Subset(self, indexes["test"]) 48 | return train_set, valid_set, test_set 49 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/muv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.MUV") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class MUV(data.MoleculeDataset): 10 | """ 11 | Subset of PubChem BioAssay by applying a refined nearest neighbor analysis. 12 | 13 | Statistics: 14 | - #Molecule: 93,087 15 | - #Classification task: 17 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/muv.csv.gz" 24 | md5 = "9c40bd41310991efd40f4d4868fa3ddf" 25 | target_fields = ["MUV-466", "MUV-548", "MUV-600", "MUV-644", "MUV-652", "MUV-689", "MUV-692", "MUV-712", "MUV-713", 26 | "MUV-733", "MUV-737", "MUV-810", "MUV-832", "MUV-846", "MUV-852", "MUV-858", "MUV-859"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | zip_file = utils.download(self.url, path, md5=self.md5) 35 | csv_file = utils.extract(zip_file) 36 | 37 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 38 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/pcqm4m.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.PCQM4M") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class PCQM4M(data.MoleculeDataset): 10 | """ 11 | Quantum chemistry dataset originally curated under the PubChemQC of molecules. 12 | 13 | Statistics: 14 | - #Molecule: 3,803,453 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip" 24 | md5 = "5144ebaa7c67d24da1a2acbe41f57f6a" 25 | target_fields = ["homolumogap"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, self.path, md5=self.md5) 34 | zip_file = utils.extract(zip_file, "pcqm4m_kddcup2021/raw/data.csv.gz") 35 | file_name = utils.extract(zip_file) 36 | 37 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 38 | lazy=True, verbose=verbose, **kwargs) 39 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/pubchem110m.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | from tqdm import tqdm 4 | 5 | from torchdrug import data, utils 6 | from torchdrug.core import Registry as R 7 | 8 | 9 | @R.register("datasets.PubChem110m") 10 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 11 | class PubChem110m(data.MoleculeDataset): 12 | """ 13 | PubChem. 14 | This dataset doesn't contain any label information. 15 | 16 | Statistics: 17 | - #Molecule: 18 | 19 | Parameters: 20 | path (str): 21 | verbose (int, optional): output verbose level 22 | **kwargs 23 | """ 24 | # TODO: download path & md5. Is it the statistics right? 25 | 26 | target_fields = [] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | smiles_file = os.path.join(path, "CID-SMILES") 35 | 36 | with open(smiles_file, "r") as fin: 37 | reader = csv.reader(fin, delimiter="\t") 38 | if verbose: 39 | reader = iter(tqdm(reader, "Loading %s" % path, utils.get_line_count(smiles_file))) 40 | smiles_list = [] 41 | 42 | for values in reader: 43 | smiles = values[1] 44 | smiles_list.append(smiles) 45 | 46 | targets = {} 47 | self.load_smiles(smiles_list, targets, lazy=True, verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/sider.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.SIDER") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class SIDER(data.MoleculeDataset): 10 | """ 11 | Marketed drugs and adverse drug reactions (ADR) dataset, grouped into 27 system organ classes. 12 | 13 | Statistics: 14 | - #Molecule: 1,427 15 | - #Classification task: 27 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/sider.csv.gz" 24 | md5 = "77c0ef421f7cc8ce963c5836c8761fd2" 25 | target_fields = None # pick all targets 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, path, md5=self.md5) 34 | csv_file = utils.extract(zip_file) 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/stability.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils import data as torch_data 4 | 5 | from torchdrug import data, utils 6 | from torchdrug.core import Registry as R 7 | 8 | 9 | @R.register("datasets.Stability") 10 | @utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) 11 | class Stability(data.ProteinDataset): 12 | """ 13 | The stability values of proteins under natural environment. 14 | 15 | Statistics: 16 | - #Train: 53,571 17 | - #Valid: 2,512 18 | - #Test: 12,851 19 | 20 | Parameters: 21 | path (str): the path to store the dataset 22 | verbose (int, optional): output verbose level 23 | **kwargs 24 | """ 25 | 26 | url = "http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/stability.tar.gz" 27 | md5 = "aa1e06eb5a59e0ecdae581e9ea029675" 28 | splits = ["train", "valid", "test"] 29 | target_fields = ["stability_score"] 30 | 31 | def __init__(self, path, verbose=1, **kwargs): 32 | path = os.path.expanduser(path) 33 | if not os.path.exists(path): 34 | os.makedirs(path) 35 | self.path = path 36 | 37 | zip_file = utils.download(self.url, path, md5=self.md5) 38 | data_path = utils.extract(zip_file) 39 | lmdb_files = [os.path.join(data_path, "stability/stability_%s.lmdb" % split) 40 | for split in self.splits] 41 | 42 | self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) 43 | 44 | def split(self): 45 | offset = 0 46 | splits = [] 47 | for num_sample in self.num_samples: 48 | split = torch_data.Subset(self, range(offset, offset + num_sample)) 49 | splits.append(split) 50 | offset += num_sample 51 | return splits -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/tox21.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Tox21") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Tox21(data.MoleculeDataset): 10 | """ 11 | Qualitative toxicity measurements on 12 biological targets, including nuclear receptors 12 | and stress response pathways. 13 | 14 | Statistics: 15 | - #Molecule: 7,831 16 | - #Classification task: 12 17 | 18 | Parameters: 19 | path (str): path to store the dataset 20 | verbose (int, optional): output verbose level 21 | **kwargs 22 | """ 23 | 24 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/tox21.csv.gz" 25 | md5 = "2882d69e70bba0fec14995f26787cc25" 26 | target_fields = ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", 27 | "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] 28 | 29 | def __init__(self, path, verbose=1, **kwargs): 30 | path = os.path.expanduser(path) 31 | if not os.path.exists(path): 32 | os.makedirs(path) 33 | self.path = path 34 | 35 | zip_file = utils.download(self.url, path, md5=self.md5) 36 | csv_file = utils.extract(zip_file) 37 | 38 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 39 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/toxcast.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ToxCast") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ToxCast(data.MoleculeDataset): 10 | """ 11 | Toxicology data based on in vitro high-throughput screening. 12 | 13 | Statistics: 14 | - #Molecule: 8,575 15 | - #Classification task: 617 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/toxcast_data.csv.gz" 24 | md5 = "92911bbf9c1e2ad85231014859388cd6" 25 | target_fields = None # pick all targets 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, path, md5=self.md5) 34 | csv_file = utils.extract(zip_file) 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/datasets/zinc250k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ZINC250k") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ZINC250k(data.MoleculeDataset): 10 | """ 11 | Subset of ZINC compound database for virtual screening. 12 | 13 | Statistics: 14 | - #Molecule: 498,910 15 | - #Regression task: 2 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/" \ 24 | "250k_rndm_zinc_drugs_clean_3.csv" 25 | md5 = "b59078b2b04c6e9431280e3dc42048d5" 26 | target_fields = ["logP", "qed"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | file_name = utils.download(self.url, path, md5=self.md5) 35 | 36 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/layers/distribution.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections.abc import Sequence 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class IndependentGaussian(nn.Module): 9 | """ 10 | Independent Gaussian distribution. 11 | 12 | Parameters: 13 | mu (Tensor): mean of shape :math:`(N,)` 14 | sigma2 (Tensor): variance of shape :math:`(N,)` 15 | learnable (bool, optional): learnable parameters or not 16 | """ 17 | 18 | def __init__(self, mu, sigma2, learnable=False): 19 | super(IndependentGaussian, self).__init__() 20 | if learnable: 21 | self.mu = nn.Parameter(torch.as_tensor(mu)) 22 | self.sigma2 = nn.Parameter(torch.as_tensor(sigma2)) 23 | else: 24 | self.register_buffer("mu", torch.as_tensor(mu)) 25 | self.register_buffer("sigma2", torch.as_tensor(sigma2)) 26 | self.dim = len(mu) 27 | 28 | def forward(self, input): 29 | """ 30 | Compute the likelihood of input data. 31 | 32 | Parameters: 33 | input (Tensor): input data of shape :math:`(..., N)` 34 | """ 35 | log_likelihood = -0.5 * (math.log(2 * math.pi) + self.sigma2.log() + (input - self.mu) ** 2 / self.sigma2) 36 | return log_likelihood 37 | 38 | def sample(self, *size): 39 | """ 40 | Draw samples from the distribution. 41 | 42 | Parameters: 43 | size (tuple of int): shape of the samples 44 | """ 45 | if len(size) == 1 and isinstance(size[0], Sequence): 46 | size = size[0] 47 | size = list(size) + [self.dim] 48 | 49 | sample = torch.randn(size, device=self.mu.device) * self.sigma2.sqrt() + self.mu 50 | return sample 51 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/layers/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import multinomial, masked_mean, mean_with_nan, shifted_softplus, multi_slice, multi_slice_mask, \ 2 | as_mask, _extend, variadic_log_softmax, variadic_softmax, variadic_sum, variadic_mean, variadic_max, \ 3 | variadic_cross_entropy, variadic_sort, variadic_topk, variadic_arange, variadic_randperm, variadic_sample,\ 4 | variadic_meshgrid, variadic_to_padded, padded_to_variadic, one_hot, clipped_policy_gradient_objective, \ 5 | policy_gradient_objective 6 | from .embedding import transe_score, distmult_score, complex_score, simple_score, rotate_score 7 | from .spmm import generalized_spmm, generalized_rspmm 8 | 9 | __all__ = [ 10 | "multinomial", "masked_mean", "mean_with_nan", "shifted_softplus", "multi_slice_mask", "as_mask", 11 | "variadic_log_softmax", "variadic_softmax", "variadic_sum", "variadic_mean", "variadic_max", 12 | "variadic_cross_entropy", "variadic_sort", "variadic_topk", "variadic_arange", "variadic_randperm", 13 | "variadic_sample", "variadic_meshgrid", "variadic_to_padded", "padded_to_variadic", 14 | "one_hot", "clipped_policy_gradient_objective", "policy_gradient_objective", 15 | "transe_score", "distmult_score", "complex_score", "simple_score", "rotate_score", 16 | "generalized_spmm", "generalized_rspmm", 17 | ] -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/layers/functional/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/data_process/get_dataset_1st_stage/torchdrug/layers/functional/extension/__init__.py -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/layers/functional/extension/util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace at { 4 | 5 | const unsigned kFullMask = 0xFFFFFFFF; 6 | 7 | template 8 | __device__ scalar_t warp_reduce(scalar_t value) { 9 | #pragma unroll 10 | for (int delta = 1; delta < warpSize; delta *= 2) 11 | #if __CUDACC_VER_MAJOR__ >= 9 12 | value += __shfl_down_sync(kFullMask, value, delta); 13 | #else 14 | value += __shfl_down(value, delta); 15 | #endif 16 | return value; 17 | } 18 | 19 | template 20 | __device__ scalar_t warp_broadcast(scalar_t value, int lane_id) { 21 | #if __CUDACC_VER_MAJOR__ >= 9 22 | return __shfl_sync(kFullMask, value, lane_id); 23 | #else 24 | return __shfl(value, lane_id); 25 | #endif 26 | } 27 | 28 | } // namespace at -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/layers/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph import GraphConstruction, SpatialLineGraph 2 | from .function import BondEdge, KNNEdge, SpatialEdge, SequentialEdge, AlphaCarbonNode, \ 3 | IdentityNode, RandomEdgeMask, SubsequenceNode, SubspaceNode 4 | 5 | __all__ = [ 6 | "GraphConstruction", "SpatialLineGraph", 7 | "BondEdge", "KNNEdge", "SpatialEdge", "SequentialEdge", "AlphaCarbonNode", 8 | "IdentityNode", "RandomEdgeMask", "SubsequenceNode", "SubspaceNode" 9 | ] -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import area_under_roc, area_under_prc, r2, QED, logP, penalized_logP, SA, chemical_validity, \ 2 | accuracy, variadic_accuracy, matthews_corrcoef, pearsonr, spearmanr, \ 3 | variadic_area_under_prc, variadic_area_under_roc, variadic_top_precision, f1_max 4 | 5 | # alias 6 | AUROC = area_under_roc 7 | AUPRC = area_under_prc 8 | 9 | __all__ = [ 10 | "area_under_roc", "area_under_prc", "r2", "QED", "logP", "penalized_logP", "SA", "chemical_validity", 11 | "accuracy", "variadic_accuracy", "matthews_corrcoef", "pearsonr", "spearmanr", 12 | "variadic_area_under_prc", "variadic_area_under_roc", "variadic_top_precision", "f1_max", 13 | "AUROC", "AUPRC", 14 | ] -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/metrics/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/data_process/get_dataset_1st_stage/torchdrug/metrics/rdkit/__init__.py -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/tasks/task.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | 3 | from torch import nn 4 | 5 | 6 | class Task(nn.Module): 7 | 8 | _option_members = set() 9 | 10 | def _standarize_option(self, x, name): 11 | if x is None: 12 | x = {} 13 | elif isinstance(x, str): 14 | x = {x: 1} 15 | elif isinstance(x, Sequence): 16 | x = dict.fromkeys(x, 1) 17 | elif not isinstance(x, Mapping): 18 | raise ValueError("Invalid value `%s` for option member `%s`" % (x, name)) 19 | return x 20 | 21 | def __setattr__(self, key, value): 22 | if key in self._option_members: 23 | value = self._standarize_option(value, key) 24 | super(Task, self).__setattr__(key, value) 25 | 26 | def preprocess(self, train_set, valid_set, test_set): 27 | pass 28 | 29 | def predict_and_target(self, batch, all_loss=None, metric=None): 30 | return self.predict(batch, all_loss, metric), self.target(batch) 31 | 32 | def predict(self, batch, all_loss=None, metric=None): 33 | raise NotImplementedError 34 | 35 | def target(self, batch): 36 | raise NotImplementedError 37 | 38 | def evaluate(self, pred, target): 39 | raise NotImplementedError -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import NormalizeTarget, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, \ 2 | VirtualAtom, TruncateProtein, ProteinView, Compose 3 | 4 | __all__ = [ 5 | "NormalizeTarget", "RemapAtomType", "RandomBFSOrder", "Shuffle", 6 | "VirtualNode", "VirtualAtom", "TruncateProtein", "ProteinView", "Compose", 7 | ] 8 | -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .io import input_choice, literal_eval, no_rdkit_log, capture_rdkit_log 2 | from .file import download, smart_open, extract, compute_md5, get_line_count 3 | from .torch import load_extension, cpu, cuda, detach, clone, mean, cat, stack, sparse_coo_tensor 4 | from .decorator import copy_args, cached_property, cached, deprecated_alias 5 | from . import pretty, comm, plot 6 | 7 | __all__ = [ 8 | "input_choice", "literal_eval", "no_rdkit_log", "capture_rdkit_log", 9 | "download", "smart_open", "extract", "compute_md5", "get_line_count", 10 | "load_extension", "cpu", "cuda", "detach", "clone", "mean", "cat", "stack", "sparse_coo_tensor", 11 | "copy_args", "cached_property", "cached", "deprecated_alias", 12 | "pretty", "comm", "plot", 13 | ] -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/utils/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/data_process/get_dataset_1st_stage/torchdrug/utils/extension/__init__.py -------------------------------------------------------------------------------- /code/known_class/data_process/get_dataset_1st_stage/torchdrug/utils/extension/torch_ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace at { 4 | 5 | Tensor sparse_coo_tensor_unsafe(const Tensor &indices, const Tensor &values, IntArrayRef size) { 6 | return _sparse_coo_tensor_unsafe(indices, values, size, values.options().layout(kSparse)); 7 | } 8 | 9 | } 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("sparse_coo_tensor_unsafe", &at::sparse_coo_tensor_unsafe, 13 | "Construct sparse COO tensor without index check"); 14 | } -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/__init__.py: -------------------------------------------------------------------------------- 1 | from . import patch 2 | from .data.constant import * 3 | 4 | import sys 5 | import logging 6 | 7 | logger = logging.getLogger("") 8 | logger.setLevel(logging.INFO) 9 | format = logging.Formatter("%(asctime)-10s %(message)s", "%H:%M:%S") 10 | 11 | handler = logging.StreamHandler(sys.stdout) 12 | handler.setFormatter(format) 13 | logger.addHandler(handler) 14 | 15 | __version__ = "0.1.3" -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import _MetaContainer, Registry, Configurable, make_configurable 2 | from .engine import Engine 3 | from .meter import Meter 4 | from .logger import LoggerBase, LoggingLogger, WandbLogger 5 | 6 | __all__ = [ 7 | "_MetaContainer", "Registry", "Configurable", 8 | "Engine", "Meter", "LoggerBase", "LoggingLogger", "WandbLogger", 9 | ] -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dictionary import PerfectHash, Dictionary 2 | from .graph import Graph, PackedGraph, cat 3 | from .molecule import Molecule, PackedMolecule 4 | from .dataset import MoleculeDataset, ReactionDataset, \ 5 | key_split, scaffold_split, ordered_scaffold_split 6 | from .dataloader import DataLoader, graph_collate 7 | from . import constant 8 | from . import feature 9 | 10 | __all__ = [ 11 | "Graph", "PackedGraph", "Molecule", "PackedMolecule", "PerfectHash", "Dictionary", 12 | "MoleculeDataset", "ReactionDataset", "key_split", "scaffold_split", "ordered_scaffold_split", 13 | "DataLoader", "graph_collate", "feature", "constant", 14 | ] 15 | -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/data/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/stage1/torchdrug/data/rdkit/__init__.py -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .uspto50k import USPTO50k 2 | __all__ = ["USPTO50k"] 3 | -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/layers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .common import MultiLayerPerceptron 3 | from .conv import GraphConv, RelationalGraphConv 4 | from .readout import MeanReadout, SumReadout, MaxReadout, Softmax 5 | 6 | MLP = MultiLayerPerceptron 7 | GCNConv = GraphConv 8 | RGCNConv = RelationalGraphConv 9 | 10 | __all__ = [ 11 | "MultiLayerPerceptron", 12 | "GraphConv", "RelationalGraphConv", 13 | "MeanReadout", "SumReadout", "MaxReadout", "Softmax", 14 | "MLP", "GCNConv", "RGCNConv" 15 | ] -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/layers/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import multinomial, masked_mean, mean_with_nan, shifted_softplus, multi_slice, multi_slice_mask, \ 2 | as_mask, _size_to_index, _extend, variadic_log_softmax, variadic_softmax, variadic_sum, variadic_mean, \ 3 | variadic_max, variadic_cross_entropy, variadic_sort, variadic_topk, variadic_arange, variadic_randperm, \ 4 | variadic_sample, variadic_meshgrid, variadic_to_padded, padded_to_variadic, one_hot, \ 5 | clipped_policy_gradient_objective, policy_gradient_objective, variadic_sample_distribution 6 | from .embedding import transe_score, distmult_score, complex_score, simple_score, rotate_score 7 | from .spmm import generalized_spmm, generalized_rspmm 8 | 9 | __all__ = [ 10 | "multinomial", "masked_mean", "mean_with_nan", "shifted_softplus", "multi_slice_mask", "as_mask", 11 | "variadic_log_softmax", "variadic_softmax", "variadic_sum", "variadic_mean", "variadic_max", 12 | "variadic_cross_entropy", "variadic_sort", "variadic_topk", "variadic_arange", "variadic_randperm", 13 | "variadic_sample", "variadic_meshgrid", "variadic_to_padded", "padded_to_variadic", 14 | "one_hot", "clipped_policy_gradient_objective", "policy_gradient_objective", 15 | "transe_score", "distmult_score", "complex_score", "simple_score", "rotate_score", 16 | "generalized_spmm", "generalized_rspmm", 17 | ] -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/layers/functional/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/stage1/torchdrug/layers/functional/extension/__init__.py -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/layers/functional/extension/util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace at { 4 | 5 | const unsigned kFullMask = 0xFFFFFFFF; 6 | 7 | template 8 | __device__ scalar_t warp_reduce(scalar_t value) { 9 | #pragma unroll 10 | for (int delta = 1; delta < warpSize; delta *= 2) 11 | #if __CUDACC_VER_MAJOR__ >= 9 12 | value += __shfl_down_sync(kFullMask, value, delta); 13 | #else 14 | value += __shfl_down(value, delta); 15 | #endif 16 | return value; 17 | } 18 | 19 | template 20 | __device__ scalar_t warp_broadcast(scalar_t value, int lane_id) { 21 | #if __CUDACC_VER_MAJOR__ >= 9 22 | return __shfl_sync(kFullMask, value, lane_id); 23 | #else 24 | return __shfl(value, lane_id); 25 | #endif 26 | } 27 | 28 | } // namespace at -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import accuracy, variadic_accuracy 2 | 3 | # alias 4 | 5 | 6 | __all__ = [ 7 | "accuracy", "variadic_accuracy" 8 | ] 9 | -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/metrics/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch_scatter import scatter_add, scatter_mean, scatter_max 4 | import networkx as nx 5 | from rdkit import Chem 6 | from rdkit.Chem import Descriptors 7 | 8 | from torchdrug import utils 9 | from torchdrug.layers import functional 10 | from torchdrug.core import Registry as R 11 | from torchdrug.metrics.rdkit import sascorer 12 | 13 | 14 | 15 | @R.register("metrics.accuracy") 16 | def accuracy(pred, target): 17 | """ 18 | Compute classification accuracy over sets with equal size. 19 | 20 | Suppose there are :math:`N` sets and :math:`C` categories. 21 | 22 | Parameters: 23 | pred (Tensor): prediction of shape :math:`(N, C)` 24 | target (Tensor): target of shape :math:`(N,)` 25 | """ 26 | return (pred.argmax(dim=-1) == target).float().mean() 27 | 28 | 29 | 30 | @R.register("metrics.variadic_accuracy") 31 | def variadic_accuracy(input, target, size): 32 | """ 33 | Compute classification accuracy over variadic sizes of categories. 34 | 35 | Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math:`B`. 36 | 37 | Parameters: 38 | input (Tensor): prediction of shape :math:`(B,)` 39 | target (Tensor): target of shape :math:`(N,)`. Each target is a relative index in a sample. 40 | size (Tensor): number of categories of shape :math:`(N,)` 41 | """ 42 | index2graph = functional._size_to_index(size) 43 | 44 | input_class = scatter_max(input, index2graph)[1] 45 | target_index = target + size.cumsum(0) - size 46 | accuracy = (input_class == target_index).float() 47 | return accuracy 48 | -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/metrics/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/stage1/torchdrug/metrics/rdkit/__init__.py -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcn import GraphConvolutionalNetwork, RelationalGraphConvolutionalNetwork 2 | 3 | # alias 4 | GCN = GraphConvolutionalNetwork 5 | RGCN = RelationalGraphConvolutionalNetwork 6 | 7 | __all__ = [ 8 | "GraphConvolutionalNetwork", "RelationalGraphConvolutionalNetwork", 9 | "GCN", "RGCN" 10 | ] -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import Task 2 | from .retrosynthesis import CenterIdentification, SynthonCompletion, Retrosynthesis 3 | 4 | 5 | _criterion_name = { 6 | "mse": "mean squared error", 7 | "mae": "mean absolute error", 8 | "bce": "binary cross entropy", 9 | "ce": "cross entropy", 10 | } 11 | 12 | _metric_name = { 13 | "mae": "mean absolute error", 14 | "mse": "mean squared error", 15 | "rmse": "root mean squared error", 16 | "acc": "accuracy", 17 | "mcc": "matthews correlation coefficient", 18 | } 19 | 20 | 21 | def _get_criterion_name(criterion): 22 | if criterion in _criterion_name: 23 | return _criterion_name[criterion] 24 | return "%s loss" % criterion 25 | 26 | 27 | def _get_metric_name(metric): 28 | if metric in _metric_name: 29 | return _metric_name[metric] 30 | return metric 31 | 32 | 33 | __all__ = ["CenterIdentification", "SynthonCompletion", "Retrosynthesis"] -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/tasks/task.py: -------------------------------------------------------------------------------- 1 | from collections import Mapping, Sequence 2 | 3 | from torch import nn 4 | 5 | 6 | class Task(nn.Module): 7 | 8 | _option_members = set() 9 | 10 | def _standarize_option(self, x, name): 11 | if x is None: 12 | x = {} 13 | elif isinstance(x, str): 14 | x = {x: 1} 15 | elif isinstance(x, Sequence): 16 | x = dict.fromkeys(x, 1) 17 | elif not isinstance(x, Mapping): 18 | raise ValueError("Invalid value `%s` for option member `%s`" % (x, name)) 19 | return x 20 | 21 | def __setattr__(self, key, value): 22 | if key in self._option_members: 23 | value = self._standarize_option(value, key) 24 | super(Task, self).__setattr__(key, value) 25 | 26 | def predict_and_target(self, batch, all_loss=None, metric=None): 27 | return self.predict(batch, all_loss, metric), self.target(batch) 28 | 29 | def predict(self, batch): 30 | raise NotImplementedError 31 | 32 | def target(self, batch): 33 | raise NotImplementedError 34 | 35 | def evaluate(self, pred, target): 36 | raise NotImplementedError -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import TargetNormalize, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, VirtualAtom, Compose 2 | 3 | __all__ = [ 4 | "TargetNormalize", "RemapAtomType", "RandomBFSOrder", "Shuffle", 5 | "VirtualNode", "VirtualAtom", "Compose", 6 | ] 7 | -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/anaconda3/lib/python3.7/site-packages') 3 | from .io import input_choice, literal_eval, no_rdkit_log, capture_rdkit_log 4 | from .file import download, smart_open, extract, compute_md5, get_line_count 5 | from .torch import load_extension, cpu, cuda, detach, clone, mean, cat, stack, sparse_coo_tensor 6 | #from .decorator import cached_property, cached, deprecated_alias 7 | #from . import pretty, comm, doc, plot 8 | 9 | from .decorator import copy_args, cached_property, cached, deprecated_alias 10 | from . import pretty, comm, plot 11 | 12 | __all__ = [ 13 | "input_choice", "literal_eval", "no_rdkit_log", "capture_rdkit_log", 14 | "download", "smart_open", "extract", "compute_md5", "get_line_count", 15 | "load_extension", "cpu", "cuda", "detach", "clone", "mean", "cat", "stack", "sparse_coo_tensor", 16 | #"cached_property", "cached", "deprecated_alias", 17 | #"pretty", "comm", "doc", "plot", 18 | 19 | "copy_args", "cached_property", "cached", "deprecated_alias", 20 | "pretty", "comm", "plot" 21 | ] -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/utils/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/stage1/torchdrug/utils/extension/__init__.py -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/utils/extension/torch_ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace at { 4 | 5 | Tensor sparse_coo_tensor_unsafe(const Tensor &indices, const Tensor &values, IntArrayRef size) { 6 | return _sparse_coo_tensor_unsafe(indices, values, size, values.options().layout(kSparse)); 7 | } 8 | 9 | } 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("sparse_coo_tensor_unsafe", &at::sparse_coo_tensor_unsafe, 13 | "Construct sparse COO tensor without index check"); 14 | } -------------------------------------------------------------------------------- /code/known_class/stage1/torchdrug/utils/pretty.py: -------------------------------------------------------------------------------- 1 | separator = ">" * 30 2 | line = "-" * 30 3 | 4 | def time(seconds): 5 | """ 6 | Format time as a string. 7 | 8 | Parameters: 9 | seconds (float): time in seconds 10 | """ 11 | sec_per_min = 60 12 | sec_per_hour = 60 * 60 13 | sec_per_day = 24 * 60 * 60 14 | 15 | if seconds > sec_per_day: 16 | return "%.2f days" % (seconds / sec_per_day) 17 | elif seconds > sec_per_hour: 18 | return "%.2f hours" % (seconds / sec_per_hour) 19 | elif seconds > sec_per_min: 20 | return "%.2f mins" % (seconds / sec_per_min) 21 | else: 22 | return "%.2f secs" % seconds 23 | 24 | 25 | def long_array(array, truncation=10, display=3): 26 | """ 27 | Format an array as a string. 28 | 29 | Parameters: 30 | array (array_like): array-like data 31 | truncation (int, optional): truncate array if its length exceeds this threshold 32 | display (int, optional): number of elements to display at the beginning and the end in truncated mode 33 | """ 34 | if len(array) <= truncation: 35 | return "%s" % array 36 | return "%s, ..., %s" % (str(array[:display])[:-1], str(array[-display:])[1:]) -------------------------------------------------------------------------------- /code/known_class/stage2/configs/uspto_gdiffretro.yml: -------------------------------------------------------------------------------- 1 | exp_name: 'latent' 2 | data: './dataset' 3 | train_data_prefix: uspto_final_train 4 | val_data_prefix: uspto_final_eval 5 | checkpoints: models 6 | logs: logs 7 | device: gpu 8 | torch_device: 'cuda:0' 9 | log_iterations: null 10 | wandb_entity: null 11 | wandb_mode: 'offline' 12 | enable_progress_bar: True 13 | num_worker: 16 14 | model: egnn_dynamics 15 | lr: 2.0e-4 16 | batch_size: 64 17 | n_layers: 8 18 | n_epochs: 3000 19 | test_epochs: 10000 20 | n_stability_samples: 10 #idk what is that 21 | nf: 128 22 | activation: silu 23 | attention: False 24 | condition_time: True 25 | tanh: False 26 | norm_constant: 0.000001 27 | inv_sublayers: 2 28 | include_charges: False #Note: False is better 29 | diffusion_loss_type: l2 30 | data_augmentation: False 31 | center_of_mass: fragments #anchors 32 | remove_anchors_context: False 33 | sin_embedding: False 34 | normalization_factor: 100 35 | normalize_factors: [1, 4, 10] 36 | aggregation_method: 'sum' 37 | normalization: batch_norm 38 | inpainting: False -------------------------------------------------------------------------------- /code/known_class/stage2/configs/uspto_sample.yml: -------------------------------------------------------------------------------- 1 | checkpoint: ./models/latent/latent_epoch=2282.ckpt 2 | samples: sample 3 | data: ./dataset_save_test 4 | prefix: uspto_final_test 5 | n_samples: 300 6 | device: cuda:7 -------------------------------------------------------------------------------- /code/known_class/stage2/configs/uspto_size.yml: -------------------------------------------------------------------------------- 1 | logs: logs 2 | checkpoints: models 3 | data: './dataset' 4 | device: cuda:0 5 | normalization: batch_norm 6 | train_data_prefix: uspto_final_train 7 | val_data_prefix: uspto_final_eval 8 | test_data_prefix: uspto_final_test -------------------------------------------------------------------------------- /code/known_class/stage2/merge_result.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import glob 3 | from src.utils import * 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--used_path', type=str, default='./sample-small/uspto_final_test/sampled_size/uspto_size_gnnbest_epoch=199/latent_epoch=294') 8 | parser.add_argument('--n_samples', default=300, type=int) 9 | args = parser.parse_args() 10 | 11 | def main(): 12 | files = glob.glob(args.used_path + '/part_*/result.txt') 13 | dataframes = [pd.read_csv(file) for file in files] 14 | merged_df = pd.concat(dataframes, ignore_index=True) 15 | merged_df.to_csv('result.csv', index=False) 16 | 17 | current_path = '.' 18 | path_csv_file_test = "./dataset_save_test/uspto_final_test_table.csv" 19 | n_samples = args.n_samples 20 | 21 | result_path = current_path + '/result.csv' 22 | save_merged_result_path = current_path + '/merged_result.csv' 23 | uspto_final_test_table = pd.read_csv(path_csv_file_test) 24 | 25 | merge_res(current_path, n_samples, result_path, save_merged_result_path, uspto_final_test_table) 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /code/known_class/stage2/run_get_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function handle_sigint() { 4 | echo "SIGINT received, killing all subprocesses..." 5 | pkill -P $$ 6 | } 7 | 8 | trap 'handle_sigint' SIGINT 9 | 10 | 11 | gnn_size_version=599 12 | diffusion_version=2282 13 | n_samples=300 14 | num_parts=32 15 | sample_path="./sample/uspto_final_test/sampled_size/uspto_size_gnnbest_epoch=${gnn_size_version}/latent_epoch=${diffusion_version}/" 16 | 17 | python sample.py \ 18 | --linker_size_model "./models/uspto_size_gnn/uspto_size_gnnbest_epoch=${gnn_size_version}.ckpt" \ 19 | --n_samples ${n_samples} \ 20 | --sample_seed 0 \ 21 | --n_steps 100 22 | 23 | python xyz_split.py --sample_path $sample_path --num_parts $num_parts 24 | 25 | for idx in $(seq 1 $num_parts) 26 | do 27 | python vis_get_result.py \ 28 | --used_path "${sample_path}part_${idx}" \ 29 | --n_samples ${n_samples} & 30 | done 31 | 32 | wait 33 | 34 | python merge_result.py --used_path ${sample_path} --n_samples ${n_samples} -------------------------------------------------------------------------------- /code/known_class/stage2/src/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/known_class/stage2/src/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /code/known_class/stage2/xyz_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | 5 | def split_folders(parent_folder, num_parts): 6 | subfolders = [f for f in os.listdir(parent_folder) if os.path.isdir(os.path.join(parent_folder, f))] 7 | pp_folder = os.path.dirname(parent_folder) 8 | folders_per_part = len(subfolders) // num_parts 9 | 10 | for i in range(num_parts): 11 | part_folder = os.path.join(pp_folder, f'part_{i + 1}') 12 | os.makedirs(part_folder, exist_ok=True) 13 | 14 | start_index = i * folders_per_part 15 | end_index = (i + 1) * folders_per_part if i < num_parts - 1 else None 16 | 17 | for folder_name in subfolders[start_index:end_index]: 18 | source_path = os.path.join(parent_folder, folder_name) 19 | destination_path = os.path.join(part_folder, folder_name) 20 | shutil.copytree(source_path, destination_path, dirs_exist_ok=True) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description='Get Result Statistics for Stage 1.') 25 | parser.add_argument('--sample_path', type=str, default="./sample/", 26 | help='Path to sampled xyz dir') 27 | parser.add_argument('--num_parts', type=int, default=32, 28 | help='Path to sampled xyz dir') 29 | args = parser.parse_args() 30 | parent_folder = args.sample_path 31 | num_parts = args.num_parts 32 | split_folders(parent_folder, num_parts) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/final_data/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/sdfdir/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/__init__.py: -------------------------------------------------------------------------------- 1 | from . import patch 2 | from .data.constant import * 3 | 4 | import sys 5 | import logging 6 | 7 | logger = logging.getLogger("") 8 | logger.setLevel(logging.INFO) 9 | format = logging.Formatter("%(asctime)-10s %(message)s", "%H:%M:%S") 10 | 11 | handler = logging.StreamHandler(sys.stdout) 12 | handler.setFormatter(format) 13 | logger.addHandler(handler) 14 | 15 | __version__ = "0.2.1" 16 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import _MetaContainer, Registry, Configurable, make_configurable 2 | from .engine import Engine 3 | from .meter import Meter 4 | from .logger import LoggerBase, LoggingLogger, WandbLogger 5 | 6 | __all__ = [ 7 | "_MetaContainer", "Registry", "Configurable", 8 | "Engine", "Meter", "LoggerBase", "LoggingLogger", "WandbLogger", 9 | ] 10 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dictionary import PerfectHash, Dictionary 2 | from .graph import Graph, PackedGraph, cat 3 | from .molecule import Molecule, PackedMolecule 4 | from .protein import Protein, PackedProtein 5 | from .dataset import MoleculeDataset, ReactionDataset, ProteinDataset, \ 6 | ProteinPairDataset, ProteinLigandDataset, \ 7 | NodeClassificationDataset, KnowledgeGraphDataset, SemiSupervised, \ 8 | semisupervised, key_split, scaffold_split, ordered_scaffold_split 9 | from .dataloader import DataLoader, graph_collate 10 | from . import constant 11 | from . import feature 12 | 13 | __all__ = [ 14 | "Graph", "PackedGraph", "Molecule", "PackedMolecule", "Protein", "PackedProtein", "PerfectHash", "Dictionary", 15 | "MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised", 16 | "ProteinDataset", "ProteinPairDataset", "ProteinLigandDataset", 17 | "semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split", 18 | "DataLoader", "graph_collate", "feature", "constant", 19 | ] 20 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/data/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/data/rdkit/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/bace.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.BACE") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class BACE(data.MoleculeDataset): 10 | r""" 11 | Binary binding results for a set of inhibitors of human :math:`\beta`-secretase 1(BACE-1). 12 | 13 | Statistics: 14 | - #Molecule: 1,513 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/bace.csv" 24 | md5 = "ba7f8fa3fdf463a811fa7edea8c982c2" 25 | target_fields = ["Class"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="mol", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/bbbp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.BBBP") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class BBBP(data.MoleculeDataset): 10 | """ 11 | Binary labels of blood-brain barrier penetration. 12 | 13 | Statistics: 14 | - #Molecule: 2,039 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/BBBP.csv" 24 | md5 = "66286cb9e6b148bd75d80c870df580fb" 25 | target_fields = ["p_np"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/cep.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.CEP") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class CEP(data.MoleculeDataset): 10 | """ 11 | Photovoltaic efficiency estimated by Havard clean energy project. 12 | 13 | Statistics: 14 | - #Molecule: 20,000 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-02-cep-pce/cep-processed.csv" 24 | md5 = "b6d257ff416917e4e6baa5e1103f3929" 25 | target_fields = ["PCE"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) 37 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/chembl_filtered.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ChEMBLFiltered") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ChEMBLFiltered(data.MoleculeDataset): 10 | """ 11 | Statistics: 12 | - #Molecule: 430,710 13 | - #Regression task: 1,310 14 | 15 | Parameters: 16 | path (str): path to store the dataset 17 | verbose (int, optional): output verbose level 18 | **kwargs 19 | """ 20 | 21 | url = "https://zenodo.org/record/5528681/files/chembl_filtered_torchdrug.csv.gz" 22 | md5 = "2fff04fecd6e697f28ebb127e8a37561" 23 | 24 | def __init__(self, path, verbose=1, **kwargs): 25 | path = os.path.expanduser(path) 26 | if not os.path.exists(path): 27 | os.makedirs(path) 28 | self.path = path 29 | 30 | zip_file = utils.download(self.url, path, md5=self.md5) 31 | csv_file = utils.extract(zip_file) 32 | 33 | self.target_fields = ["target_{}".format(i) for i in range(1310)] 34 | 35 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/citeseer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.CiteSeer") 8 | class CiteSeer(data.NodeClassificationDataset): 9 | """ 10 | A citation network of scientific publications with binary word features. 11 | 12 | Statistics: 13 | - #Node: 3,327 14 | - #Edge: 8,059 15 | - #Class: 6 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | """ 21 | 22 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/citeseer.tgz" 23 | md5 = "c8ded8ed395b31899576bfd1e91e4d6e" 24 | 25 | def __init__(self, path, verbose=1): 26 | path = os.path.expanduser(path) 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | self.path = path 30 | 31 | zip_file = utils.download(self.url, path, md5=self.md5) 32 | node_file = utils.extract(zip_file, "citeseer/citeseer.content") 33 | edge_file = utils.extract(zip_file, "citeseer/citeseer.cites") 34 | 35 | self.load_tsv(node_file, edge_file, verbose=verbose) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/clintox.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ClinTox") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ClinTox(data.MoleculeDataset): 10 | """ 11 | Qualitative data of drugs approved by the FDA and those that have failed clinical 12 | trials for toxicity reasons. 13 | 14 | Statistics: 15 | - #Molecule: 1,478 16 | - #Classification task: 2 17 | 18 | Parameters: 19 | path (str): path to store the dataset 20 | verbose (int, optional): output verbose level 21 | **kwargs 22 | """ 23 | 24 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/clintox.csv.gz" 25 | md5 = "db4f2df08be8ae92814e9d6a2d015284" 26 | target_fields = ["FDA_APPROVED", "CT_TOX"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | zip_file = utils.download(self.url, path, md5=self.md5) 35 | csv_file = utils.extract(zip_file) 36 | 37 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 38 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/cora.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Cora") 8 | class Cora(data.NodeClassificationDataset): 9 | """ 10 | A citation network of scientific publications with binary word features. 11 | 12 | Statistics: 13 | - #Node: 2,708 14 | - #Edge: 5,429 15 | - #Class: 7 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | """ 21 | 22 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz" 23 | md5 = "2fc040bee8ce3d920e4204effd1e9214" 24 | 25 | def __init__(self, path, verbose=1): 26 | path = os.path.expanduser(path) 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | self.path = path 30 | 31 | zip_file = utils.download(self.url, path, md5=self.md5) 32 | node_file = utils.extract(zip_file, "cora/cora.content") 33 | edge_file = utils.extract(zip_file, "cora/cora.cites") 34 | 35 | self.load_tsv(node_file, edge_file, verbose=verbose) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/delaney.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Delaney") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Delaney(data.MoleculeDataset): 10 | """ 11 | Log-scale water solubility of molecules. 12 | 13 | Statistics: 14 | - #Molecule: 1,128 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/delaney-processed.csv" 24 | md5 = "0c90a51668d446b9e3ab77e67662bd1c" 25 | target_fields = ["measured log solubility in mols per litre"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/freesolv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.FreeSolv") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class FreeSolv(data.MoleculeDataset): 10 | """ 11 | Experimental and calculated hydration free energy of small molecules in water. 12 | 13 | Statistics: 14 | - #Molecule: 642 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/molnet_publish/FreeSolv.zip" 24 | md5 = "8d681babd239b15e2f8b2d29f025577a" 25 | target_fields = ["expt"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, self.path, md5=self.md5) 34 | csv_file = utils.extract(zip_file, "SAMPL.csv") 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/hiv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.HIV") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class HIV(data.MoleculeDataset): 10 | """ 11 | Experimentally measured abilities to inhibit HIV replication. 12 | 13 | Statistics: 14 | - #Molecule: 41,127 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/HIV.csv" 24 | md5 = "9ad10c88f82f1dac7eb5c52b668c30a7" 25 | target_fields = ["HIV_active"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/lipophilicity.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Lipophilicity") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Lipophilicity(data.MoleculeDataset): 10 | """ 11 | Experimental results of octanol/water distribution coefficient (logD at pH 7.4). 12 | 13 | Statistics: 14 | - #Molecule: 4,200 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/Lipophilicity.csv" 24 | md5 = "85a0e1cb8b38b0dfc3f96ff47a57f0ab" 25 | target_fields = ["exp"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) 37 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/malaria.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Malaria") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Malaria(data.MoleculeDataset): 10 | """ 11 | Half-maximal effective concentration (EC50) against a parasite that causes malaria. 12 | 13 | Statistics: 14 | - #Molecule: 10,000 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-03-malaria/" \ 24 | "malaria-processed.csv" 25 | md5 = "ef40ddfd164be0e5ed1bd3dd0cce9b88" 26 | target_fields = ["activity"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | file_name = utils.download(self.url, self.path, md5=self.md5) 35 | 36 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/moses.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | from torch.utils import data as torch_data 5 | 6 | from torchdrug import data, utils 7 | from torchdrug.core import Registry as R 8 | 9 | 10 | @R.register("datasets.MOSES") 11 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 12 | class MOSES(data.MoleculeDataset): 13 | """ 14 | Subset of ZINC database for molecule generation. 15 | This dataset doesn't contain any label information. 16 | 17 | Statistics: 18 | - #Molecule: 1,936,963 19 | 20 | Parameters: 21 | path (str): path for the CSV dataset 22 | verbose (int, optional): output verbose level 23 | **kwargs 24 | """ 25 | 26 | url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv" 27 | md5 = "6bdb0d9526ddf5fdeb87d6aa541df213" 28 | target_fields = ["SPLIT"] 29 | 30 | def __init__(self, path, verbose=1, **kwargs): 31 | path = os.path.expanduser(path) 32 | if not os.path.exists(path): 33 | os.makedirs(path) 34 | self.path = path 35 | 36 | file_name = utils.download(self.url, path, md5=self.md5) 37 | 38 | self.load_csv(file_name, smiles_field="SMILES", target_fields=self.target_fields, 39 | lazy=True, verbose=verbose, **kwargs) 40 | 41 | def split(self): 42 | indexes = defaultdict(list) 43 | for i, split in enumerate(self.targets["SPLIT"]): 44 | indexes[split].append(i) 45 | train_set = torch_data.Subset(self, indexes["train"]) 46 | valid_set = torch_data.Subset(self, indexes["valid"]) 47 | test_set = torch_data.Subset(self, indexes["test"]) 48 | return train_set, valid_set, test_set 49 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/muv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.MUV") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class MUV(data.MoleculeDataset): 10 | """ 11 | Subset of PubChem BioAssay by applying a refined nearest neighbor analysis. 12 | 13 | Statistics: 14 | - #Molecule: 93,087 15 | - #Classification task: 17 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/muv.csv.gz" 24 | md5 = "9c40bd41310991efd40f4d4868fa3ddf" 25 | target_fields = ["MUV-466", "MUV-548", "MUV-600", "MUV-644", "MUV-652", "MUV-689", "MUV-692", "MUV-712", "MUV-713", 26 | "MUV-733", "MUV-737", "MUV-810", "MUV-832", "MUV-846", "MUV-852", "MUV-858", "MUV-859"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | zip_file = utils.download(self.url, path, md5=self.md5) 35 | csv_file = utils.extract(zip_file) 36 | 37 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 38 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/pcqm4m.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.PCQM4M") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class PCQM4M(data.MoleculeDataset): 10 | """ 11 | Quantum chemistry dataset originally curated under the PubChemQC of molecules. 12 | 13 | Statistics: 14 | - #Molecule: 3,803,453 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip" 24 | md5 = "5144ebaa7c67d24da1a2acbe41f57f6a" 25 | target_fields = ["homolumogap"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, self.path, md5=self.md5) 34 | zip_file = utils.extract(zip_file, "pcqm4m_kddcup2021/raw/data.csv.gz") 35 | file_name = utils.extract(zip_file) 36 | 37 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 38 | lazy=True, verbose=verbose, **kwargs) 39 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/pubchem110m.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | from tqdm import tqdm 4 | 5 | from torchdrug import data, utils 6 | from torchdrug.core import Registry as R 7 | 8 | 9 | @R.register("datasets.PubChem110m") 10 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 11 | class PubChem110m(data.MoleculeDataset): 12 | """ 13 | PubChem. 14 | This dataset doesn't contain any label information. 15 | 16 | Statistics: 17 | - #Molecule: 18 | 19 | Parameters: 20 | path (str): 21 | verbose (int, optional): output verbose level 22 | **kwargs 23 | """ 24 | # TODO: download path & md5. Is it the statistics right? 25 | 26 | target_fields = [] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | smiles_file = os.path.join(path, "CID-SMILES") 35 | 36 | with open(smiles_file, "r") as fin: 37 | reader = csv.reader(fin, delimiter="\t") 38 | if verbose: 39 | reader = iter(tqdm(reader, "Loading %s" % path, utils.get_line_count(smiles_file))) 40 | smiles_list = [] 41 | 42 | for values in reader: 43 | smiles = values[1] 44 | smiles_list.append(smiles) 45 | 46 | targets = {} 47 | self.load_smiles(smiles_list, targets, lazy=True, verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/sider.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.SIDER") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class SIDER(data.MoleculeDataset): 10 | """ 11 | Marketed drugs and adverse drug reactions (ADR) dataset, grouped into 27 system organ classes. 12 | 13 | Statistics: 14 | - #Molecule: 1,427 15 | - #Classification task: 27 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/sider.csv.gz" 24 | md5 = "77c0ef421f7cc8ce963c5836c8761fd2" 25 | target_fields = None # pick all targets 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, path, md5=self.md5) 34 | csv_file = utils.extract(zip_file) 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/stability.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils import data as torch_data 4 | 5 | from torchdrug import data, utils 6 | from torchdrug.core import Registry as R 7 | 8 | 9 | @R.register("datasets.Stability") 10 | @utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) 11 | class Stability(data.ProteinDataset): 12 | """ 13 | The stability values of proteins under natural environment. 14 | 15 | Statistics: 16 | - #Train: 53,571 17 | - #Valid: 2,512 18 | - #Test: 12,851 19 | 20 | Parameters: 21 | path (str): the path to store the dataset 22 | verbose (int, optional): output verbose level 23 | **kwargs 24 | """ 25 | 26 | url = "http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/stability.tar.gz" 27 | md5 = "aa1e06eb5a59e0ecdae581e9ea029675" 28 | splits = ["train", "valid", "test"] 29 | target_fields = ["stability_score"] 30 | 31 | def __init__(self, path, verbose=1, **kwargs): 32 | path = os.path.expanduser(path) 33 | if not os.path.exists(path): 34 | os.makedirs(path) 35 | self.path = path 36 | 37 | zip_file = utils.download(self.url, path, md5=self.md5) 38 | data_path = utils.extract(zip_file) 39 | lmdb_files = [os.path.join(data_path, "stability/stability_%s.lmdb" % split) 40 | for split in self.splits] 41 | 42 | self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) 43 | 44 | def split(self): 45 | offset = 0 46 | splits = [] 47 | for num_sample in self.num_samples: 48 | split = torch_data.Subset(self, range(offset, offset + num_sample)) 49 | splits.append(split) 50 | offset += num_sample 51 | return splits -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/tox21.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Tox21") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Tox21(data.MoleculeDataset): 10 | """ 11 | Qualitative toxicity measurements on 12 biological targets, including nuclear receptors 12 | and stress response pathways. 13 | 14 | Statistics: 15 | - #Molecule: 7,831 16 | - #Classification task: 12 17 | 18 | Parameters: 19 | path (str): path to store the dataset 20 | verbose (int, optional): output verbose level 21 | **kwargs 22 | """ 23 | 24 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/tox21.csv.gz" 25 | md5 = "2882d69e70bba0fec14995f26787cc25" 26 | target_fields = ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", 27 | "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] 28 | 29 | def __init__(self, path, verbose=1, **kwargs): 30 | path = os.path.expanduser(path) 31 | if not os.path.exists(path): 32 | os.makedirs(path) 33 | self.path = path 34 | 35 | zip_file = utils.download(self.url, path, md5=self.md5) 36 | csv_file = utils.extract(zip_file) 37 | 38 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 39 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/toxcast.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ToxCast") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ToxCast(data.MoleculeDataset): 10 | """ 11 | Toxicology data based on in vitro high-throughput screening. 12 | 13 | Statistics: 14 | - #Molecule: 8,575 15 | - #Classification task: 617 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/toxcast_data.csv.gz" 24 | md5 = "92911bbf9c1e2ad85231014859388cd6" 25 | target_fields = None # pick all targets 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, path, md5=self.md5) 34 | csv_file = utils.extract(zip_file) 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/datasets/zinc250k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ZINC250k") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ZINC250k(data.MoleculeDataset): 10 | """ 11 | Subset of ZINC compound database for virtual screening. 12 | 13 | Statistics: 14 | - #Molecule: 498,910 15 | - #Regression task: 2 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/" \ 24 | "250k_rndm_zinc_drugs_clean_3.csv" 25 | md5 = "b59078b2b04c6e9431280e3dc42048d5" 26 | target_fields = ["logP", "qed"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | file_name = utils.download(self.url, path, md5=self.md5) 35 | 36 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/layers/distribution.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections.abc import Sequence 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class IndependentGaussian(nn.Module): 9 | """ 10 | Independent Gaussian distribution. 11 | 12 | Parameters: 13 | mu (Tensor): mean of shape :math:`(N,)` 14 | sigma2 (Tensor): variance of shape :math:`(N,)` 15 | learnable (bool, optional): learnable parameters or not 16 | """ 17 | 18 | def __init__(self, mu, sigma2, learnable=False): 19 | super(IndependentGaussian, self).__init__() 20 | if learnable: 21 | self.mu = nn.Parameter(torch.as_tensor(mu)) 22 | self.sigma2 = nn.Parameter(torch.as_tensor(sigma2)) 23 | else: 24 | self.register_buffer("mu", torch.as_tensor(mu)) 25 | self.register_buffer("sigma2", torch.as_tensor(sigma2)) 26 | self.dim = len(mu) 27 | 28 | def forward(self, input): 29 | """ 30 | Compute the likelihood of input data. 31 | 32 | Parameters: 33 | input (Tensor): input data of shape :math:`(..., N)` 34 | """ 35 | log_likelihood = -0.5 * (math.log(2 * math.pi) + self.sigma2.log() + (input - self.mu) ** 2 / self.sigma2) 36 | return log_likelihood 37 | 38 | def sample(self, *size): 39 | """ 40 | Draw samples from the distribution. 41 | 42 | Parameters: 43 | size (tuple of int): shape of the samples 44 | """ 45 | if len(size) == 1 and isinstance(size[0], Sequence): 46 | size = size[0] 47 | size = list(size) + [self.dim] 48 | 49 | sample = torch.randn(size, device=self.mu.device) * self.sigma2.sqrt() + self.mu 50 | return sample 51 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/layers/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import multinomial, masked_mean, mean_with_nan, shifted_softplus, multi_slice, multi_slice_mask, \ 2 | as_mask, _extend, variadic_log_softmax, variadic_softmax, variadic_sum, variadic_mean, variadic_max, \ 3 | variadic_cross_entropy, variadic_sort, variadic_topk, variadic_arange, variadic_randperm, variadic_sample,\ 4 | variadic_meshgrid, variadic_to_padded, padded_to_variadic, one_hot, clipped_policy_gradient_objective, \ 5 | policy_gradient_objective 6 | from .embedding import transe_score, distmult_score, complex_score, simple_score, rotate_score 7 | from .spmm import generalized_spmm, generalized_rspmm 8 | 9 | __all__ = [ 10 | "multinomial", "masked_mean", "mean_with_nan", "shifted_softplus", "multi_slice_mask", "as_mask", 11 | "variadic_log_softmax", "variadic_softmax", "variadic_sum", "variadic_mean", "variadic_max", 12 | "variadic_cross_entropy", "variadic_sort", "variadic_topk", "variadic_arange", "variadic_randperm", 13 | "variadic_sample", "variadic_meshgrid", "variadic_to_padded", "padded_to_variadic", 14 | "one_hot", "clipped_policy_gradient_objective", "policy_gradient_objective", 15 | "transe_score", "distmult_score", "complex_score", "simple_score", "rotate_score", 16 | "generalized_spmm", "generalized_rspmm", 17 | ] -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/layers/functional/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/layers/functional/extension/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/layers/functional/extension/util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace at { 4 | 5 | const unsigned kFullMask = 0xFFFFFFFF; 6 | 7 | template 8 | __device__ scalar_t warp_reduce(scalar_t value) { 9 | #pragma unroll 10 | for (int delta = 1; delta < warpSize; delta *= 2) 11 | #if __CUDACC_VER_MAJOR__ >= 9 12 | value += __shfl_down_sync(kFullMask, value, delta); 13 | #else 14 | value += __shfl_down(value, delta); 15 | #endif 16 | return value; 17 | } 18 | 19 | template 20 | __device__ scalar_t warp_broadcast(scalar_t value, int lane_id) { 21 | #if __CUDACC_VER_MAJOR__ >= 9 22 | return __shfl_sync(kFullMask, value, lane_id); 23 | #else 24 | return __shfl(value, lane_id); 25 | #endif 26 | } 27 | 28 | } // namespace at -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/layers/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph import GraphConstruction, SpatialLineGraph 2 | from .function import BondEdge, KNNEdge, SpatialEdge, SequentialEdge, AlphaCarbonNode, \ 3 | IdentityNode, RandomEdgeMask, SubsequenceNode, SubspaceNode 4 | 5 | __all__ = [ 6 | "GraphConstruction", "SpatialLineGraph", 7 | "BondEdge", "KNNEdge", "SpatialEdge", "SequentialEdge", "AlphaCarbonNode", 8 | "IdentityNode", "RandomEdgeMask", "SubsequenceNode", "SubspaceNode" 9 | ] -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import area_under_roc, area_under_prc, r2, QED, logP, penalized_logP, SA, chemical_validity, \ 2 | accuracy, variadic_accuracy, matthews_corrcoef, pearsonr, spearmanr, \ 3 | variadic_area_under_prc, variadic_area_under_roc, variadic_top_precision, f1_max 4 | 5 | # alias 6 | AUROC = area_under_roc 7 | AUPRC = area_under_prc 8 | 9 | __all__ = [ 10 | "area_under_roc", "area_under_prc", "r2", "QED", "logP", "penalized_logP", "SA", "chemical_validity", 11 | "accuracy", "variadic_accuracy", "matthews_corrcoef", "pearsonr", "spearmanr", 12 | "variadic_area_under_prc", "variadic_area_under_roc", "variadic_top_precision", "f1_max", 13 | "AUROC", "AUPRC", 14 | ] -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/metrics/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/metrics/rdkit/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/tasks/task.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | 3 | from torch import nn 4 | 5 | 6 | class Task(nn.Module): 7 | 8 | _option_members = set() 9 | 10 | def _standarize_option(self, x, name): 11 | if x is None: 12 | x = {} 13 | elif isinstance(x, str): 14 | x = {x: 1} 15 | elif isinstance(x, Sequence): 16 | x = dict.fromkeys(x, 1) 17 | elif not isinstance(x, Mapping): 18 | raise ValueError("Invalid value `%s` for option member `%s`" % (x, name)) 19 | return x 20 | 21 | def __setattr__(self, key, value): 22 | if key in self._option_members: 23 | value = self._standarize_option(value, key) 24 | super(Task, self).__setattr__(key, value) 25 | 26 | def preprocess(self, train_set, valid_set, test_set): 27 | pass 28 | 29 | def predict_and_target(self, batch, all_loss=None, metric=None): 30 | return self.predict(batch, all_loss, metric), self.target(batch) 31 | 32 | def predict(self, batch, all_loss=None, metric=None): 33 | raise NotImplementedError 34 | 35 | def target(self, batch): 36 | raise NotImplementedError 37 | 38 | def evaluate(self, pred, target): 39 | raise NotImplementedError -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import NormalizeTarget, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, \ 2 | VirtualAtom, TruncateProtein, ProteinView, Compose 3 | 4 | __all__ = [ 5 | "NormalizeTarget", "RemapAtomType", "RandomBFSOrder", "Shuffle", 6 | "VirtualNode", "VirtualAtom", "TruncateProtein", "ProteinView", "Compose", 7 | ] 8 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .io import input_choice, literal_eval, no_rdkit_log, capture_rdkit_log 2 | from .file import download, smart_open, extract, compute_md5, get_line_count 3 | from .torch import load_extension, cpu, cuda, detach, clone, mean, cat, stack, sparse_coo_tensor 4 | from .decorator import copy_args, cached_property, cached, deprecated_alias 5 | from . import pretty, comm, plot 6 | 7 | __all__ = [ 8 | "input_choice", "literal_eval", "no_rdkit_log", "capture_rdkit_log", 9 | "download", "smart_open", "extract", "compute_md5", "get_line_count", 10 | "load_extension", "cpu", "cuda", "detach", "clone", "mean", "cat", "stack", "sparse_coo_tensor", 11 | "copy_args", "cached_property", "cached", "deprecated_alias", 12 | "pretty", "comm", "plot", 13 | ] -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/utils/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/utils/extension/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/data_process/generate_SDF/torchdrug/torchdrug/utils/extension/torch_ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace at { 4 | 5 | Tensor sparse_coo_tensor_unsafe(const Tensor &indices, const Tensor &values, IntArrayRef size) { 6 | return _sparse_coo_tensor_unsafe(indices, values, size, values.options().layout(kSparse)); 7 | } 8 | 9 | } 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("sparse_coo_tensor_unsafe", &at::sparse_coo_tensor_unsafe, 13 | "Construct sparse COO tensor without index check"); 14 | } -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/cmp_data/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/molecule-datasets/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/__init__.py: -------------------------------------------------------------------------------- 1 | from . import patch 2 | from .data.constant import * 3 | 4 | import sys 5 | import logging 6 | 7 | logger = logging.getLogger("") 8 | logger.setLevel(logging.INFO) 9 | format = logging.Formatter("%(asctime)-10s %(message)s", "%H:%M:%S") 10 | 11 | handler = logging.StreamHandler(sys.stdout) 12 | handler.setFormatter(format) 13 | logger.addHandler(handler) 14 | 15 | __version__ = "0.2.1" 16 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import _MetaContainer, Registry, Configurable, make_configurable 2 | from .engine import Engine 3 | from .meter import Meter 4 | from .logger import LoggerBase, LoggingLogger, WandbLogger 5 | 6 | __all__ = [ 7 | "_MetaContainer", "Registry", "Configurable", 8 | "Engine", "Meter", "LoggerBase", "LoggingLogger", "WandbLogger", 9 | ] 10 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dictionary import PerfectHash, Dictionary 2 | from .graph import Graph, PackedGraph, cat 3 | from .molecule import Molecule, PackedMolecule 4 | from .protein import Protein, PackedProtein 5 | from .dataset import MoleculeDataset, ReactionDataset, ProteinDataset, \ 6 | ProteinPairDataset, ProteinLigandDataset, \ 7 | NodeClassificationDataset, KnowledgeGraphDataset, SemiSupervised, \ 8 | semisupervised, key_split, scaffold_split, ordered_scaffold_split 9 | from .dataloader import DataLoader, graph_collate 10 | from . import constant 11 | from . import feature 12 | 13 | __all__ = [ 14 | "Graph", "PackedGraph", "Molecule", "PackedMolecule", "Protein", "PackedProtein", "PerfectHash", "Dictionary", 15 | "MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised", 16 | "ProteinDataset", "ProteinPairDataset", "ProteinLigandDataset", 17 | "semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split", 18 | "DataLoader", "graph_collate", "feature", "constant", 19 | ] 20 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/data/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/data/rdkit/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/bace.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.BACE") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class BACE(data.MoleculeDataset): 10 | r""" 11 | Binary binding results for a set of inhibitors of human :math:`\beta`-secretase 1(BACE-1). 12 | 13 | Statistics: 14 | - #Molecule: 1,513 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/bace.csv" 24 | md5 = "ba7f8fa3fdf463a811fa7edea8c982c2" 25 | target_fields = ["Class"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="mol", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/bbbp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.BBBP") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class BBBP(data.MoleculeDataset): 10 | """ 11 | Binary labels of blood-brain barrier penetration. 12 | 13 | Statistics: 14 | - #Molecule: 2,039 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/BBBP.csv" 24 | md5 = "66286cb9e6b148bd75d80c870df580fb" 25 | target_fields = ["p_np"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/cep.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.CEP") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class CEP(data.MoleculeDataset): 10 | """ 11 | Photovoltaic efficiency estimated by Havard clean energy project. 12 | 13 | Statistics: 14 | - #Molecule: 20,000 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-02-cep-pce/cep-processed.csv" 24 | md5 = "b6d257ff416917e4e6baa5e1103f3929" 25 | target_fields = ["PCE"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) 37 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/chembl_filtered.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ChEMBLFiltered") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ChEMBLFiltered(data.MoleculeDataset): 10 | """ 11 | Statistics: 12 | - #Molecule: 430,710 13 | - #Regression task: 1,310 14 | 15 | Parameters: 16 | path (str): path to store the dataset 17 | verbose (int, optional): output verbose level 18 | **kwargs 19 | """ 20 | 21 | url = "https://zenodo.org/record/5528681/files/chembl_filtered_torchdrug.csv.gz" 22 | md5 = "2fff04fecd6e697f28ebb127e8a37561" 23 | 24 | def __init__(self, path, verbose=1, **kwargs): 25 | path = os.path.expanduser(path) 26 | if not os.path.exists(path): 27 | os.makedirs(path) 28 | self.path = path 29 | 30 | zip_file = utils.download(self.url, path, md5=self.md5) 31 | csv_file = utils.extract(zip_file) 32 | 33 | self.target_fields = ["target_{}".format(i) for i in range(1310)] 34 | 35 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/citeseer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.CiteSeer") 8 | class CiteSeer(data.NodeClassificationDataset): 9 | """ 10 | A citation network of scientific publications with binary word features. 11 | 12 | Statistics: 13 | - #Node: 3,327 14 | - #Edge: 8,059 15 | - #Class: 6 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | """ 21 | 22 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/citeseer.tgz" 23 | md5 = "c8ded8ed395b31899576bfd1e91e4d6e" 24 | 25 | def __init__(self, path, verbose=1): 26 | path = os.path.expanduser(path) 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | self.path = path 30 | 31 | zip_file = utils.download(self.url, path, md5=self.md5) 32 | node_file = utils.extract(zip_file, "citeseer/citeseer.content") 33 | edge_file = utils.extract(zip_file, "citeseer/citeseer.cites") 34 | 35 | self.load_tsv(node_file, edge_file, verbose=verbose) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/clintox.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ClinTox") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ClinTox(data.MoleculeDataset): 10 | """ 11 | Qualitative data of drugs approved by the FDA and those that have failed clinical 12 | trials for toxicity reasons. 13 | 14 | Statistics: 15 | - #Molecule: 1,478 16 | - #Classification task: 2 17 | 18 | Parameters: 19 | path (str): path to store the dataset 20 | verbose (int, optional): output verbose level 21 | **kwargs 22 | """ 23 | 24 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/clintox.csv.gz" 25 | md5 = "db4f2df08be8ae92814e9d6a2d015284" 26 | target_fields = ["FDA_APPROVED", "CT_TOX"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | zip_file = utils.download(self.url, path, md5=self.md5) 35 | csv_file = utils.extract(zip_file) 36 | 37 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 38 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/cora.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Cora") 8 | class Cora(data.NodeClassificationDataset): 9 | """ 10 | A citation network of scientific publications with binary word features. 11 | 12 | Statistics: 13 | - #Node: 2,708 14 | - #Edge: 5,429 15 | - #Class: 7 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | """ 21 | 22 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz" 23 | md5 = "2fc040bee8ce3d920e4204effd1e9214" 24 | 25 | def __init__(self, path, verbose=1): 26 | path = os.path.expanduser(path) 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | self.path = path 30 | 31 | zip_file = utils.download(self.url, path, md5=self.md5) 32 | node_file = utils.extract(zip_file, "cora/cora.content") 33 | edge_file = utils.extract(zip_file, "cora/cora.cites") 34 | 35 | self.load_tsv(node_file, edge_file, verbose=verbose) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/delaney.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Delaney") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Delaney(data.MoleculeDataset): 10 | """ 11 | Log-scale water solubility of molecules. 12 | 13 | Statistics: 14 | - #Molecule: 1,128 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/delaney-processed.csv" 24 | md5 = "0c90a51668d446b9e3ab77e67662bd1c" 25 | target_fields = ["measured log solubility in mols per litre"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/freesolv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.FreeSolv") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class FreeSolv(data.MoleculeDataset): 10 | """ 11 | Experimental and calculated hydration free energy of small molecules in water. 12 | 13 | Statistics: 14 | - #Molecule: 642 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/molnet_publish/FreeSolv.zip" 24 | md5 = "8d681babd239b15e2f8b2d29f025577a" 25 | target_fields = ["expt"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, self.path, md5=self.md5) 34 | csv_file = utils.extract(zip_file, "SAMPL.csv") 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/hiv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.HIV") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class HIV(data.MoleculeDataset): 10 | """ 11 | Experimentally measured abilities to inhibit HIV replication. 12 | 13 | Statistics: 14 | - #Molecule: 41,127 15 | - #Classification task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/HIV.csv" 24 | md5 = "9ad10c88f82f1dac7eb5c52b668c30a7" 25 | target_fields = ["HIV_active"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/lipophilicity.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Lipophilicity") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Lipophilicity(data.MoleculeDataset): 10 | """ 11 | Experimental results of octanol/water distribution coefficient (logD at pH 7.4). 12 | 13 | Statistics: 14 | - #Molecule: 4,200 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/Lipophilicity.csv" 24 | md5 = "85a0e1cb8b38b0dfc3f96ff47a57f0ab" 25 | target_fields = ["exp"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | file_name = utils.download(self.url, self.path, md5=self.md5) 34 | 35 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 36 | verbose=verbose, **kwargs) 37 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/malaria.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Malaria") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Malaria(data.MoleculeDataset): 10 | """ 11 | Half-maximal effective concentration (EC50) against a parasite that causes malaria. 12 | 13 | Statistics: 14 | - #Molecule: 10,000 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-03-malaria/" \ 24 | "malaria-processed.csv" 25 | md5 = "ef40ddfd164be0e5ed1bd3dd0cce9b88" 26 | target_fields = ["activity"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | file_name = utils.download(self.url, self.path, md5=self.md5) 35 | 36 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/moses.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | from torch.utils import data as torch_data 5 | 6 | from torchdrug import data, utils 7 | from torchdrug.core import Registry as R 8 | 9 | 10 | @R.register("datasets.MOSES") 11 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 12 | class MOSES(data.MoleculeDataset): 13 | """ 14 | Subset of ZINC database for molecule generation. 15 | This dataset doesn't contain any label information. 16 | 17 | Statistics: 18 | - #Molecule: 1,936,963 19 | 20 | Parameters: 21 | path (str): path for the CSV dataset 22 | verbose (int, optional): output verbose level 23 | **kwargs 24 | """ 25 | 26 | url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv" 27 | md5 = "6bdb0d9526ddf5fdeb87d6aa541df213" 28 | target_fields = ["SPLIT"] 29 | 30 | def __init__(self, path, verbose=1, **kwargs): 31 | path = os.path.expanduser(path) 32 | if not os.path.exists(path): 33 | os.makedirs(path) 34 | self.path = path 35 | 36 | file_name = utils.download(self.url, path, md5=self.md5) 37 | 38 | self.load_csv(file_name, smiles_field="SMILES", target_fields=self.target_fields, 39 | lazy=True, verbose=verbose, **kwargs) 40 | 41 | def split(self): 42 | indexes = defaultdict(list) 43 | for i, split in enumerate(self.targets["SPLIT"]): 44 | indexes[split].append(i) 45 | train_set = torch_data.Subset(self, indexes["train"]) 46 | valid_set = torch_data.Subset(self, indexes["valid"]) 47 | test_set = torch_data.Subset(self, indexes["test"]) 48 | return train_set, valid_set, test_set 49 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/muv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.MUV") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class MUV(data.MoleculeDataset): 10 | """ 11 | Subset of PubChem BioAssay by applying a refined nearest neighbor analysis. 12 | 13 | Statistics: 14 | - #Molecule: 93,087 15 | - #Classification task: 17 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/muv.csv.gz" 24 | md5 = "9c40bd41310991efd40f4d4868fa3ddf" 25 | target_fields = ["MUV-466", "MUV-548", "MUV-600", "MUV-644", "MUV-652", "MUV-689", "MUV-692", "MUV-712", "MUV-713", 26 | "MUV-733", "MUV-737", "MUV-810", "MUV-832", "MUV-846", "MUV-852", "MUV-858", "MUV-859"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | zip_file = utils.download(self.url, path, md5=self.md5) 35 | csv_file = utils.extract(zip_file) 36 | 37 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 38 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/pcqm4m.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.PCQM4M") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class PCQM4M(data.MoleculeDataset): 10 | """ 11 | Quantum chemistry dataset originally curated under the PubChemQC of molecules. 12 | 13 | Statistics: 14 | - #Molecule: 3,803,453 15 | - #Regression task: 1 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip" 24 | md5 = "5144ebaa7c67d24da1a2acbe41f57f6a" 25 | target_fields = ["homolumogap"] 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, self.path, md5=self.md5) 34 | zip_file = utils.extract(zip_file, "pcqm4m_kddcup2021/raw/data.csv.gz") 35 | file_name = utils.extract(zip_file) 36 | 37 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 38 | lazy=True, verbose=verbose, **kwargs) 39 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/pubchem110m.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | from tqdm import tqdm 4 | 5 | from torchdrug import data, utils 6 | from torchdrug.core import Registry as R 7 | 8 | 9 | @R.register("datasets.PubChem110m") 10 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 11 | class PubChem110m(data.MoleculeDataset): 12 | """ 13 | PubChem. 14 | This dataset doesn't contain any label information. 15 | 16 | Statistics: 17 | - #Molecule: 18 | 19 | Parameters: 20 | path (str): 21 | verbose (int, optional): output verbose level 22 | **kwargs 23 | """ 24 | # TODO: download path & md5. Is it the statistics right? 25 | 26 | target_fields = [] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | smiles_file = os.path.join(path, "CID-SMILES") 35 | 36 | with open(smiles_file, "r") as fin: 37 | reader = csv.reader(fin, delimiter="\t") 38 | if verbose: 39 | reader = iter(tqdm(reader, "Loading %s" % path, utils.get_line_count(smiles_file))) 40 | smiles_list = [] 41 | 42 | for values in reader: 43 | smiles = values[1] 44 | smiles_list.append(smiles) 45 | 46 | targets = {} 47 | self.load_smiles(smiles_list, targets, lazy=True, verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/sider.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.SIDER") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class SIDER(data.MoleculeDataset): 10 | """ 11 | Marketed drugs and adverse drug reactions (ADR) dataset, grouped into 27 system organ classes. 12 | 13 | Statistics: 14 | - #Molecule: 1,427 15 | - #Classification task: 27 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/sider.csv.gz" 24 | md5 = "77c0ef421f7cc8ce963c5836c8761fd2" 25 | target_fields = None # pick all targets 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, path, md5=self.md5) 34 | csv_file = utils.extract(zip_file) 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/tox21.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.Tox21") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class Tox21(data.MoleculeDataset): 10 | """ 11 | Qualitative toxicity measurements on 12 biological targets, including nuclear receptors 12 | and stress response pathways. 13 | 14 | Statistics: 15 | - #Molecule: 7,831 16 | - #Classification task: 12 17 | 18 | Parameters: 19 | path (str): path to store the dataset 20 | verbose (int, optional): output verbose level 21 | **kwargs 22 | """ 23 | 24 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/tox21.csv.gz" 25 | md5 = "2882d69e70bba0fec14995f26787cc25" 26 | target_fields = ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", 27 | "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] 28 | 29 | def __init__(self, path, verbose=1, **kwargs): 30 | path = os.path.expanduser(path) 31 | if not os.path.exists(path): 32 | os.makedirs(path) 33 | self.path = path 34 | 35 | zip_file = utils.download(self.url, path, md5=self.md5) 36 | csv_file = utils.extract(zip_file) 37 | 38 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 39 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/toxcast.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ToxCast") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ToxCast(data.MoleculeDataset): 10 | """ 11 | Toxicology data based on in vitro high-throughput screening. 12 | 13 | Statistics: 14 | - #Molecule: 8,575 15 | - #Classification task: 617 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/toxcast_data.csv.gz" 24 | md5 = "92911bbf9c1e2ad85231014859388cd6" 25 | target_fields = None # pick all targets 26 | 27 | def __init__(self, path, verbose=1, **kwargs): 28 | path = os.path.expanduser(path) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | self.path = path 32 | 33 | zip_file = utils.download(self.url, path, md5=self.md5) 34 | csv_file = utils.extract(zip_file) 35 | 36 | self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/datasets/zinc250k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchdrug import data, utils 4 | from torchdrug.core import Registry as R 5 | 6 | 7 | @R.register("datasets.ZINC250k") 8 | @utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) 9 | class ZINC250k(data.MoleculeDataset): 10 | """ 11 | Subset of ZINC compound database for virtual screening. 12 | 13 | Statistics: 14 | - #Molecule: 498,910 15 | - #Regression task: 2 16 | 17 | Parameters: 18 | path (str): path to store the dataset 19 | verbose (int, optional): output verbose level 20 | **kwargs 21 | """ 22 | 23 | url = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/" \ 24 | "250k_rndm_zinc_drugs_clean_3.csv" 25 | md5 = "b59078b2b04c6e9431280e3dc42048d5" 26 | target_fields = ["logP", "qed"] 27 | 28 | def __init__(self, path, verbose=1, **kwargs): 29 | path = os.path.expanduser(path) 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | self.path = path 33 | 34 | file_name = utils.download(self.url, path, md5=self.md5) 35 | 36 | self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, 37 | verbose=verbose, **kwargs) -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/layers/distribution.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections.abc import Sequence 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class IndependentGaussian(nn.Module): 9 | """ 10 | Independent Gaussian distribution. 11 | 12 | Parameters: 13 | mu (Tensor): mean of shape :math:`(N,)` 14 | sigma2 (Tensor): variance of shape :math:`(N,)` 15 | learnable (bool, optional): learnable parameters or not 16 | """ 17 | 18 | def __init__(self, mu, sigma2, learnable=False): 19 | super(IndependentGaussian, self).__init__() 20 | if learnable: 21 | self.mu = nn.Parameter(torch.as_tensor(mu)) 22 | self.sigma2 = nn.Parameter(torch.as_tensor(sigma2)) 23 | else: 24 | self.register_buffer("mu", torch.as_tensor(mu)) 25 | self.register_buffer("sigma2", torch.as_tensor(sigma2)) 26 | self.dim = len(mu) 27 | 28 | def forward(self, input): 29 | """ 30 | Compute the likelihood of input data. 31 | 32 | Parameters: 33 | input (Tensor): input data of shape :math:`(..., N)` 34 | """ 35 | log_likelihood = -0.5 * (math.log(2 * math.pi) + self.sigma2.log() + (input - self.mu) ** 2 / self.sigma2) 36 | return log_likelihood 37 | 38 | def sample(self, *size): 39 | """ 40 | Draw samples from the distribution. 41 | 42 | Parameters: 43 | size (tuple of int): shape of the samples 44 | """ 45 | if len(size) == 1 and isinstance(size[0], Sequence): 46 | size = size[0] 47 | size = list(size) + [self.dim] 48 | 49 | sample = torch.randn(size, device=self.mu.device) * self.sigma2.sqrt() + self.mu 50 | return sample 51 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/layers/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import multinomial, masked_mean, mean_with_nan, shifted_softplus, multi_slice, multi_slice_mask, \ 2 | as_mask, _extend, variadic_log_softmax, variadic_softmax, variadic_sum, variadic_mean, variadic_max, \ 3 | variadic_cross_entropy, variadic_sort, variadic_topk, variadic_arange, variadic_randperm, variadic_sample,\ 4 | variadic_meshgrid, variadic_to_padded, padded_to_variadic, one_hot, clipped_policy_gradient_objective, \ 5 | policy_gradient_objective 6 | from .embedding import transe_score, distmult_score, complex_score, simple_score, rotate_score 7 | from .spmm import generalized_spmm, generalized_rspmm 8 | 9 | __all__ = [ 10 | "multinomial", "masked_mean", "mean_with_nan", "shifted_softplus", "multi_slice_mask", "as_mask", 11 | "variadic_log_softmax", "variadic_softmax", "variadic_sum", "variadic_mean", "variadic_max", 12 | "variadic_cross_entropy", "variadic_sort", "variadic_topk", "variadic_arange", "variadic_randperm", 13 | "variadic_sample", "variadic_meshgrid", "variadic_to_padded", "padded_to_variadic", 14 | "one_hot", "clipped_policy_gradient_objective", "policy_gradient_objective", 15 | "transe_score", "distmult_score", "complex_score", "simple_score", "rotate_score", 16 | "generalized_spmm", "generalized_rspmm", 17 | ] -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/layers/functional/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/layers/functional/extension/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/layers/functional/extension/util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace at { 4 | 5 | const unsigned kFullMask = 0xFFFFFFFF; 6 | 7 | template 8 | __device__ scalar_t warp_reduce(scalar_t value) { 9 | #pragma unroll 10 | for (int delta = 1; delta < warpSize; delta *= 2) 11 | #if __CUDACC_VER_MAJOR__ >= 9 12 | value += __shfl_down_sync(kFullMask, value, delta); 13 | #else 14 | value += __shfl_down(value, delta); 15 | #endif 16 | return value; 17 | } 18 | 19 | template 20 | __device__ scalar_t warp_broadcast(scalar_t value, int lane_id) { 21 | #if __CUDACC_VER_MAJOR__ >= 9 22 | return __shfl_sync(kFullMask, value, lane_id); 23 | #else 24 | return __shfl(value, lane_id); 25 | #endif 26 | } 27 | 28 | } // namespace at -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/layers/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph import GraphConstruction, SpatialLineGraph 2 | from .function import BondEdge, KNNEdge, SpatialEdge, SequentialEdge, AlphaCarbonNode, \ 3 | IdentityNode, RandomEdgeMask, SubsequenceNode, SubspaceNode 4 | 5 | __all__ = [ 6 | "GraphConstruction", "SpatialLineGraph", 7 | "BondEdge", "KNNEdge", "SpatialEdge", "SequentialEdge", "AlphaCarbonNode", 8 | "IdentityNode", "RandomEdgeMask", "SubsequenceNode", "SubspaceNode" 9 | ] -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import area_under_roc, area_under_prc, r2, QED, logP, penalized_logP, SA, chemical_validity, \ 2 | accuracy, variadic_accuracy, matthews_corrcoef, pearsonr, spearmanr, \ 3 | variadic_area_under_prc, variadic_area_under_roc, variadic_top_precision, f1_max 4 | 5 | # alias 6 | AUROC = area_under_roc 7 | AUPRC = area_under_prc 8 | 9 | __all__ = [ 10 | "area_under_roc", "area_under_prc", "r2", "QED", "logP", "penalized_logP", "SA", "chemical_validity", 11 | "accuracy", "variadic_accuracy", "matthews_corrcoef", "pearsonr", "spearmanr", 12 | "variadic_area_under_prc", "variadic_area_under_roc", "variadic_top_precision", "f1_max", 13 | "AUROC", "AUPRC", 14 | ] -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/metrics/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/metrics/rdkit/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/tasks/task.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | 3 | from torch import nn 4 | 5 | 6 | class Task(nn.Module): 7 | 8 | _option_members = set() 9 | 10 | def _standarize_option(self, x, name): 11 | if x is None: 12 | x = {} 13 | elif isinstance(x, str): 14 | x = {x: 1} 15 | elif isinstance(x, Sequence): 16 | x = dict.fromkeys(x, 1) 17 | elif not isinstance(x, Mapping): 18 | raise ValueError("Invalid value `%s` for option member `%s`" % (x, name)) 19 | return x 20 | 21 | def __setattr__(self, key, value): 22 | if key in self._option_members: 23 | value = self._standarize_option(value, key) 24 | super(Task, self).__setattr__(key, value) 25 | 26 | def preprocess(self, train_set, valid_set, test_set): 27 | pass 28 | 29 | def predict_and_target(self, batch, all_loss=None, metric=None): 30 | return self.predict(batch, all_loss, metric), self.target(batch) 31 | 32 | def predict(self, batch, all_loss=None, metric=None): 33 | raise NotImplementedError 34 | 35 | def target(self, batch): 36 | raise NotImplementedError 37 | 38 | def evaluate(self, pred, target): 39 | raise NotImplementedError -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import NormalizeTarget, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, \ 2 | VirtualAtom, TruncateProtein, ProteinView, Compose 3 | 4 | __all__ = [ 5 | "NormalizeTarget", "RemapAtomType", "RandomBFSOrder", "Shuffle", 6 | "VirtualNode", "VirtualAtom", "TruncateProtein", "ProteinView", "Compose", 7 | ] 8 | -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .io import input_choice, literal_eval, no_rdkit_log, capture_rdkit_log 2 | from .file import download, smart_open, extract, compute_md5, get_line_count 3 | from .torch import load_extension, cpu, cuda, detach, clone, mean, cat, stack, sparse_coo_tensor 4 | from .decorator import copy_args, cached_property, cached, deprecated_alias 5 | from . import pretty, comm, plot 6 | 7 | __all__ = [ 8 | "input_choice", "literal_eval", "no_rdkit_log", "capture_rdkit_log", 9 | "download", "smart_open", "extract", "compute_md5", "get_line_count", 10 | "load_extension", "cpu", "cuda", "detach", "clone", "mean", "cat", "stack", "sparse_coo_tensor", 11 | "copy_args", "cached_property", "cached", "deprecated_alias", 12 | "pretty", "comm", "plot", 13 | ] -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/utils/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/utils/extension/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/data_process/get_dataset_1st_stage/torchdrug/torchdrug/utils/extension/torch_ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace at { 4 | 5 | Tensor sparse_coo_tensor_unsafe(const Tensor &indices, const Tensor &values, IntArrayRef size) { 6 | return _sparse_coo_tensor_unsafe(indices, values, size, values.options().layout(kSparse)); 7 | } 8 | 9 | } 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("sparse_coo_tensor_unsafe", &at::sparse_coo_tensor_unsafe, 13 | "Construct sparse COO tensor without index check"); 14 | } -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/__init__.py: -------------------------------------------------------------------------------- 1 | from . import patch 2 | from .data.constant import * 3 | 4 | import sys 5 | import logging 6 | 7 | logger = logging.getLogger("") 8 | logger.setLevel(logging.INFO) 9 | format = logging.Formatter("%(asctime)-10s %(message)s", "%H:%M:%S") 10 | 11 | handler = logging.StreamHandler(sys.stdout) 12 | handler.setFormatter(format) 13 | logger.addHandler(handler) 14 | 15 | __version__ = "0.1.3" -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import _MetaContainer, Registry, Configurable, make_configurable 2 | from .engine import Engine 3 | from .meter import Meter 4 | from .logger import LoggerBase, LoggingLogger, WandbLogger 5 | 6 | __all__ = [ 7 | "_MetaContainer", "Registry", "Configurable", 8 | "Engine", "Meter", "LoggerBase", "LoggingLogger", "WandbLogger", 9 | ] -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dictionary import PerfectHash, Dictionary 2 | from .graph import Graph, PackedGraph, cat 3 | from .molecule import Molecule, PackedMolecule 4 | from .dataset import MoleculeDataset, ReactionDataset, \ 5 | key_split, scaffold_split, ordered_scaffold_split 6 | from .dataloader import DataLoader, graph_collate 7 | from . import constant 8 | from . import feature 9 | 10 | __all__ = [ 11 | "Graph", "PackedGraph", "Molecule", "PackedMolecule", "PerfectHash", "Dictionary", 12 | "MoleculeDataset", "ReactionDataset", "key_split", "scaffold_split", "ordered_scaffold_split", 13 | "DataLoader", "graph_collate", "feature", "constant", 14 | ] 15 | -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/data/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/stage1/torchdrug/data/rdkit/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .uspto50k import USPTO50k 2 | __all__ = ["USPTO50k"] 3 | -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/layers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .common import MultiLayerPerceptron 3 | from .conv import GraphConv, RelationalGraphConv 4 | from .readout import MeanReadout, SumReadout, MaxReadout, Softmax 5 | 6 | MLP = MultiLayerPerceptron 7 | GCNConv = GraphConv 8 | RGCNConv = RelationalGraphConv 9 | 10 | __all__ = [ 11 | "MultiLayerPerceptron", 12 | "GraphConv", "RelationalGraphConv", 13 | "MeanReadout", "SumReadout", "MaxReadout", "Softmax", 14 | "MLP", "GCNConv", "RGCNConv" 15 | ] -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/layers/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import multinomial, masked_mean, mean_with_nan, shifted_softplus, multi_slice, multi_slice_mask, \ 2 | as_mask, _size_to_index, _extend, variadic_log_softmax, variadic_softmax, variadic_sum, variadic_mean, \ 3 | variadic_max, variadic_cross_entropy, variadic_sort, variadic_topk, variadic_arange, variadic_randperm, \ 4 | variadic_sample, variadic_meshgrid, variadic_to_padded, padded_to_variadic, one_hot, \ 5 | clipped_policy_gradient_objective, policy_gradient_objective, variadic_sample_distribution 6 | from .embedding import transe_score, distmult_score, complex_score, simple_score, rotate_score 7 | from .spmm import generalized_spmm, generalized_rspmm 8 | 9 | __all__ = [ 10 | "multinomial", "masked_mean", "mean_with_nan", "shifted_softplus", "multi_slice_mask", "as_mask", 11 | "variadic_log_softmax", "variadic_softmax", "variadic_sum", "variadic_mean", "variadic_max", 12 | "variadic_cross_entropy", "variadic_sort", "variadic_topk", "variadic_arange", "variadic_randperm", 13 | "variadic_sample", "variadic_meshgrid", "variadic_to_padded", "padded_to_variadic", 14 | "one_hot", "clipped_policy_gradient_objective", "policy_gradient_objective", 15 | "transe_score", "distmult_score", "complex_score", "simple_score", "rotate_score", 16 | "generalized_spmm", "generalized_rspmm", 17 | ] -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/layers/functional/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/stage1/torchdrug/layers/functional/extension/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/layers/functional/extension/util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace at { 4 | 5 | const unsigned kFullMask = 0xFFFFFFFF; 6 | 7 | template 8 | __device__ scalar_t warp_reduce(scalar_t value) { 9 | #pragma unroll 10 | for (int delta = 1; delta < warpSize; delta *= 2) 11 | #if __CUDACC_VER_MAJOR__ >= 9 12 | value += __shfl_down_sync(kFullMask, value, delta); 13 | #else 14 | value += __shfl_down(value, delta); 15 | #endif 16 | return value; 17 | } 18 | 19 | template 20 | __device__ scalar_t warp_broadcast(scalar_t value, int lane_id) { 21 | #if __CUDACC_VER_MAJOR__ >= 9 22 | return __shfl_sync(kFullMask, value, lane_id); 23 | #else 24 | return __shfl(value, lane_id); 25 | #endif 26 | } 27 | 28 | } // namespace at -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import accuracy, variadic_accuracy 2 | 3 | # alias 4 | 5 | 6 | __all__ = [ 7 | "accuracy", "variadic_accuracy" 8 | ] 9 | -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/metrics/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch_scatter import scatter_add, scatter_mean, scatter_max 4 | import networkx as nx 5 | from rdkit import Chem 6 | from rdkit.Chem import Descriptors 7 | 8 | from torchdrug import utils 9 | from torchdrug.layers import functional 10 | from torchdrug.core import Registry as R 11 | from torchdrug.metrics.rdkit import sascorer 12 | 13 | 14 | 15 | @R.register("metrics.accuracy") 16 | def accuracy(pred, target): 17 | """ 18 | Compute classification accuracy over sets with equal size. 19 | 20 | Suppose there are :math:`N` sets and :math:`C` categories. 21 | 22 | Parameters: 23 | pred (Tensor): prediction of shape :math:`(N, C)` 24 | target (Tensor): target of shape :math:`(N,)` 25 | """ 26 | return (pred.argmax(dim=-1) == target).float().mean() 27 | 28 | 29 | 30 | @R.register("metrics.variadic_accuracy") 31 | def variadic_accuracy(input, target, size): 32 | """ 33 | Compute classification accuracy over variadic sizes of categories. 34 | 35 | Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math:`B`. 36 | 37 | Parameters: 38 | input (Tensor): prediction of shape :math:`(B,)` 39 | target (Tensor): target of shape :math:`(N,)`. Each target is a relative index in a sample. 40 | size (Tensor): number of categories of shape :math:`(N,)` 41 | """ 42 | index2graph = functional._size_to_index(size) 43 | 44 | input_class = scatter_max(input, index2graph)[1] 45 | target_index = target + size.cumsum(0) - size 46 | accuracy = (input_class == target_index).float() 47 | return accuracy 48 | -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/metrics/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/stage1/torchdrug/metrics/rdkit/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcn import GraphConvolutionalNetwork, RelationalGraphConvolutionalNetwork 2 | 3 | # alias 4 | GCN = GraphConvolutionalNetwork 5 | RGCN = RelationalGraphConvolutionalNetwork 6 | 7 | __all__ = [ 8 | "GraphConvolutionalNetwork", "RelationalGraphConvolutionalNetwork", 9 | "GCN", "RGCN" 10 | ] -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import Task 2 | from .retrosynthesis import CenterIdentification, SynthonCompletion, Retrosynthesis 3 | 4 | 5 | _criterion_name = { 6 | "mse": "mean squared error", 7 | "mae": "mean absolute error", 8 | "bce": "binary cross entropy", 9 | "ce": "cross entropy", 10 | } 11 | 12 | _metric_name = { 13 | "mae": "mean absolute error", 14 | "mse": "mean squared error", 15 | "rmse": "root mean squared error", 16 | "acc": "accuracy", 17 | "mcc": "matthews correlation coefficient", 18 | } 19 | 20 | 21 | def _get_criterion_name(criterion): 22 | if criterion in _criterion_name: 23 | return _criterion_name[criterion] 24 | return "%s loss" % criterion 25 | 26 | 27 | def _get_metric_name(metric): 28 | if metric in _metric_name: 29 | return _metric_name[metric] 30 | return metric 31 | 32 | 33 | __all__ = ["CenterIdentification", "SynthonCompletion", "Retrosynthesis"] -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/tasks/task.py: -------------------------------------------------------------------------------- 1 | from collections import Mapping, Sequence 2 | 3 | from torch import nn 4 | 5 | 6 | class Task(nn.Module): 7 | 8 | _option_members = set() 9 | 10 | def _standarize_option(self, x, name): 11 | if x is None: 12 | x = {} 13 | elif isinstance(x, str): 14 | x = {x: 1} 15 | elif isinstance(x, Sequence): 16 | x = dict.fromkeys(x, 1) 17 | elif not isinstance(x, Mapping): 18 | raise ValueError("Invalid value `%s` for option member `%s`" % (x, name)) 19 | return x 20 | 21 | def __setattr__(self, key, value): 22 | if key in self._option_members: 23 | value = self._standarize_option(value, key) 24 | super(Task, self).__setattr__(key, value) 25 | 26 | def predict_and_target(self, batch, all_loss=None, metric=None): 27 | return self.predict(batch, all_loss, metric), self.target(batch) 28 | 29 | def predict(self, batch): 30 | raise NotImplementedError 31 | 32 | def target(self, batch): 33 | raise NotImplementedError 34 | 35 | def evaluate(self, pred, target): 36 | raise NotImplementedError -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import TargetNormalize, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, VirtualAtom, Compose 2 | 3 | __all__ = [ 4 | "TargetNormalize", "RemapAtomType", "RandomBFSOrder", "Shuffle", 5 | "VirtualNode", "VirtualAtom", "Compose", 6 | ] 7 | -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./anaconda3/lib/python3.7/site-packages') 3 | from .io import input_choice, literal_eval, no_rdkit_log, capture_rdkit_log 4 | from .file import download, smart_open, extract, compute_md5, get_line_count 5 | from .torch import load_extension, cpu, cuda, detach, clone, mean, cat, stack, sparse_coo_tensor 6 | #from .decorator import cached_property, cached, deprecated_alias 7 | #from . import pretty, comm, doc, plot 8 | 9 | from .decorator import copy_args, cached_property, cached, deprecated_alias 10 | from . import pretty, comm, plot 11 | 12 | __all__ = [ 13 | "input_choice", "literal_eval", "no_rdkit_log", "capture_rdkit_log", 14 | "download", "smart_open", "extract", "compute_md5", "get_line_count", 15 | "load_extension", "cpu", "cuda", "detach", "clone", "mean", "cat", "stack", "sparse_coo_tensor", 16 | #"cached_property", "cached", "deprecated_alias", 17 | #"pretty", "comm", "doc", "plot", 18 | 19 | "copy_args", "cached_property", "cached", "deprecated_alias", 20 | "pretty", "comm", "plot" 21 | ] -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/utils/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/stage1/torchdrug/utils/extension/__init__.py -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/utils/extension/torch_ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace at { 4 | 5 | Tensor sparse_coo_tensor_unsafe(const Tensor &indices, const Tensor &values, IntArrayRef size) { 6 | return _sparse_coo_tensor_unsafe(indices, values, size, values.options().layout(kSparse)); 7 | } 8 | 9 | } 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("sparse_coo_tensor_unsafe", &at::sparse_coo_tensor_unsafe, 13 | "Construct sparse COO tensor without index check"); 14 | } -------------------------------------------------------------------------------- /code/unknown_class/stage1/torchdrug/utils/pretty.py: -------------------------------------------------------------------------------- 1 | separator = ">" * 30 2 | line = "-" * 30 3 | 4 | def time(seconds): 5 | """ 6 | Format time as a string. 7 | 8 | Parameters: 9 | seconds (float): time in seconds 10 | """ 11 | sec_per_min = 60 12 | sec_per_hour = 60 * 60 13 | sec_per_day = 24 * 60 * 60 14 | 15 | if seconds > sec_per_day: 16 | return "%.2f days" % (seconds / sec_per_day) 17 | elif seconds > sec_per_hour: 18 | return "%.2f hours" % (seconds / sec_per_hour) 19 | elif seconds > sec_per_min: 20 | return "%.2f mins" % (seconds / sec_per_min) 21 | else: 22 | return "%.2f secs" % seconds 23 | 24 | 25 | def long_array(array, truncation=10, display=3): 26 | """ 27 | Format an array as a string. 28 | 29 | Parameters: 30 | array (array_like): array-like data 31 | truncation (int, optional): truncate array if its length exceeds this threshold 32 | display (int, optional): number of elements to display at the beginning and the end in truncated mode 33 | """ 34 | if len(array) <= truncation: 35 | return "%s" % array 36 | return "%s, ..., %s" % (str(array[:display])[:-1], str(array[-display:])[1:]) -------------------------------------------------------------------------------- /code/unknown_class/stage2/configs/uspto_gdiffretro.yml: -------------------------------------------------------------------------------- 1 | exp_name: 'latent' 2 | data: './dataset' 3 | train_data_prefix: uspto_final_train 4 | val_data_prefix: uspto_final_eval 5 | #test_data_prefix: uspto_final_test 6 | checkpoints: models 7 | logs: logs 8 | device: gpu 9 | torch_device: 'cuda:1' 10 | log_iterations: null 11 | wandb_entity: null 12 | wandb_mode: 'offline' 13 | enable_progress_bar: True 14 | num_worker: 16 15 | model: egnn_dynamics 16 | lr: 2.0e-4 17 | batch_size: 64 18 | n_layers: 8 19 | n_epochs: 3000 20 | test_epochs: 10000 21 | n_stability_samples: 10 #idk what is that 22 | nf: 128 23 | activation: silu 24 | attention: False 25 | condition_time: True 26 | tanh: False 27 | norm_constant: 0.000001 28 | inv_sublayers: 2 29 | include_charges: False 30 | diffusion_loss_type: l2 31 | data_augmentation: False 32 | center_of_mass: fragments 33 | remove_anchors_context: False 34 | sin_embedding: False 35 | normalization_factor: 100 36 | normalize_factors: [1, 4, 10] 37 | aggregation_method: 'sum' 38 | normalization: batch_norm 39 | inpainting: False -------------------------------------------------------------------------------- /code/unknown_class/stage2/configs/uspto_sample.yml: -------------------------------------------------------------------------------- 1 | checkpoint: ./models/latent/latent_epoch=2999.ckpt 2 | samples: sample 3 | data: ./dataset_save_test 4 | prefix: uspto_final_test 5 | n_samples: 300 6 | device: cuda:6 7 | -------------------------------------------------------------------------------- /code/unknown_class/stage2/configs/uspto_size.yml: -------------------------------------------------------------------------------- 1 | logs: logs 2 | checkpoints: models 3 | data: './dataset' 4 | device: cuda:0 5 | normalization: batch_norm 6 | train_data_prefix: uspto_final_train 7 | val_data_prefix: uspto_final_eval 8 | test_data_prefix: uspto_final_test 9 | -------------------------------------------------------------------------------- /code/unknown_class/stage2/merge_result.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import glob 3 | from src.utils import * 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--used_path', type=str, default='./sample-small/uspto_final_test/sampled_size/uspto_size_gnnbest_epoch=199/latent_epoch=294') 8 | parser.add_argument('--n_samples', default=300, type=int) 9 | args = parser.parse_args() 10 | 11 | def main(): 12 | files = glob.glob(args.used_path + '/part_*/result.txt') 13 | dataframes = [pd.read_csv(file) for file in files] 14 | merged_df = pd.concat(dataframes, ignore_index=True) 15 | merged_df.to_csv('result.csv', index=False) 16 | 17 | current_path = '.' 18 | path_csv_file_test = "./dataset_save_test/uspto_final_test_table.csv" 19 | n_samples = args.n_samples 20 | 21 | result_path = current_path + '/result.csv' 22 | save_merged_result_path = current_path + '/merged_result.csv' 23 | uspto_final_test_table = pd.read_csv(path_csv_file_test) 24 | 25 | merge_res(current_path, n_samples, result_path, save_merged_result_path, uspto_final_test_table) 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /code/unknown_class/stage2/run_get_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function handle_sigint() { 4 | echo "SIGINT received, killing all subprocesses..." 5 | pkill -P $$ 6 | } 7 | 8 | trap 'handle_sigint' SIGINT 9 | 10 | 11 | gnn_size_version=600 12 | diffusion_version=2999 13 | n_samples=300 14 | num_parts=32 15 | sample_path="./sample/uspto_final_test/sampled_size/uspto_size_gnnbest_epoch=${gnn_size_version}/latent_epoch=${diffusion_version}/" # Note that the last slash is neccessary cause we use os.path.dirname in xyz_split 16 | 17 | python sample.py \ 18 | --linker_size_model "./models/uspto_size_gnn/uspto_size_gnnbest_epoch=${gnn_size_version}.ckpt" \ 19 | --n_samples ${n_samples} \ 20 | --sample_seed 0 \ 21 | --n_steps 100 22 | 23 | python xyz_split.py --sample_path $sample_path --num_parts $num_parts 24 | 25 | for idx in $(seq 1 $num_parts) 26 | do 27 | python vis_get_result.py \ 28 | --used_path "${sample_path}part_${idx}" \ 29 | --n_samples ${n_samples} & 30 | done 31 | 32 | wait 33 | 34 | python merge_result.py --used_path ${sample_path} --n_samples ${n_samples} -------------------------------------------------------------------------------- /code/unknown_class/stage2/src/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/code/unknown_class/stage2/src/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /code/unknown_class/stage2/xyz_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | 5 | def split_folders(parent_folder, num_parts): 6 | subfolders = [f for f in os.listdir(parent_folder) if os.path.isdir(os.path.join(parent_folder, f))] 7 | pp_folder = os.path.dirname(parent_folder) 8 | folders_per_part = len(subfolders) // num_parts 9 | 10 | for i in range(num_parts): 11 | part_folder = os.path.join(pp_folder, f'part_{i + 1}') 12 | os.makedirs(part_folder, exist_ok=True) 13 | 14 | start_index = i * folders_per_part 15 | end_index = (i + 1) * folders_per_part if i < num_parts - 1 else None 16 | 17 | for folder_name in subfolders[start_index:end_index]: 18 | source_path = os.path.join(parent_folder, folder_name) 19 | destination_path = os.path.join(part_folder, folder_name) 20 | shutil.copytree(source_path, destination_path, dirs_exist_ok=True) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description='Get Result Statistics for Stage 1.') 25 | parser.add_argument('--sample_path', type=str, default="./sample/", 26 | help='Path to sampled xyz dir') 27 | parser.add_argument('--num_parts', type=int, default=32, 28 | help='Path to sampled xyz dir') 29 | args = parser.parse_args() 30 | parent_folder = args.sample_path 31 | num_parts = args.num_parts 32 | split_folders(parent_folder, num_parts) -------------------------------------------------------------------------------- /fig/fig1_framework.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/fig/fig1_framework.pdf -------------------------------------------------------------------------------- /fig/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunshy-1/GDiffRetro/e72767853d05b12d6e2cb1e187e2c1b97c2a9239/fig/framework.png --------------------------------------------------------------------------------