├── .gitignore ├── .readthedocs.yml ├── CITATION.cff ├── LICENSE ├── MANIFEST.in ├── README.md ├── configs ├── _base_ │ ├── curators │ │ ├── domains │ │ │ ├── assay.py │ │ │ ├── protein.py │ │ │ ├── protein_family.py │ │ │ ├── scaffold.py │ │ │ └── size.py │ │ ├── lbap_defaults.py │ │ ├── noises │ │ │ ├── core.py │ │ │ ├── general.py │ │ │ └── refined.py │ │ └── sbap_defaults.py │ ├── default_runtime.py │ ├── models │ │ └── resnet50.py │ └── schedules │ │ ├── classification.py │ │ └── regression.py ├── algorithms │ ├── augmentation │ │ ├── lbap_core_ec50_assay_dann.py │ │ ├── lbap_core_ec50_assay_mixup.py │ │ ├── sbap_core_ec50_assay_dann.py │ │ └── sbap_core_ec50_assay_mixup.py │ ├── coral │ │ ├── lbap_core_ec50_assay_coral.py │ │ └── sbap_core_ec50_assay_coral.py │ ├── erm │ │ ├── lbap_core_ec50_assay_erm.py │ │ └── sbap_core_ec50_assay_erm.py │ ├── groupdro │ │ ├── lbap_core_ec50_assay_groupdro.py │ │ └── sbap_core_ec50_assay_groupdro.py │ └── irm │ │ ├── lbap_core_ec50_assay_irm.py │ │ └── sbap_core_ec50_assay_irm.py └── curators │ ├── lbap_core_ec50_assay.py │ ├── lbap_core_ec50_scaffold.py │ ├── lbap_core_ec50_size.py │ ├── lbap_core_ic50_assay.py │ ├── lbap_core_ic50_scaffold.py │ ├── lbap_core_ic50_size.py │ ├── lbap_core_ki_assay.py │ ├── lbap_core_ki_scaffold.py │ ├── lbap_core_ki_size.py │ ├── lbap_core_potency_assay.py │ ├── lbap_core_potency_scaffold.py │ ├── lbap_core_potency_size.py │ ├── lbap_general_ec50_assay.py │ ├── lbap_general_ec50_scaffold.py │ ├── lbap_general_ec50_size.py │ ├── lbap_general_ic50_assay.py │ ├── lbap_general_ic50_scaffold.py │ ├── lbap_general_ic50_size.py │ ├── lbap_general_ki_assay.py │ ├── lbap_general_ki_scaffold.py │ ├── lbap_general_ki_size.py │ ├── lbap_general_potency_assay.py │ ├── lbap_general_potency_scaffold.py │ ├── lbap_general_potency_size.py │ ├── lbap_refined_ec50_assay.py │ ├── lbap_refined_ec50_scaffold.py │ ├── lbap_refined_ec50_size.py │ ├── lbap_refined_ic50_assay.py │ ├── lbap_refined_ic50_scaffold.py │ ├── lbap_refined_ic50_size.py │ ├── lbap_refined_ki_assay.py │ ├── lbap_refined_ki_scaffold.py │ ├── lbap_refined_ki_size.py │ ├── lbap_refined_potency_assay.py │ ├── lbap_refined_potency_scaffold.py │ ├── lbap_refined_potency_size.py │ ├── sbap_core_ec50_assay.py │ ├── sbap_core_ec50_protein.py │ ├── sbap_core_ec50_protein_family.py │ ├── sbap_core_ec50_scaffold.py │ ├── sbap_core_ec50_size.py │ ├── sbap_core_ic50_assay.py │ ├── sbap_core_ic50_protein.py │ ├── sbap_core_ic50_protein_family.py │ ├── sbap_core_ic50_scaffold.py │ ├── sbap_core_ic50_size.py │ ├── sbap_core_ki_assay.py │ ├── sbap_core_ki_protein.py │ ├── sbap_core_ki_protein_family.py │ ├── sbap_core_ki_scaffold.py │ ├── sbap_core_ki_size.py │ ├── sbap_core_potency_assay.py │ ├── sbap_core_potency_protein.py │ ├── sbap_core_potency_protein_family.py │ ├── sbap_core_potency_scaffold.py │ ├── sbap_core_potency_size.py │ ├── sbap_general_ec50_assay.py │ ├── sbap_general_ec50_protein.py │ ├── sbap_general_ec50_protein_family.py │ ├── sbap_general_ec50_scaffold.py │ ├── sbap_general_ec50_size.py │ ├── sbap_general_ic50_assay.py │ ├── sbap_general_ic50_protein.py │ ├── sbap_general_ic50_protein_family.py │ ├── sbap_general_ic50_scaffold.py │ ├── sbap_general_ic50_size.py │ ├── sbap_general_ki_assay.py │ ├── sbap_general_ki_protein.py │ ├── sbap_general_ki_protein_family.py │ ├── sbap_general_ki_scaffold.py │ ├── sbap_general_ki_size.py │ ├── sbap_general_potency_assay.py │ ├── sbap_general_potency_protein.py │ ├── sbap_general_potency_protein_family.py │ ├── sbap_general_potency_scaffold.py │ ├── sbap_general_potency_size.py │ ├── sbap_refined_ec50_assay.py │ ├── sbap_refined_ec50_protein.py │ ├── sbap_refined_ec50_protein_family.py │ ├── sbap_refined_ec50_scaffold.py │ ├── sbap_refined_ec50_size.py │ ├── sbap_refined_ic50_assay.py │ ├── sbap_refined_ic50_protein.py │ ├── sbap_refined_ic50_protein_family.py │ ├── sbap_refined_ic50_scaffold.py │ ├── sbap_refined_ic50_size.py │ ├── sbap_refined_ki_assay.py │ ├── sbap_refined_ki_protein.py │ ├── sbap_refined_ki_protein_family.py │ ├── sbap_refined_ki_scaffold.py │ ├── sbap_refined_ki_size.py │ ├── sbap_refined_potency_assay.py │ ├── sbap_refined_potency_protein.py │ ├── sbap_refined_potency_protein_family.py │ ├── sbap_refined_potency_scaffold.py │ └── sbap_refined_potency_size.py ├── demo ├── INFO ├── demo.ipynb └── sources │ ├── curator.png │ ├── drugood_dataset.png │ └── overview.png ├── drugood.yaml ├── drugood ├── __init__.py ├── apis │ ├── __init__.py │ ├── curate.py │ ├── inference.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── eval_hooks.py │ │ ├── eval_metrics.py │ │ ├── mean_average_precision.py │ │ └── multilabel_eval_metrics.py │ ├── fp16 │ │ ├── __init__.py │ │ ├── decorators.py │ │ ├── hooks.py │ │ └── utils.py │ ├── hooks │ │ ├── __init__.py │ │ └── irm_optimizer_hook.py │ ├── runner │ │ └── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── data_collect.py │ │ ├── dist_utils.py │ │ └── misc.py │ └── visualization │ │ ├── __init__.py │ │ └── molecule.py ├── curators │ ├── __init__.py │ ├── chembl │ │ ├── __init__.py │ │ ├── filter.py │ │ ├── protein_family.py │ │ └── sql_exe.py │ ├── curator.py │ └── get_domain_info.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── builder.py │ ├── dataset_wrappers.py │ ├── drugood_dataset.py │ ├── grouper.py │ ├── multi_label.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── compose.py │ │ └── formating.py │ ├── samplers │ │ ├── __init__.py │ │ └── distributed_sampler.py │ └── utils.py ├── models │ ├── __init__.py │ ├── algorithms │ │ ├── __init__.py │ │ ├── base.py │ │ ├── builder.py │ │ ├── coral.py │ │ ├── dann.py │ │ ├── erm.py │ │ ├── groupdro.py │ │ ├── irm.py │ │ └── mixup.py │ ├── backbones │ │ ├── __init__.py │ │ ├── attentivefp.py │ │ ├── base_backbone.py │ │ ├── bert.py │ │ ├── gat.py │ │ ├── gcn.py │ │ ├── gin.py │ │ ├── gta.py │ │ ├── mgcn.py │ │ ├── nf.py │ │ ├── resnet.py │ │ ├── schnet.py │ │ └── weave.py │ ├── builder.py │ ├── heads │ │ ├── __init__.py │ │ ├── base_head.py │ │ ├── cls_head.py │ │ ├── linear_head.py │ │ ├── multi_label_head.py │ │ ├── multi_label_linear_head.py │ │ └── reg_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── cross_entropy_loss.py │ │ ├── error.py │ │ ├── focal_loss.py │ │ ├── label_smooth_loss.py │ │ ├── mean_squared_error_loss.py │ │ └── utils.py │ ├── necks │ │ ├── __init__.py │ │ ├── cat.py │ │ └── gap.py │ ├── taskers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── classifier.py │ │ └── regressor.py │ └── utils │ │ ├── __init__.py │ │ ├── augment │ │ ├── __init__.py │ │ ├── augments.py │ │ ├── builder.py │ │ ├── cutmix.py │ │ ├── identity.py │ │ └── mixup.py │ │ └── helpers.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── comm.py │ ├── logger.py │ └── smile_to_dgl.py └── version.py ├── requirements.txt ├── requirements ├── docs.txt ├── mminstall.txt ├── optional.txt ├── readthedocs.txt ├── runtime.txt └── tests.txt ├── setup.cfg ├── setup.py └── tools ├── __init__.py ├── analysis_tools ├── analyze_logs.py ├── analyze_results.py ├── eval_metric.py ├── get_flops.py └── parse_logs.py ├── curate.py ├── dist_test.sh ├── dist_train.sh ├── misc └── print_config.py ├── slurm_test.sh ├── slurm_train.sh ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | **/*.pyc 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # custom 108 | .vscode 109 | .idea 110 | *.pkl 111 | *.pkl.json 112 | *.log.json 113 | /work_dirs 114 | /mmcls/.mim 115 | 116 | # Pytorch 117 | *.pth 118 | 119 | .DS_Store 120 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | formats: all 4 | 5 | python: 6 | version: 3.7 7 | install: 8 | - requirements: requirements/docs.txt 9 | - requirements: requirements/readthedocs.txt 10 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | title: "DrugOOD: Out-of-Distribution (OOD) Dataset Curator and Benchmark for AI-aided Drug Discovery--A Focus on Affinity Prediction Problems with Noise Annotations" 4 | authors: 5 | - name: "Yuanfeng Ji, Lu Zhang, Jiaxiang Wu, Bingzhe Wu, Long-Kai Huang, Tingyang Xu, Yu Rong, Lanqing Li, Jie Ren, Ding Xue, Houtim Lai, Shaoyong Xu, Jing Feng, Wei Liu, Ping Luo, Shuigeng Zhou, Junzhou Huang, Peilin Zhao, Yatao Bian" 6 | version: 0.0.1 7 | date-released: 2020-07-09 8 | repository-code: "" 9 | license: Apache-2.0 10 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include mmcls/.mim/model-index.yml 2 | recursive-include mmcls/.mim/configs *.py *.yml 3 | recursive-include mmcls/.mim/tools *.py *.sh 4 | -------------------------------------------------------------------------------- /configs/_base_/curators/domains/assay.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | 3 | # domain split 4 | domain = dict( 5 | domain_generate_field="assay_id", 6 | domain_name="assay", 7 | sort_func="domain_capacity", 8 | sort_order='descend', 9 | protein_family_level=1 10 | ) 11 | -------------------------------------------------------------------------------- /configs/_base_/curators/domains/protein.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | 3 | # domain split 4 | domain = dict( 5 | domain_generate_field="protein", 6 | domain_name="protein", 7 | sort_func="domain_capacity", 8 | sort_order='descend', 9 | protein_family_level=1 10 | ) 11 | -------------------------------------------------------------------------------- /configs/_base_/curators/domains/protein_family.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | 3 | # domain split 4 | domain = dict( 5 | domain_generate_field="protein", 6 | domain_name="protein_family", 7 | sort_func="domain_capacity", 8 | sort_order='descend', 9 | protein_family_level=1 10 | ) 11 | -------------------------------------------------------------------------------- /configs/_base_/curators/domains/scaffold.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | 3 | # domain split 4 | domain = dict( 5 | domain_generate_field="smiles", 6 | domain_name="scaffold", 7 | sort_func="scaffold_size", 8 | sort_order='descend', 9 | protein_family_level=1 10 | ) 11 | -------------------------------------------------------------------------------- /configs/_base_/curators/domains/size.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | 3 | # domain split 4 | domain = dict( 5 | domain_generate_field="smiles", 6 | domain_name="size", 7 | sort_func="domain_value", 8 | sort_order='descend', 9 | protein_family_level=1 10 | ) 11 | -------------------------------------------------------------------------------- /configs/_base_/curators/lbap_defaults.py: -------------------------------------------------------------------------------- 1 | # data path 2 | path = dict( 3 | task=dict(type="lbap", subset="lbap_core_ec50_assay"), 4 | source_root="/apdcephfs/share_1364275/xluzhang/chembl_29_sqlite/chembl_29.db", 5 | target_root="data/" 6 | ) 7 | # filter 8 | # noise_filter = dict( 9 | # assay=dict( 10 | # measurement_type=["EC50"], 11 | # assay_value_units=["nM", "uM"], 12 | # molecules_number=[50, 3000], 13 | # confidence_score=9), 14 | # sample=dict( 15 | # filter_none=[], 16 | # smile_exist=[], 17 | # smile_legal=[], 18 | # value_relation=["=", "~"]) 19 | # ) 20 | # uncertainty 21 | uncertainty = dict(delta={'<': -1, '<=': -1, '>': 1, '>=': 1}) 22 | 23 | # adaptive cls label 24 | classification_threshold = dict( 25 | lower_bound=4, 26 | upper_bound=6, 27 | fix_value=5 28 | ) 29 | 30 | # train/val/test 31 | fractions = dict( 32 | train_fraction_ood=0.6, 33 | val_fraction_ood=0.2, 34 | iid_train_sample_fractions=0.6, 35 | iid_val_sample_fractions=0.2 36 | ) 37 | -------------------------------------------------------------------------------- /configs/_base_/curators/noises/core.py: -------------------------------------------------------------------------------- 1 | noise_filter = dict( 2 | assay=dict( 3 | measurement_type=["EC50"], 4 | assay_value_units=["nM", "uM"], 5 | molecules_number=[50, 3000], 6 | confidence_score=9), 7 | sample=dict( 8 | filter_none=[], 9 | smile_exist=[], 10 | smile_legal=[], 11 | value_relation=["=", "~"]) 12 | ) 13 | -------------------------------------------------------------------------------- /configs/_base_/curators/noises/general.py: -------------------------------------------------------------------------------- 1 | noise_filter = dict( 2 | assay=dict( 3 | measurement_type=["EC50"], 4 | assay_value_units=["nM", "uM"], 5 | molecules_number=[10, 5000], 6 | confidence_score=None), 7 | sample=dict( 8 | filter_none=[], 9 | smile_exist=[], 10 | smile_legal=[], 11 | value_relation=["=", ">=", "<=", "~", "<", ">"]) 12 | ) 13 | -------------------------------------------------------------------------------- /configs/_base_/curators/noises/refined.py: -------------------------------------------------------------------------------- 1 | noise_filter = dict( 2 | assay=dict( 3 | measurement_type=["EC50"], 4 | assay_value_units=["nM", "uM"], 5 | molecules_number=[32, 5000], 6 | confidence_score=3), 7 | sample=dict( 8 | filter_none=[], 9 | smile_exist=[], 10 | smile_legal=[], 11 | value_relation=["=", ">=", "<=", "~"]) 12 | ) 13 | -------------------------------------------------------------------------------- /configs/_base_/curators/sbap_defaults.py: -------------------------------------------------------------------------------- 1 | # data path 2 | path = dict( 3 | task=dict(type="sbap", subset="sbap_core_ec50_assay"), 4 | source_root="/apdcephfs/share_1364275/xluzhang/chembl_29_sqlite/chembl_29.db", 5 | target_root="data/" 6 | ) 7 | 8 | # uncertainty 9 | uncertainty = dict(delta={'<': -1, '<=': -1, '>': 1, '>=': 1}) 10 | 11 | # adaptive cls label 12 | classification_threshold = dict( 13 | lower_bound=4, 14 | upper_bound=6, 15 | fix_value=5 16 | ) 17 | 18 | # train/val/test 19 | fractions = dict( 20 | train_fraction_ood=0.6, 21 | val_fraction_ood=0.2, 22 | iid_train_sample_fractions=0.6, 23 | iid_val_sample_fractions=0.2 24 | ) 25 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # checkpoint saving 2 | checkpoint_config = dict(interval=5) 3 | # yapf:disable 4 | log_config = dict( 5 | interval=5, 6 | hooks=[ 7 | dict(type='TextLoggerHook'), 8 | dict(type='TensorboardLoggerHook') 9 | ] 10 | ) 11 | # yapf:enable 12 | dist_params = dict(backend='nccl') 13 | log_level = 'INFO' 14 | load_from = None 15 | resume_from = None 16 | workflow = [('train', 1)] 17 | -------------------------------------------------------------------------------- /configs/_base_/models/resnet50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict( 5 | type='ResNet', 6 | depth=50, 7 | num_stages=4, 8 | out_indices=(3,), 9 | style='pytorch'), 10 | neck=dict(type='GlobalAveragePooling'), 11 | head=dict( 12 | type='LinearClsHead', 13 | num_classes=1000, 14 | in_channels=2048, 15 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 16 | topk=(1, 5), 17 | )) 18 | -------------------------------------------------------------------------------- /configs/_base_/schedules/classification.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=1e-4, weight_decay=0) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='fixed') 6 | runner = dict(type='EpochBasedRunner', max_epochs=50) 7 | # evaluation config 8 | evaluation = dict(metric=['accuracy', 'auc']) 9 | -------------------------------------------------------------------------------- /configs/_base_/schedules/regression.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=1e-6, weight_decay=0) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='fixed') 6 | runner = dict(type='EpochBasedRunner', max_epochs=50) 7 | # evaluation config 8 | evaluation = dict(metric=['mae']) 9 | -------------------------------------------------------------------------------- /configs/algorithms/augmentation/lbap_core_ec50_assay_dann.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='ToTensor', 11 | keys=['gt_label'] 12 | ), 13 | dict( 14 | type='Collect', 15 | keys=['input', 'gt_label', 'group'] 16 | ) 17 | ] 18 | test_pipeline = [ 19 | dict( 20 | type="SmileToGraph", 21 | keys=["input"] 22 | ), 23 | dict( 24 | type='Collect', 25 | keys=['input', 'gt_label', 'group'] 26 | )] 27 | 28 | # dataset 29 | dataset_type = "LBAPDataset" 30 | ann_file = 'data/lbap_core_ec50_assay.json' 31 | 32 | data = dict( 33 | samples_per_gpu=128, 34 | workers_per_gpu=4, 35 | train=dict( 36 | split="train", 37 | type=dataset_type, 38 | ann_file=ann_file, 39 | pipeline=train_pipeline, 40 | ), 41 | ood_val=dict( 42 | split="ood_val", 43 | type=dataset_type, 44 | ann_file=ann_file, 45 | pipeline=test_pipeline, 46 | rule="greater", 47 | save_best="accuracy" 48 | ), 49 | iid_val=dict( 50 | split="iid_val", 51 | type=dataset_type, 52 | ann_file=ann_file, 53 | pipeline=test_pipeline, 54 | ), 55 | ood_test=dict( 56 | split="ood_test", 57 | type=dataset_type, 58 | ann_file=ann_file, 59 | pipeline=test_pipeline, 60 | ), 61 | iid_test=dict( 62 | split="iid_test", 63 | type=dataset_type, 64 | ann_file=ann_file, 65 | pipeline=test_pipeline, 66 | ), 67 | ) 68 | model = dict(type="DANN", 69 | tasker=dict( 70 | type='Classifier', 71 | backbone=dict( 72 | type='GIN', 73 | num_node_emb_list=[39], 74 | num_edge_emb_list=[10], 75 | num_layers=4, 76 | emb_dim=128, 77 | readout='sum', 78 | JK='last', 79 | dropout=0.1, 80 | ), 81 | head=dict( 82 | type='LinearClsHead', 83 | num_classes=2, 84 | in_channels=128, 85 | loss=dict( 86 | type='CrossEntropyLoss', 87 | ) 88 | ) 89 | ), 90 | 91 | dann_cfg=dict( 92 | alpha=0.2, 93 | aux_head=dict( 94 | type='LinearClsHead', 95 | num_classes=47, 96 | in_channels=128, 97 | loss=dict( 98 | type='CrossEntropyLoss', 99 | ) 100 | )) 101 | ) 102 | -------------------------------------------------------------------------------- /configs/algorithms/augmentation/lbap_core_ec50_assay_mixup.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='Collect', 11 | keys=['input', 'gt_label', 'group'] 12 | ) 13 | ] 14 | test_pipeline = [ 15 | dict( 16 | type="SmileToGraph", 17 | keys=["input"] 18 | ), 19 | dict( 20 | type='Collect', 21 | keys=['input', 'gt_label', 'group'] 22 | )] 23 | 24 | # dataset 25 | dataset_type = "LBAPDataset" 26 | ann_file = 'data/lbap_core_ec50_assay.json' 27 | 28 | data = dict( 29 | samples_per_gpu=128, 30 | workers_per_gpu=4, 31 | train=dict( 32 | split="train", 33 | type=dataset_type, 34 | ann_file=ann_file, 35 | pipeline=train_pipeline, 36 | ), 37 | ood_val=dict( 38 | split="ood_val", 39 | type=dataset_type, 40 | ann_file=ann_file, 41 | pipeline=test_pipeline, 42 | rule="greater", 43 | save_best="accuracy" 44 | ), 45 | iid_val=dict( 46 | split="iid_val", 47 | type=dataset_type, 48 | ann_file=ann_file, 49 | pipeline=test_pipeline, 50 | ), 51 | ood_test=dict( 52 | split="ood_test", 53 | type=dataset_type, 54 | ann_file=ann_file, 55 | pipeline=test_pipeline, 56 | ), 57 | iid_test=dict( 58 | split="iid_test", 59 | type=dataset_type, 60 | ann_file=ann_file, 61 | pipeline=test_pipeline, 62 | ), 63 | ) 64 | 65 | model = dict( 66 | type="MixUp", 67 | tasker=dict( 68 | type='Classifier', 69 | backbone=dict( 70 | type='GIN', 71 | num_node_emb_list=[39], 72 | num_edge_emb_list=[10], 73 | num_layers=4, 74 | emb_dim=128, 75 | readout='sum', 76 | JK='last', 77 | dropout=0.1, 78 | ), 79 | head=dict( 80 | type='LinearClsHead', 81 | num_classes=2, 82 | in_channels=128, 83 | loss=dict( 84 | type='LabelSmoothLoss', 85 | label_smooth_val=0.1, 86 | num_classes=2, 87 | ) 88 | ) 89 | ), 90 | cfg=dict(type="BatchMixup", alpha=0.1, num_classes=2, prob=0.4) 91 | ) 92 | -------------------------------------------------------------------------------- /configs/algorithms/augmentation/sbap_core_ec50_assay_dann.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='ToTensor', 11 | keys=['gt_label'] 12 | ), 13 | dict( 14 | type='Collect', 15 | keys=['input', 'aux_input', 'gt_label', 'group'] 16 | ) 17 | ] 18 | test_pipeline = [ 19 | dict( 20 | type="SmileToGraph", 21 | keys=["input"] 22 | ), 23 | dict( 24 | type='Collect', 25 | keys=['input', 'aux_input', 'gt_label', 'group'] 26 | )] 27 | 28 | 29 | # dataset 30 | dataset_type = "SBAPDataset" 31 | ann_file = 'data/sbap_core_ec50_assay.json' 32 | 33 | tokenizer = dict( 34 | type="SeqToToken", 35 | model="bert-base-uncased" 36 | ) 37 | 38 | data = dict( 39 | samples_per_gpu=128, 40 | workers_per_gpu=4, 41 | train=dict( 42 | split="train", 43 | type=dataset_type, 44 | ann_file=ann_file, 45 | pipeline=train_pipeline, 46 | tokenizer=tokenizer, 47 | ), 48 | ood_val=dict( 49 | split="ood_val", 50 | type=dataset_type, 51 | ann_file=ann_file, 52 | pipeline=test_pipeline, 53 | tokenizer=tokenizer, 54 | rule="greater", 55 | save_best="accuracy" 56 | ), 57 | iid_val=dict( 58 | split="iid_val", 59 | type=dataset_type, 60 | ann_file=ann_file, 61 | pipeline=test_pipeline, 62 | tokenizer=tokenizer 63 | ), 64 | ood_test=dict( 65 | split="ood_test", 66 | type=dataset_type, 67 | ann_file=ann_file, 68 | pipeline=test_pipeline, 69 | tokenizer=tokenizer, 70 | ), 71 | iid_test=dict( 72 | split="iid_test", 73 | type=dataset_type, 74 | ann_file=ann_file, 75 | pipeline=test_pipeline, 76 | tokenizer=tokenizer 77 | ), 78 | ) 79 | 80 | model = dict(type="DANN", 81 | tasker=dict( 82 | type='Classifier', 83 | backbone=dict( 84 | type='GIN', 85 | num_node_emb_list=[39], 86 | num_edge_emb_list=[10], 87 | num_layers=4, 88 | emb_dim=128, 89 | readout='sum', 90 | JK='last', 91 | dropout=0.1, 92 | ), 93 | aux_backbone=dict( 94 | type='Bert', 95 | model="bert-base-uncased", 96 | ), 97 | neck=dict(type="Concatenate"), 98 | head=dict( 99 | type='LinearClsHead', 100 | num_classes=2, 101 | in_channels=128 + 768, 102 | loss=dict( 103 | type='CrossEntropyLoss', 104 | ) 105 | ) 106 | ), 107 | 108 | dann_cfg=dict( 109 | alpha=0.2, 110 | aux_head=dict( 111 | type='LinearClsHead', 112 | num_classes=55, 113 | in_channels=128 + 768, 114 | loss=dict( 115 | type='CrossEntropyLoss', 116 | ) 117 | )) 118 | ) 119 | -------------------------------------------------------------------------------- /configs/algorithms/augmentation/sbap_core_ec50_assay_mixup.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='Collect', 11 | keys=['input', 'aux_input', 'gt_label', 'group'] 12 | ) 13 | ] 14 | test_pipeline = [ 15 | dict( 16 | type="SmileToGraph", 17 | keys=["input"] 18 | ), 19 | dict( 20 | type='Collect', 21 | keys=['input', 'aux_input', 'gt_label', 'group'] 22 | )] 23 | 24 | # dataset 25 | 26 | dataset_type = "SBAPDataset" 27 | ann_file = 'data/sbap_core_ec50_assay.json' 28 | 29 | tokenizer = dict( 30 | type="SeqToToken", 31 | model="bert-base-uncased" 32 | ) 33 | 34 | data = dict( 35 | samples_per_gpu=128, 36 | workers_per_gpu=4, 37 | train=dict( 38 | split="train", 39 | type=dataset_type, 40 | ann_file=ann_file, 41 | pipeline=train_pipeline, 42 | tokenizer=tokenizer 43 | 44 | ), 45 | ood_val=dict( 46 | split="ood_val", 47 | type=dataset_type, 48 | ann_file=ann_file, 49 | pipeline=test_pipeline, 50 | tokenizer=tokenizer, 51 | rule="greater", 52 | save_best="accuracy" 53 | ), 54 | iid_val=dict( 55 | split="iid_val", 56 | type=dataset_type, 57 | ann_file=ann_file, 58 | pipeline=test_pipeline, 59 | tokenizer=tokenizer 60 | ), 61 | ood_test=dict( 62 | split="ood_test", 63 | type=dataset_type, 64 | ann_file=ann_file, 65 | pipeline=test_pipeline, 66 | tokenizer=tokenizer, 67 | ), 68 | iid_test=dict( 69 | split="iid_test", 70 | type=dataset_type, 71 | ann_file=ann_file, 72 | pipeline=test_pipeline, 73 | tokenizer=tokenizer 74 | ), 75 | ) 76 | 77 | model = dict( 78 | type="MixUp", 79 | tasker=dict( 80 | type='Classifier', 81 | backbone=dict( 82 | type='GIN', 83 | num_node_emb_list=[39], 84 | num_edge_emb_list=[10], 85 | num_layers=4, 86 | emb_dim=128, 87 | readout='sum', 88 | JK='last', 89 | dropout=0.1, 90 | ), 91 | aux_backbone=dict( 92 | type='Bert', 93 | model="bert-base-uncased", 94 | ), 95 | neck=dict(type="Concatenate"), 96 | head=dict( 97 | type='LinearClsHead', 98 | num_classes=2, 99 | in_channels=128 + 768, 100 | loss=dict( 101 | type='LabelSmoothLoss', 102 | label_smooth_val=0.1, 103 | num_classes=2, 104 | ) 105 | ) 106 | ), 107 | cfg=dict(alpha=0.1, num_classes=2, prob=0.4) 108 | ) 109 | -------------------------------------------------------------------------------- /configs/algorithms/coral/lbap_core_ec50_assay_coral.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='Collect', 11 | keys=['input', 'gt_label', 'group'] 12 | ) 13 | ] 14 | test_pipeline = [ 15 | dict( 16 | type="SmileToGraph", 17 | keys=["input"] 18 | ), 19 | dict( 20 | type='Collect', 21 | keys=['input', 'gt_label', 'group'] 22 | )] 23 | 24 | # dataset 25 | dataset_type = "LBAPDataset" 26 | ann_file = 'data/lbap_core_ec50_assay.json' 27 | 28 | data = dict( 29 | samples_per_gpu=128, 30 | workers_per_gpu=4, 31 | train=dict( 32 | split="train", 33 | type=dataset_type, 34 | ann_file=ann_file, 35 | pipeline=train_pipeline, 36 | sample_mode="group", 37 | sample_config=dict( 38 | uniform_over_groups=None, 39 | n_groups_per_batch=4, 40 | distinct_groups=True 41 | ) 42 | ), 43 | ood_val=dict( 44 | split="ood_val", 45 | type=dataset_type, 46 | ann_file=ann_file, 47 | pipeline=test_pipeline, 48 | rule="greater", 49 | save_best="accuracy" 50 | ), 51 | iid_val=dict( 52 | split="iid_val", 53 | type=dataset_type, 54 | ann_file=ann_file, 55 | pipeline=test_pipeline, 56 | ), 57 | ood_test=dict( 58 | split="ood_test", 59 | type=dataset_type, 60 | ann_file=ann_file, 61 | pipeline=test_pipeline, 62 | ), 63 | iid_test=dict( 64 | split="iid_test", 65 | type=dataset_type, 66 | ann_file=ann_file, 67 | pipeline=test_pipeline, 68 | ), 69 | ) 70 | 71 | # model 72 | model = dict(type="CORAL", 73 | tasker=dict( 74 | type='Classifier', 75 | backbone=dict( 76 | type='GIN', 77 | num_node_emb_list=[39], 78 | num_edge_emb_list=[10], 79 | num_layers=4, 80 | emb_dim=128, 81 | readout='sum', 82 | JK='last', 83 | dropout=0.1, 84 | ), 85 | head=dict( 86 | type='LinearClsHead', 87 | num_classes=2, 88 | in_channels=128, 89 | loss=dict( 90 | type='CrossEntropyLoss', 91 | reduction="none", 92 | ) 93 | ) 94 | ), 95 | coral_penalty_weight=0.001, 96 | ) 97 | -------------------------------------------------------------------------------- /configs/algorithms/coral/sbap_core_ec50_assay_coral.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='Collect', 11 | keys=['input', 'aux_input', 'gt_label', 'group'] 12 | ) 13 | ] 14 | test_pipeline = [ 15 | dict( 16 | type="SmileToGraph", 17 | keys=["input"] 18 | ), 19 | dict( 20 | type='Collect', 21 | keys=['input', 'aux_input', 'gt_label', 'group'] 22 | )] 23 | 24 | 25 | # dataset 26 | dataset_type = "SBAPDataset" 27 | ann_file = 'data/sbap_core_ec50_assay.json' 28 | 29 | tokenizer = dict( 30 | type="SeqToToken", 31 | model="bert-base-uncased" 32 | ) 33 | 34 | data = dict( 35 | samples_per_gpu=128, 36 | workers_per_gpu=4, 37 | train=dict( 38 | split="train", 39 | type=dataset_type, 40 | ann_file=ann_file, 41 | pipeline=train_pipeline, 42 | tokenizer=tokenizer, 43 | sample_mode="group", 44 | sample_config=dict( 45 | uniform_over_groups=None, 46 | n_groups_per_batch=4, 47 | distinct_groups=True 48 | ) 49 | 50 | ), 51 | ood_val=dict( 52 | split="ood_val", 53 | type=dataset_type, 54 | ann_file=ann_file, 55 | pipeline=test_pipeline, 56 | tokenizer=tokenizer, 57 | rule="greater", 58 | save_best="accuracy" 59 | ), 60 | iid_val=dict( 61 | split="iid_val", 62 | type=dataset_type, 63 | ann_file=ann_file, 64 | pipeline=test_pipeline, 65 | tokenizer=tokenizer 66 | ), 67 | ood_test=dict( 68 | split="ood_test", 69 | type=dataset_type, 70 | ann_file=ann_file, 71 | pipeline=test_pipeline, 72 | tokenizer=tokenizer, 73 | ), 74 | iid_test=dict( 75 | split="iid_test", 76 | type=dataset_type, 77 | ann_file=ann_file, 78 | pipeline=test_pipeline, 79 | tokenizer=tokenizer 80 | ), 81 | ) 82 | 83 | model = dict( 84 | type="CORAL", 85 | tasker=dict( 86 | type='Classifier', 87 | backbone=dict( 88 | type='GIN', 89 | num_node_emb_list=[39], 90 | num_edge_emb_list=[10], 91 | num_layers=4, 92 | emb_dim=128, 93 | readout='sum', 94 | JK='last', 95 | dropout=0.1, 96 | ), 97 | aux_backbone=dict( 98 | type='Bert', 99 | model="bert-base-uncased", 100 | ), 101 | neck=dict(type="Concatenate"), 102 | head=dict( 103 | type='LinearClsHead', 104 | num_classes=2, 105 | in_channels=128 + 768, 106 | loss=dict( 107 | type='CrossEntropyLoss', 108 | reduction='none') 109 | ) 110 | ), 111 | coral_penalty_weight=0.001 112 | ) 113 | -------------------------------------------------------------------------------- /configs/algorithms/erm/lbap_core_ec50_assay_erm.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='Collect', 11 | keys=['input', 'gt_label', 'group'] 12 | ) 13 | ] 14 | test_pipeline = [ 15 | dict( 16 | type="SmileToGraph", 17 | keys=["input"] 18 | ), 19 | dict( 20 | type='Collect', 21 | keys=['input', 'gt_label', 'group'] 22 | )] 23 | 24 | # dataset 25 | dataset_type = "LBAPDataset" 26 | ann_file = 'data/lbap_core_ec50_assay.json' 27 | 28 | data = dict( 29 | samples_per_gpu=128, 30 | workers_per_gpu=4, 31 | train=dict( 32 | split="train", 33 | type=dataset_type, 34 | ann_file=ann_file, 35 | pipeline=train_pipeline 36 | ), 37 | ood_val=dict( 38 | split="ood_val", 39 | type=dataset_type, 40 | ann_file=ann_file, 41 | pipeline=test_pipeline, 42 | rule="greater", 43 | save_best="accuracy" 44 | ), 45 | iid_val=dict( 46 | split="iid_val", 47 | type=dataset_type, 48 | ann_file=ann_file, 49 | pipeline=test_pipeline, 50 | ), 51 | ood_test=dict( 52 | split="ood_test", 53 | type=dataset_type, 54 | ann_file=ann_file, 55 | pipeline=test_pipeline, 56 | ), 57 | iid_test=dict( 58 | split="iid_test", 59 | type=dataset_type, 60 | ann_file=ann_file, 61 | pipeline=test_pipeline, 62 | ), 63 | ) 64 | # model 65 | model = dict( 66 | type="ERM", 67 | tasker=dict( 68 | type='Classifier', 69 | backbone=dict( 70 | type='GIN', 71 | num_node_emb_list=[39], 72 | num_edge_emb_list=[10], 73 | num_layers=4, 74 | emb_dim=128, 75 | readout='sum', 76 | JK='last', 77 | dropout=0.1, 78 | ), 79 | head=dict( 80 | type='LinearClsHead', 81 | num_classes=2, 82 | in_channels=128, 83 | loss=dict( 84 | type='CrossEntropyLoss', 85 | ) 86 | ) 87 | ) 88 | ) 89 | -------------------------------------------------------------------------------- /configs/algorithms/erm/sbap_core_ec50_assay_erm.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='Collect', 11 | keys=['input', 'aux_input', 'gt_label', 'group'] 12 | ) 13 | ] 14 | test_pipeline = [ 15 | dict( 16 | type="SmileToGraph", 17 | keys=["input"] 18 | ), 19 | dict( 20 | type='Collect', 21 | keys=['input', 'aux_input', 'gt_label', 'group'] 22 | )] 23 | 24 | # dataset 25 | 26 | dataset_type = "SBAPDataset" 27 | ann_file = 'data/sbap_core_ec50_assay.json' 28 | 29 | tokenizer = dict( 30 | type="SeqToToken", 31 | model="bert-base-uncased" 32 | ) 33 | 34 | data = dict( 35 | samples_per_gpu=128, 36 | workers_per_gpu=4, 37 | train=dict( 38 | split="train", 39 | type=dataset_type, 40 | ann_file=ann_file, 41 | pipeline=train_pipeline, 42 | tokenizer=tokenizer 43 | 44 | ), 45 | ood_val=dict( 46 | split="ood_val", 47 | type=dataset_type, 48 | ann_file=ann_file, 49 | pipeline=test_pipeline, 50 | tokenizer=tokenizer, 51 | rule="greater", 52 | save_best="accuracy" 53 | ), 54 | iid_val=dict( 55 | split="iid_val", 56 | type=dataset_type, 57 | ann_file=ann_file, 58 | pipeline=test_pipeline, 59 | tokenizer=tokenizer 60 | ), 61 | ood_test=dict( 62 | split="ood_test", 63 | type=dataset_type, 64 | ann_file=ann_file, 65 | pipeline=test_pipeline, 66 | tokenizer=tokenizer, 67 | ), 68 | iid_test=dict( 69 | split="iid_test", 70 | type=dataset_type, 71 | ann_file=ann_file, 72 | pipeline=test_pipeline, 73 | tokenizer=tokenizer 74 | ), 75 | ) 76 | 77 | # model 78 | 79 | model = dict( 80 | type="ERM", 81 | tasker=dict( 82 | type='Classifier', 83 | backbone=dict( 84 | type='GIN', 85 | num_node_emb_list=[39], 86 | num_edge_emb_list=[10], 87 | num_layers=4, 88 | emb_dim=128, 89 | readout='sum', 90 | JK='last', 91 | dropout=0.1, 92 | ), 93 | aux_backbone=dict( 94 | type='Bert', 95 | model="bert-base-uncased", 96 | ), 97 | neck=dict(type="Concatenate"), 98 | head=dict( 99 | type='LinearClsHead', 100 | num_classes=2, 101 | in_channels=128 + 768, 102 | loss=dict( 103 | type='CrossEntropyLoss', 104 | ) 105 | ) 106 | )) 107 | -------------------------------------------------------------------------------- /configs/algorithms/groupdro/lbap_core_ec50_assay_groupdro.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='Collect', 11 | keys=['input', 'gt_label', 'group'] 12 | ) 13 | ] 14 | test_pipeline = [ 15 | dict( 16 | type="SmileToGraph", 17 | keys=["input"] 18 | ), 19 | dict( 20 | type='Collect', 21 | keys=['input', 'gt_label', 'group'] 22 | )] 23 | 24 | # dataset 25 | dataset_type = "LBAPDataset" 26 | ann_file = 'data/lbap_core_ec50_assay.json' 27 | 28 | data = dict( 29 | samples_per_gpu=128, 30 | workers_per_gpu=4, 31 | train=dict( 32 | split="train", 33 | type=dataset_type, 34 | ann_file=ann_file, 35 | pipeline=train_pipeline, 36 | sample_mode="group", 37 | sample_config=dict( 38 | uniform_over_groups=None, 39 | n_groups_per_batch=4, 40 | distinct_groups=True 41 | ) 42 | ), 43 | ood_val=dict( 44 | split="ood_val", 45 | type=dataset_type, 46 | ann_file=ann_file, 47 | pipeline=test_pipeline, 48 | rule="greater", 49 | save_best="accuracy" 50 | ), 51 | iid_val=dict( 52 | split="iid_val", 53 | type=dataset_type, 54 | ann_file=ann_file, 55 | pipeline=test_pipeline, 56 | ), 57 | ood_test=dict( 58 | split="ood_test", 59 | type=dataset_type, 60 | ann_file=ann_file, 61 | pipeline=test_pipeline, 62 | ), 63 | iid_test=dict( 64 | split="iid_test", 65 | type=dataset_type, 66 | ann_file=ann_file, 67 | pipeline=test_pipeline, 68 | ), 69 | ) 70 | # model 71 | model = dict( 72 | type="GroupDRO", 73 | tasker=dict( 74 | type='Classifier', 75 | backbone=dict( 76 | type='GIN', 77 | num_node_emb_list=[39], 78 | num_edge_emb_list=[10], 79 | num_layers=4, 80 | emb_dim=128, 81 | readout='sum', 82 | JK='last', 83 | dropout=0.1, 84 | ), 85 | head=dict( 86 | type='LinearClsHead', 87 | num_classes=2, 88 | in_channels=128, 89 | loss=dict( 90 | type='CrossEntropyLoss', 91 | reduction="none", 92 | ) 93 | ) 94 | ), 95 | group_dro_step_size=0.001, 96 | num_groups=47 97 | ) 98 | -------------------------------------------------------------------------------- /configs/algorithms/groupdro/sbap_core_ec50_assay_groupdro.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='Collect', 11 | keys=['input', 'aux_input', 'gt_label', 'group'] 12 | ) 13 | ] 14 | test_pipeline = [ 15 | dict( 16 | type="SmileToGraph", 17 | keys=["input"] 18 | ), 19 | dict( 20 | type='Collect', 21 | keys=['input', 'aux_input', 'gt_label', 'group'] 22 | )] 23 | 24 | # dataset 25 | dataset_type = "SBAPDataset" 26 | ann_file = 'data/sbap_core_ec50_assay.json' 27 | 28 | tokenizer = dict( 29 | type="SeqToToken", 30 | model="bert-base-uncased" 31 | ) 32 | 33 | data = dict( 34 | samples_per_gpu=128, 35 | workers_per_gpu=4, 36 | train=dict( 37 | split="train", 38 | type=dataset_type, 39 | ann_file=ann_file, 40 | pipeline=train_pipeline, 41 | tokenizer=tokenizer, 42 | sample_mode="group", 43 | sample_config=dict( 44 | uniform_over_groups=None, 45 | n_groups_per_batch=4, 46 | distinct_groups=True 47 | ) 48 | 49 | ), 50 | ood_val=dict( 51 | split="ood_val", 52 | type=dataset_type, 53 | ann_file=ann_file, 54 | pipeline=test_pipeline, 55 | tokenizer=tokenizer, 56 | rule="greater", 57 | save_best="accuracy" 58 | ), 59 | iid_val=dict( 60 | split="iid_val", 61 | type=dataset_type, 62 | ann_file=ann_file, 63 | pipeline=test_pipeline, 64 | tokenizer=tokenizer 65 | ), 66 | ood_test=dict( 67 | split="ood_test", 68 | type=dataset_type, 69 | ann_file=ann_file, 70 | pipeline=test_pipeline, 71 | tokenizer=tokenizer, 72 | ), 73 | iid_test=dict( 74 | split="iid_test", 75 | type=dataset_type, 76 | ann_file=ann_file, 77 | pipeline=test_pipeline, 78 | tokenizer=tokenizer 79 | ), 80 | ) 81 | model = dict( 82 | type="GroupDRO", 83 | tasker=dict( 84 | type='Classifier', 85 | backbone=dict( 86 | type='GIN', 87 | num_node_emb_list=[39], 88 | num_edge_emb_list=[10], 89 | num_layers=4, 90 | emb_dim=128, 91 | readout='sum', 92 | JK='last', 93 | dropout=0.1, 94 | ), 95 | aux_backbone=dict( 96 | type='Bert', 97 | model="bert-base-uncased", 98 | ), 99 | neck=dict(type="Concatenate"), 100 | head=dict( 101 | type='LinearClsHead', 102 | num_classes=2, 103 | in_channels=128 + 768, 104 | loss=dict( 105 | type='CrossEntropyLoss', 106 | reduction='none', 107 | ) 108 | ) 109 | ), 110 | group_dro_step_size=0.001, 111 | num_groups=55, 112 | ) 113 | -------------------------------------------------------------------------------- /configs/algorithms/irm/lbap_core_ec50_assay_irm.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='Collect', 11 | keys=['input', 'gt_label', 'group'] 12 | ) 13 | ] 14 | test_pipeline = [ 15 | dict( 16 | type="SmileToGraph", 17 | keys=["input"] 18 | ), 19 | dict( 20 | type='Collect', 21 | keys=['input', 'gt_label', 'group'] 22 | )] 23 | 24 | # dataset 25 | dataset_type = "LBAPDataset" 26 | ann_file = 'data/lbap_core_ec50_assay.json' 27 | 28 | data = dict( 29 | samples_per_gpu=128, 30 | workers_per_gpu=4, 31 | train=dict( 32 | split="train", 33 | type=dataset_type, 34 | ann_file=ann_file, 35 | pipeline=train_pipeline, 36 | sample_mode="group", 37 | sample_config=dict( 38 | uniform_over_groups=None, 39 | n_groups_per_batch=4, 40 | distinct_groups=True 41 | ) 42 | ), 43 | ood_val=dict( 44 | split="ood_val", 45 | type=dataset_type, 46 | ann_file=ann_file, 47 | pipeline=test_pipeline, 48 | rule="greater", 49 | save_best="accuracy" 50 | ), 51 | iid_val=dict( 52 | split="iid_val", 53 | type=dataset_type, 54 | ann_file=ann_file, 55 | pipeline=test_pipeline, 56 | ), 57 | ood_test=dict( 58 | split="ood_test", 59 | type=dataset_type, 60 | ann_file=ann_file, 61 | pipeline=test_pipeline, 62 | ), 63 | iid_test=dict( 64 | split="iid_test", 65 | type=dataset_type, 66 | ann_file=ann_file, 67 | pipeline=test_pipeline, 68 | ), 69 | ) 70 | 71 | model = dict(type="IRM", 72 | tasker=dict( 73 | type='Classifier', 74 | backbone=dict( 75 | type='GIN', 76 | num_node_emb_list=[39], 77 | num_edge_emb_list=[10], 78 | num_layers=4, 79 | emb_dim=128, 80 | readout='sum', 81 | JK='last', 82 | dropout=0.1, 83 | ), 84 | head=dict( 85 | type='LinearClsHead', 86 | num_classes=2, 87 | in_channels=128, 88 | loss=dict( 89 | type='CrossEntropyLoss', 90 | reduction='none', 91 | ) 92 | ), 93 | 94 | ), 95 | irm_lambda=10, 96 | irm_penalty_anneal_iters=500, 97 | ) 98 | 99 | optimizer_config = dict(type="IRMOptimizerHook", grad_clip=None) 100 | -------------------------------------------------------------------------------- /configs/algorithms/irm/sbap_core_ec50_assay_irm.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../../_base_/schedules/classification.py', '../../_base_/default_runtime.py'] 2 | 3 | # transform 4 | train_pipeline = [ 5 | dict( 6 | type="SmileToGraph", 7 | keys=["input"] 8 | ), 9 | dict( 10 | type='Collect', 11 | keys=['input', 'aux_input', 'gt_label', 'group'] 12 | ) 13 | ] 14 | test_pipeline = [ 15 | dict( 16 | type="SmileToGraph", 17 | keys=["input"] 18 | ), 19 | dict( 20 | type='Collect', 21 | keys=['input', 'aux_input', 'gt_label', 'group'] 22 | )] 23 | 24 | # dataset 25 | dataset_type = "SBAPDataset" 26 | ann_file = 'data/sbap_core_ec50_assay.json' 27 | 28 | tokenizer = dict( 29 | type="SeqToToken", 30 | model="bert-base-uncased" 31 | ) 32 | 33 | data = dict( 34 | samples_per_gpu=128, 35 | workers_per_gpu=4, 36 | train=dict( 37 | split="train", 38 | type=dataset_type, 39 | ann_file=ann_file, 40 | pipeline=train_pipeline, 41 | tokenizer=tokenizer, 42 | sample_mode="group", 43 | sample_config=dict( 44 | uniform_over_groups=None, 45 | n_groups_per_batch=4, 46 | distinct_groups=True 47 | ) 48 | 49 | ), 50 | ood_val=dict( 51 | split="ood_val", 52 | type=dataset_type, 53 | ann_file=ann_file, 54 | pipeline=test_pipeline, 55 | tokenizer=tokenizer, 56 | rule="greater", 57 | save_best="accuracy" 58 | ), 59 | iid_val=dict( 60 | split="iid_val", 61 | type=dataset_type, 62 | ann_file=ann_file, 63 | pipeline=test_pipeline, 64 | tokenizer=tokenizer 65 | ), 66 | ood_test=dict( 67 | split="ood_test", 68 | type=dataset_type, 69 | ann_file=ann_file, 70 | pipeline=test_pipeline, 71 | tokenizer=tokenizer, 72 | ), 73 | iid_test=dict( 74 | split="iid_test", 75 | type=dataset_type, 76 | ann_file=ann_file, 77 | pipeline=test_pipeline, 78 | tokenizer=tokenizer 79 | ), 80 | ) 81 | 82 | model = dict( 83 | type="IRM", 84 | tasker=dict( 85 | type='Classifier', 86 | backbone=dict( 87 | type='GIN', 88 | num_node_emb_list=[39], 89 | num_edge_emb_list=[10], 90 | num_layers=4, 91 | emb_dim=128, 92 | readout='sum', 93 | JK='last', 94 | dropout=0.1, 95 | ), 96 | aux_backbone=dict( 97 | type='Bert', 98 | model="bert-base-uncased", 99 | ), 100 | neck=dict(type="Concatenate"), 101 | head=dict( 102 | type='LinearClsHead', 103 | num_classes=2, 104 | in_channels=128 + 768, 105 | loss=dict( 106 | type='CrossEntropyLoss', 107 | reduction='none', 108 | ) 109 | ) 110 | ), 111 | irm_lambda=10, 112 | irm_penalty_anneal_iters=500, 113 | ) 114 | 115 | optimizer_config = dict(type="IRMOptimizerHook", grad_clip=None) 116 | -------------------------------------------------------------------------------- /configs/curators/lbap_core_ec50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/assay.py'] 4 | -------------------------------------------------------------------------------- /configs/curators/lbap_core_ec50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_ec50_scaffold")) -------------------------------------------------------------------------------- /configs/curators/lbap_core_ec50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_ec50_size")) -------------------------------------------------------------------------------- /configs/curators/lbap_core_ic50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_ic50_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) 8 | -------------------------------------------------------------------------------- /configs/curators/lbap_core_ic50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_ic50_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) -------------------------------------------------------------------------------- /configs/curators/lbap_core_ic50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_ic50_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) -------------------------------------------------------------------------------- /configs/curators/lbap_core_ki_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_ki_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) -------------------------------------------------------------------------------- /configs/curators/lbap_core_ki_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_ki_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) -------------------------------------------------------------------------------- /configs/curators/lbap_core_ki_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_ki_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) -------------------------------------------------------------------------------- /configs/curators/lbap_core_potency_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_potency_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) 8 | -------------------------------------------------------------------------------- /configs/curators/lbap_core_potency_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_potency_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) -------------------------------------------------------------------------------- /configs/curators/lbap_core_potency_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_core_potency_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) -------------------------------------------------------------------------------- /configs/curators/lbap_general_ec50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_ec50_assay")) 6 | -------------------------------------------------------------------------------- /configs/curators/lbap_general_ec50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_ec50_scaffold")) -------------------------------------------------------------------------------- /configs/curators/lbap_general_ec50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_ec50_size")) -------------------------------------------------------------------------------- /configs/curators/lbap_general_ic50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_ec50_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) -------------------------------------------------------------------------------- /configs/curators/lbap_general_ic50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_ec50_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) -------------------------------------------------------------------------------- /configs/curators/lbap_general_ic50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_ec50_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) -------------------------------------------------------------------------------- /configs/curators/lbap_general_ki_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_ki_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) -------------------------------------------------------------------------------- /configs/curators/lbap_general_ki_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_ki_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) -------------------------------------------------------------------------------- /configs/curators/lbap_general_ki_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_ki_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) -------------------------------------------------------------------------------- /configs/curators/lbap_general_potency_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_potency_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) -------------------------------------------------------------------------------- /configs/curators/lbap_general_potency_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_potency_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) -------------------------------------------------------------------------------- /configs/curators/lbap_general_potency_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_general_potency_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) -------------------------------------------------------------------------------- /configs/curators/lbap_refined_ec50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_ec50_assay")) 6 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_ec50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_ec50_scaffold")) 6 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_ec50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_ec50_size")) 6 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_ic50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_ic50_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) 8 | 9 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_ic50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_ic50_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) 8 | 9 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_ic50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_ic50_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) 8 | 9 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_ki_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_ki50_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) 8 | 9 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_ki_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_ki50_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) 8 | 9 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_ki_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_ki50_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) 8 | 9 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_potency_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_potency_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) 8 | 9 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_potency_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_potency_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) 8 | 9 | -------------------------------------------------------------------------------- /configs/curators/lbap_refined_potency_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/lbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="lbap_refined_potency_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) 8 | 9 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ec50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | noise_filter = dict(assay=dict(assay_target_type=["SINGLE PROTEIN"])) 6 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ec50_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ec50_protein")) 6 | 7 | noise_filter = dict(assay=dict(assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ec50_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ec50_protein_family")) 6 | 7 | noise_filter = dict(assay=dict(assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ec50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ec50_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ec50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ec50_size")) 6 | 7 | noise_filter = dict(assay=dict(assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ic50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ic50_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ic50_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ic50_protein")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ic50_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ic50_protein_family")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ic50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ic50_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ic50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ic50_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ki_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ki_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ki_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ki_protein")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ki_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ki_protein_family")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ki_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ki_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_ki_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_ki_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_potency_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_potency_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_potency_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_potency_protein")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_potency_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_potency_protein_family")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_potency_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_potency_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_core_potency_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/core.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_core_potency_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'], assay_target_type=["SINGLE PROTEIN"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ec50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ec50_assay")) -------------------------------------------------------------------------------- /configs/curators/sbap_general_ec50_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ec50_protein")) -------------------------------------------------------------------------------- /configs/curators/sbap_general_ec50_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ec50_protein_family")) 6 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ec50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ec50_scaffold")) 6 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ec50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ec50_size")) 6 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ic50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ic50_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ic50_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ic50_protein")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ic50_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ic50_protein_family")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ic50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ic50_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ic50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ic50_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['IC50'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ki_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ki_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ki_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ki_protein")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ki_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ki_protein_family")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ki_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ki_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_ki_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_ki_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Ki'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_potency_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_potency_assay")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) 8 | 9 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_potency_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_potency_protein")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_potency_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_potency_protein_family")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_potency_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_potency_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_general_potency_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/general.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_general_potency_size")) 6 | 7 | noise_filter = dict(assay=dict(measurement_type=['Potency'])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ec50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ec50_assay")) 6 | 7 | noise_filter = dict(assay=dict(assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ec50_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ec50_protein")) 6 | 7 | noise_filter = dict(assay=dict(assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ec50_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ec50_protein_family")) 6 | 7 | noise_filter = dict(assay=dict(assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ec50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ec50_scaffold")) 6 | 7 | noise_filter = dict(assay=dict(assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ec50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ec50_size")) 6 | 7 | noise_filter = dict(assay=dict(assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 8 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ic50_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ic50_assay")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['IC50'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ic50_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ic50_protein")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['IC50'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ic50_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ic50_protein_family")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['IC50'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ic50_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ic50_scaffold")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['IC50'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ic50_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ic50_size")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['IC50'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ki_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ki_assay")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['Ki'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ki_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ki_protein")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['Ki'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ki_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ki_protein_family")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['Ki'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ki_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ki_scaffold")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['Ki'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_ki_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_ki_size")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['Ki'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_potency_assay.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/assay.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_potency_assay")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['Potency'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_potency_protein.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/protein.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_potency_protein")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['Potency'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_potency_protein_family.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/protein_family.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_potency_protein_family")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['Potency'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_potency_scaffold.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/scaffold.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_potency_scaffold")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['Potency'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /configs/curators/sbap_refined_potency_size.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/curators/sbap_defaults.py', 2 | '../_base_/curators/noises/refined.py', 3 | '../_base_/curators/domains/size.py'] 4 | 5 | path = dict(task=dict(subset="sbap_refined_potency_size")) 6 | 7 | noise_filter = dict( 8 | assay=dict( 9 | measurement_type=['Potency'], 10 | assay_target_type=["SINGLE PROTEIN", "PROTEIN COMPLEX", "PROTEIN FAMILY"])) 11 | -------------------------------------------------------------------------------- /demo/sources/curator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/DrugOOD/61a70e4ad1fb227e4a264ed5ba87d4c78fdb4ae7/demo/sources/curator.png -------------------------------------------------------------------------------- /demo/sources/drugood_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/DrugOOD/61a70e4ad1fb227e4a264ed5ba87d4c78fdb4ae7/demo/sources/drugood_dataset.png -------------------------------------------------------------------------------- /demo/sources/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/DrugOOD/61a70e4ad1fb227e4a264ed5ba87d4c78fdb4ae7/demo/sources/overview.png -------------------------------------------------------------------------------- /drugood/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 3 | The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 4 | All Tencent Modifications are Copyright (C) THL A29 Limited. 5 | Copyright (c) OpenMMLab. All rights reserved. 6 | """ 7 | import warnings 8 | 9 | import mmcv 10 | from packaging.version import parse 11 | 12 | from .version import __version__ 13 | 14 | 15 | def digit_version(version_str: str, length: int = 4): 16 | """Convert a version string into a tuple of integers. 17 | 18 | This method is usually used for comparing two versions. For pre-release 19 | versions: alpha < beta < rc. 20 | 21 | Args: 22 | version_str (str): The version string. 23 | length (int): The maximum number of version levels. Default: 4. 24 | 25 | Returns: 26 | tuple[int]: The version info in digits (integers). 27 | """ 28 | version = parse(version_str) 29 | assert version.release, f'failed to parse version {version_str}' 30 | release = list(version.release) 31 | release = release[:length] 32 | if len(release) < length: 33 | release = release + [0] * (length - len(release)) 34 | if version.is_prerelease: 35 | mapping = {'a': -3, 'b': -2, 'rc': -1} 36 | val = -4 37 | # version.pre can be None 38 | if version.pre: 39 | if version.pre[0] not in mapping: 40 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 41 | 'version checking may go wrong') 42 | else: 43 | val = mapping[version.pre[0]] 44 | release.extend([val, version.pre[-1]]) 45 | else: 46 | release.extend([val, 0]) 47 | 48 | elif version.is_postrelease: 49 | release.extend([1, version.post]) 50 | else: 51 | release.extend([0, 0]) 52 | return tuple(release) 53 | 54 | 55 | MMCV_MINIMUM_VERSION = '1.3.8' 56 | MMCV_MAXIMUM_VERSION = '1.5.0' 57 | mmcv_version = digit_version(mmcv.__version__) 58 | 59 | assert (mmcv_version >= digit_version(MMCV_MINIMUM_VERSION) 60 | and mmcv_version <= digit_version(MMCV_MAXIMUM_VERSION)), \ 61 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 62 | f'Please install mmcv>={MMCV_MINIMUM_VERSION}, <={MMCV_MAXIMUM_VERSION}.' 63 | 64 | __all__ = ['__version__', 'digit_version'] 65 | -------------------------------------------------------------------------------- /drugood/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 3 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 4 | from .inference import inference_model, init_model, show_result_pyplot 5 | from .test import multi_gpu_test, single_gpu_test 6 | from .train import set_random_seed, train_model 7 | from .curate import curate_data 8 | 9 | __all__ = [ 10 | 'set_random_seed', 'train_model', 'init_model', 'inference_model', 11 | 'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot', 'curate_data' 12 | ] 13 | -------------------------------------------------------------------------------- /drugood/apis/curate.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | from mmcv import print_log 3 | from drugood.curators import GenericCurator 4 | 5 | 6 | def curate_data(cfg): 7 | print_log(f'Curator Config:\n{cfg.pretty_text}''\n' + '-' * 60) 8 | curator = GenericCurator(cfg) 9 | # Processing Flow 10 | data = curator.data_loading() 11 | data = curator.noise_filtering(data) 12 | data = curator.uncertainty_processing(data) 13 | data = curator.classification_label_generating(data) 14 | data = curator.data_splitting(data) 15 | curator.data_saving(data) 16 | curator.statistics_reporting() 17 | -------------------------------------------------------------------------------- /drugood/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 3 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 4 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 5 | from .evaluation import * # noqa: F401, F403 6 | from .fp16 import * # noqa: F401, F403 7 | from .hooks import * 8 | from .runner import * 9 | from .utils import * # noqa: F401, F403 10 | -------------------------------------------------------------------------------- /drugood/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Hooks for Evaluations 3 | from .eval_hooks import DistEvalHook, EvalHook 4 | # Metrics for Evaluations 5 | from .eval_metrics import calculate_confusion_matrix, f1_score, precision, \ 6 | precision_recall_f1, recall, support, auc 7 | from .mean_average_precision import average_precision, mean_average_precision 8 | from .multilabel_eval_metrics import average_performance 9 | 10 | __all__ = [ 11 | 'DistEvalHook', 'EvalHook', 'precision', 'recall', 'f1_score', 'support', 12 | 'average_precision', 'mean_average_precision', 'average_performance', 13 | 'calculate_confusion_matrix', 'precision_recall_f1', 'auc' 14 | ] 15 | -------------------------------------------------------------------------------- /drugood/core/evaluation/mean_average_precision.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def average_precision(pred, target): 7 | r"""Calculate the average precision for a single class. 8 | 9 | AP summarizes a precision-recall curve as the weighted mean of maximum 10 | precisions obtained for any r'>r, where r is the recall: 11 | 12 | .. math:: 13 | \text{AP} = \sum_n (R_n - R_{n-1}) P_n 14 | 15 | Note that no approximation is involved since the curve is piecewise 16 | constant. 17 | 18 | Args: 19 | pred (np.ndarray): The model prediction with shape (N, ). 20 | target (np.ndarray): The target of each prediction with shape (N, ). 21 | 22 | Returns: 23 | float: a single float as average precision value. 24 | """ 25 | eps = np.finfo(np.float32).eps 26 | 27 | # sort examples 28 | sort_inds = np.argsort(-pred) 29 | sort_target = target[sort_inds] 30 | 31 | # count true positive examples 32 | pos_inds = sort_target == 1 33 | tp = np.cumsum(pos_inds) 34 | total_pos = tp[-1] 35 | 36 | # count not difficult examples 37 | pn_inds = sort_target != -1 38 | pn = np.cumsum(pn_inds) 39 | 40 | tp[np.logical_not(pos_inds)] = 0 41 | precision = tp / np.maximum(pn, eps) 42 | ap = np.sum(precision) / np.maximum(total_pos, eps) 43 | return ap 44 | 45 | 46 | def mean_average_precision(pred, target, nan_reduce=True): 47 | """Calculate the mean average precision with respect of classes. 48 | 49 | Args: 50 | pred (torch.Tensor | np.ndarray): The model prediction with shape 51 | (N, C), where C is the number of classes. 52 | target (torch.Tensor | np.ndarray): The target of each prediction with 53 | shape (N, C), where C is the number of classes. 1 stands for 54 | positive examples, 0 stands for negative examples and -1 stands for 55 | difficult examples. 56 | 57 | Returns: 58 | float: A single float as mAP value. 59 | """ 60 | if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor): 61 | pred = pred.detach().cpu().numpy() 62 | target = target.detach().cpu().numpy() 63 | elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)): 64 | raise TypeError('pred and target should both be torch.Tensor or' 65 | 'np.ndarray') 66 | 67 | assert pred.shape == \ 68 | target.shape, 'pred and target should be in the same shape.' 69 | 70 | if nan_reduce: 71 | target[np.isnan(target)] = -1 72 | num_classes = pred.shape[1] 73 | ap = np.zeros(num_classes) 74 | for k in range(num_classes): 75 | ap[k] = average_precision(pred[:, k], target[:, k]) 76 | mean_ap = ap.mean() * 100.0 77 | return mean_ap 78 | -------------------------------------------------------------------------------- /drugood/core/evaluation/multilabel_eval_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def average_performance(pred, target, thr=None, k=None): 9 | """Calculate class_precision, class_recall, class_f1, over_precision, 10 | over_recall, over_f1, where C stands for per-class 11 | average, O stands for overall average, P stands for precision, R stands for 12 | recall and F1 stands for F1-score. 13 | 14 | Args: 15 | pred (torch.Tensor | np.ndarray): The model prediction with shape 16 | (N, C), where C is the number of classes. 17 | target (torch.Tensor | np.ndarray): The target of each prediction with 18 | shape (N, C), where C is the number of classes. 1 stands for 19 | positive examples, 0 stands for negative examples and -1 stands for 20 | difficult examples. 21 | thr (float): The confidence threshold. Defaults to None. 22 | k (int): Top-k performance. Note that if thr and k are both given, k 23 | will be ignored. Defaults to None. 24 | 25 | Returns: 26 | tuple: (class_precision, class_recall, class_f1, over_precision, over_recall, over_f1) 27 | """ 28 | if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor): 29 | pred = pred.detach().cpu().numpy() 30 | target = target.detach().cpu().numpy() 31 | elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)): 32 | raise TypeError('pred and target should both be torch.Tensor or' 33 | 'np.ndarray') 34 | if thr is None and k is None: 35 | thr = 0.5 36 | warnings.warn('Neither thr nor k is given, set thr as 0.5 by ' 37 | 'default.') 38 | elif thr is not None and k is not None: 39 | warnings.warn('Both thr and k are given, use threshold in favor of ' 40 | 'top-k.') 41 | 42 | assert pred.shape == \ 43 | target.shape, 'pred and target should be in the same shape.' 44 | 45 | eps = np.finfo(np.float32).eps 46 | target[target == -1] = 0 47 | if thr is not None: 48 | # a label is predicted positive if the confidence is no lower than thr 49 | pos_inds = pred >= thr 50 | 51 | else: 52 | # top-k labels will be predicted positive for any example 53 | sort_inds = np.argsort(-pred, axis=1) 54 | sort_inds_ = sort_inds[:, :k] 55 | inds = np.indices(sort_inds_.shape) 56 | pos_inds = np.zeros_like(pred) 57 | pos_inds[inds[0], sort_inds_] = 1 58 | 59 | tp = (pos_inds * target) == 1 60 | fp = (pos_inds * (1 - target)) == 1 61 | fn = ((1 - pos_inds) * target) == 1 62 | 63 | precision_class = tp.sum(axis=0) / np.maximum( 64 | tp.sum(axis=0) + fp.sum(axis=0), eps) 65 | recall_class = tp.sum(axis=0) / np.maximum( 66 | tp.sum(axis=0) + fn.sum(axis=0), eps) 67 | class_precision = precision_class.mean() * 100.0 68 | class_recall = recall_class.mean() * 100.0 69 | class_f1 = 2 * class_precision * class_recall / np.maximum(class_precision + class_recall, eps) 70 | over_precision = tp.sum() / np.maximum(tp.sum() + fp.sum(), eps) * 100.0 71 | over_recall = tp.sum() / np.maximum(tp.sum() + fn.sum(), eps) * 100.0 72 | over_f1 = 2 * over_precision * over_recall / np.maximum(over_precision + over_recall, eps) 73 | return class_precision, class_recall, class_f1, over_precision, over_recall, over_f1 74 | -------------------------------------------------------------------------------- /drugood/core/fp16/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .decorators import auto_fp16, force_fp32 3 | from .hooks import Fp16OptimizerHook, wrap_fp16_model 4 | 5 | __all__ = ['auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'wrap_fp16_model'] 6 | -------------------------------------------------------------------------------- /drugood/core/fp16/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections import abc 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def cast_tensor_type(inputs, src_type, dst_type): 9 | if isinstance(inputs, torch.Tensor): 10 | return inputs.to(dst_type) 11 | elif isinstance(inputs, str): 12 | return inputs 13 | elif isinstance(inputs, np.ndarray): 14 | return inputs 15 | elif isinstance(inputs, abc.Mapping): 16 | return type(inputs)({ 17 | k: cast_tensor_type(v, src_type, dst_type) 18 | for k, v in inputs.items() 19 | }) 20 | elif isinstance(inputs, abc.Iterable): 21 | return type(inputs)( 22 | cast_tensor_type(item, src_type, dst_type) for item in inputs) 23 | else: 24 | return inputs 25 | -------------------------------------------------------------------------------- /drugood/core/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | from .irm_optimizer_hook import IRMOptimizerHook 3 | -------------------------------------------------------------------------------- /drugood/core/hooks/irm_optimizer_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | import collections 3 | 4 | from mmcv import print_log 5 | from mmcv.runner import OptimizerHook 6 | from mmcv.runner.hooks import HOOKS 7 | 8 | from drugood.utils import get_root_logger 9 | 10 | 11 | @HOOKS.register_module() 12 | class IRMOptimizerHook(OptimizerHook): 13 | def after_train_iter(self, runner): 14 | if runner.model.module.update_count == runner.model.module.irm_penalty_anneal_iters: 15 | print_log("Hit IRM penalty anneal iters, Re-set optimizer", logger=get_root_logger()) 16 | self.reset_optimizer(runner) 17 | 18 | runner.optimizer.zero_grad() 19 | runner.outputs['loss'].backward() 20 | if self.grad_clip is not None: 21 | self.clip_grads(runner.model.parameters()) 22 | runner.optimizer.step() 23 | runner.model.module.update_count += 1 24 | 25 | def reset_optimizer(self, runner): 26 | runner.optimizer.state = collections.defaultdict(dict) # Reset state 27 | -------------------------------------------------------------------------------- /drugood/core/runner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/DrugOOD/61a70e4ad1fb227e4a264ed5ba87d4c78fdb4ae7/drugood/core/runner/__init__.py -------------------------------------------------------------------------------- /drugood/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_utils import DistOptimizerHook, allreduce_grads 3 | from .misc import multi_apply, move_to_device, make_dirs 4 | 5 | __all__ = ['allreduce_grads', 'DistOptimizerHook', 'multi_apply', 'move_to_device', 'make_dirs'] 6 | -------------------------------------------------------------------------------- /drugood/core/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections import OrderedDict 3 | 4 | import torch.distributed as dist 5 | from mmcv.runner import OptimizerHook 6 | from torch._utils import (_flatten_dense_tensors, _take_tensors, 7 | _unflatten_dense_tensors) 8 | 9 | 10 | def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): 11 | if bucket_size_mb > 0: 12 | bucket_size_bytes = bucket_size_mb * 1024 * 1024 13 | buckets = _take_tensors(tensors, bucket_size_bytes) 14 | else: 15 | buckets = OrderedDict() 16 | for tensor in tensors: 17 | tp = tensor.type() 18 | if tp not in buckets: 19 | buckets[tp] = [] 20 | buckets[tp].append(tensor) 21 | buckets = buckets.values() 22 | 23 | for bucket in buckets: 24 | flat_tensors = _flatten_dense_tensors(bucket) 25 | dist.all_reduce(flat_tensors) 26 | flat_tensors.div_(world_size) 27 | for tensor, synced in zip( 28 | bucket, _unflatten_dense_tensors(flat_tensors, bucket)): 29 | tensor.copy_(synced) 30 | 31 | 32 | def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): 33 | grads = [ 34 | param.grad.data for param in params 35 | if param.requires_grad and param.grad is not None 36 | ] 37 | world_size = dist.get_world_size() 38 | if coalesce: 39 | _allreduce_coalesced(grads, world_size, bucket_size_mb) 40 | else: 41 | for tensor in grads: 42 | dist.all_reduce(tensor.div_(world_size)) 43 | 44 | 45 | class DistOptimizerHook(OptimizerHook): 46 | 47 | def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1): 48 | self.grad_clip = grad_clip 49 | self.coalesce = coalesce 50 | self.bucket_size_mb = bucket_size_mb 51 | 52 | def after_train_iter(self, runner): 53 | runner.optimizer.zero_grad() 54 | runner.outputs['loss'].backward() 55 | if self.grad_clip is not None: 56 | self.clip_grads(runner.model.parameters()) 57 | runner.optimizer.step() 58 | -------------------------------------------------------------------------------- /drugood/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | import os 3 | from functools import partial 4 | 5 | import torch 6 | 7 | 8 | def multi_apply(func, *args, **kwargs): 9 | pfunc = partial(func, **kwargs) if kwargs else func 10 | map_results = map(pfunc, *args) 11 | return tuple(map(list, zip(*map_results))) 12 | 13 | 14 | def move_to_device(obj, device=None): 15 | if (device is None): 16 | device = torch.device('cuda') 17 | if isinstance(obj, dict): 18 | return {k: move_to_device(v, device) for k, v in obj.items()} 19 | elif isinstance(obj, list): 20 | return [move_to_device(v, device) for v in obj] 21 | elif isinstance(obj, float) or isinstance(obj, int): 22 | return obj 23 | else: 24 | return obj.to(device) 25 | 26 | 27 | def make_dirs(dir): 28 | if (not os.path.exists(dir)): 29 | try: 30 | os.makedirs(dir) 31 | except FileNotFoundError as e: 32 | print(str(e)) 33 | -------------------------------------------------------------------------------- /drugood/core/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | from .molecule import imshow_infos 3 | 4 | __all__ = ['imshow_infos'] 5 | -------------------------------------------------------------------------------- /drugood/core/visualization/molecule.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | import numpy as np 3 | 4 | # A small value 5 | EPS = 1e-2 6 | 7 | 8 | def imshow_infos(input): 9 | """Show mol with extra infomation. 10 | 11 | Args: 12 | input (str): The smile to be displayed. 13 | infos (dict): Extra infos to display in the image. 14 | text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos 15 | display color. Defaults to 'white'. 16 | font_size (int): Extra infos display font size. Defaults to 26. 17 | row_width (int): width between each row of results on the image. 18 | win_name (str): The image title. Defaults to '' 19 | show (bool): Whether to show the image. Defaults to True. 20 | fig_size (tuple): Image show figure size. Defaults to (15, 10). 21 | wait_time (int): How many seconds to display the image. Defaults to 0. 22 | out_file (Optional[str]): The filename to write the image. 23 | Defaults to None. 24 | Returns: 25 | np.ndarray: The image with extra infomations. 26 | """ 27 | NotImplemented 28 | -------------------------------------------------------------------------------- /drugood/curators/__init__.py: -------------------------------------------------------------------------------- 1 | from .curator import GenericCurator -------------------------------------------------------------------------------- /drugood/curators/chembl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/DrugOOD/61a70e4ad1fb227e4a264ed5ba87d4c78fdb4ae7/drugood/curators/chembl/__init__.py -------------------------------------------------------------------------------- /drugood/curators/chembl/filter.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | from drugood.utils.smile_to_dgl import smile2graph 3 | 4 | 5 | class Filter(object): 6 | """Compose a filter pipeline with a sequence of sub filters. 7 | 8 | Args: 9 | cfg (Dict[]): 10 | Either config dicts of filters. 11 | """ 12 | 13 | def __init__(self, cfg): 14 | """ 15 | The Base class of filter. 16 | Args: 17 | cfg: The config object. 18 | """ 19 | self.filter_func = [] 20 | for (filter, config) in cfg.items(): 21 | one_func = self.__getattribute__(filter) 22 | self.filter_func.append(one_func) 23 | 24 | def __call__(self, one_data): 25 | for func in self.filter_func: 26 | res = func(one_data) 27 | if not res: 28 | return False 29 | return True 30 | 31 | 32 | class AssayFilter(Filter): 33 | def __init__(self, cfg, sql_func): 34 | """ 35 | The Filter for assay. 36 | Args: 37 | cfg: The config object. 38 | sql_func: The object of SQLFunction. 39 | """ 40 | self.cfg = cfg 41 | self.sql_func = sql_func 42 | super(AssayFilter, self).__init__(cfg=cfg) 43 | 44 | def measurement_type(self, data): 45 | for case in data: 46 | index_type = case['STANDARD_TYPE'] 47 | if index_type not in self.cfg["measurement_type"]: 48 | return False 49 | return True 50 | 51 | def assay_value_units(self, data): 52 | for case in data: 53 | units = case['STANDARD_UNITS'] 54 | if units not in self.cfg["assay_value_units"]: 55 | return False 56 | return True 57 | 58 | def molecules_number(self, data): 59 | number_m = len(data) 60 | if self.cfg["molecules_number"][0] <= number_m <= self.cfg["molecules_number"][1]: 61 | return True 62 | else: 63 | return False 64 | 65 | def assay_target_type(self, data): 66 | assay_id = data[0]['ASSAY_ID'] 67 | target_type = self.sql_func.get_assay_target_type(assay_id) 68 | if target_type in self.cfg["assay_target_type"]: 69 | return True 70 | else: 71 | return False 72 | 73 | def confidence_score(self, data): 74 | assay_id = data[0]['ASSAY_ID'] 75 | confidence_score = self.sql_func.get_confidence_score_for_assay(assay_id) 76 | confidence_score = int(confidence_score) 77 | if self.cfg["confidence_score"] is None: 78 | return True 79 | elif confidence_score >= self.cfg["confidence_score"]: 80 | return True 81 | else: 82 | return False 83 | 84 | 85 | 86 | class SampleFilter(Filter): 87 | def __init__(self, cfg, mol_id_to_smile): 88 | """ 89 | The Filter for samples. 90 | Args: 91 | cfg: The config object. 92 | mol_id_to_smile: 93 | """ 94 | self.cfg = cfg 95 | self.mol_id_to_smile = mol_id_to_smile 96 | super(SampleFilter, self).__init__(cfg=cfg) 97 | 98 | def filter_none(self, case): 99 | for item in case.values(): 100 | if item is None: 101 | return False 102 | return True 103 | 104 | def smile_exist(self, case): 105 | molregno = case['MOLREGNO'] 106 | if molregno in self.mol_id_to_smile: 107 | return True 108 | else: 109 | return False 110 | 111 | def smile_legal(self, case): 112 | molregno = case['MOLREGNO'] 113 | smile = self.mol_id_to_smile[molregno] 114 | graph = smile2graph(smile) 115 | if graph is None or graph.num_edges() == 0 or graph.num_nodes() == 0: 116 | return False 117 | else: 118 | return True 119 | 120 | def value_relation(self, case): 121 | relation = case['STANDARD_RELATION'] 122 | if relation in self.cfg["value_relation"]: 123 | return True 124 | else: 125 | return False 126 | -------------------------------------------------------------------------------- /drugood/curators/chembl/protein_family.py: -------------------------------------------------------------------------------- 1 | class ProteinFamilyTree(): 2 | """query protein's level identifier at a specific family level 3 | Args: 4 | protein_family_level (int): 5 | Specific protein family level. 6 | sql_func: 7 | Queried Database 8 | """ 9 | 10 | def __init__(self, protein_family_level, sql_func): 11 | ''' 12 | Output the class of a protein in a protein family. 13 | Args: 14 | protein_family_level: Hyperparameter that controls which layer in the multi-level protein classification to output. 15 | sql_func: The object of SQLFunction. 16 | ''' 17 | super(ProteinFamilyTree, self).__init__() 18 | self.id_target_level = protein_family_level 19 | link_nodes = sql_func.get_all_class_id_parent_of_protein() 20 | dict_id_to_parent_level = {} 21 | for item in link_nodes: 22 | cur_id, parent_id, level = item 23 | dict_id_to_parent_level[cur_id] = (parent_id, level) 24 | self.dict_id_to_parent_level = dict_id_to_parent_level 25 | 26 | dict_protein_seq_to_classid = {} 27 | for item in sql_func.get_all_protein_seq_to_id(): 28 | class_id, protein_seq = item 29 | dict_protein_seq_to_classid[protein_seq] = class_id 30 | self.dict_protein_seq_to_classid = dict_protein_seq_to_classid 31 | 32 | def get_target_level_class_id(self, class_id_cur_level): 33 | cur_level = self.dict_id_to_parent_level[class_id_cur_level][1] 34 | while True: 35 | if cur_level == self.id_target_level: 36 | break 37 | class_id_cur_level = self.dict_id_to_parent_level[class_id_cur_level][0] 38 | cur_level -= 1 39 | dict_level = self.dict_id_to_parent_level[class_id_cur_level][1] 40 | assert dict_level == self.id_target_level, \ 41 | 'dict_level:{}, target level:{}'.format(dict_level, self.id_target_level) 42 | return class_id_cur_level 43 | 44 | def __call__(self, protein_seq): 45 | class_id = self.dict_protein_seq_to_classid[protein_seq] 46 | target_level_class_id = self.get_target_level_class_id(class_id) 47 | return target_level_class_id 48 | -------------------------------------------------------------------------------- /drugood/curators/get_domain_info.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | from rdkit import Chem 3 | from rdkit.Chem.Scaffolds import MurckoScaffold 4 | 5 | from drugood.curators.chembl.protein_family import ProteinFamilyTree 6 | 7 | 8 | class DomainInfo(): 9 | def __init__(self, cfg, sql_func): 10 | """ 11 | Args: 12 | cfg: The config obj. 13 | sql_func: The SQLFunction obj to get sql information from ChemBL. 14 | DomainInfo convert the values in the raw data to specific domain description values. 15 | """ 16 | self.protein_family_getter = ProteinFamilyTree(cfg.get("protein_family_level"), sql_func) 17 | 18 | def scaffold(self, smile): 19 | """ 20 | Args: 21 | smile: The smile string in raw data. 22 | Returns: 23 | The scaffold string of a smile. 24 | """ 25 | try: 26 | scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=Chem.MolFromSmiles(smile), includeChirality=False) 27 | return scaffold 28 | except ValueError: 29 | print('get scaffold error') 30 | return 'C' 31 | 32 | def size(self, smile): 33 | """ 34 | Args: 35 | smile: The smile string in raw data. 36 | Returns: 37 | The number of atoms in a smile. 38 | """ 39 | mol = Chem.MolFromSmiles(smile) 40 | if (mol is None): 41 | print('GetNumAtoms error, smiles:{}'.format(smile)) 42 | return len(smile) 43 | number_atom = mol.GetNumAtoms() 44 | return number_atom 45 | 46 | def assay(self, assay): 47 | """ 48 | Args: 49 | assay: The assay ID. 50 | Returns: 51 | The assay ID. 52 | """ 53 | return assay 54 | 55 | def protein(self, protein_seq): 56 | """ 57 | 58 | Args: 59 | protein_seq: The protein sequence. 60 | 61 | Returns: 62 | The protein sequence. 63 | """ 64 | return protein_seq 65 | 66 | def protein_family(self, protein_seq): 67 | """ 68 | 69 | Args: 70 | protein_seq:The protein sequence. 71 | 72 | Returns: 73 | The class id of the protein. 74 | """ 75 | class_id = self.protein_family_getter(protein_seq) 76 | return class_id 77 | 78 | 79 | class SortFunc(): 80 | def __init__(self, cfg, sql_func): 81 | ''' 82 | 83 | Args: 84 | cfg: The config object. 85 | sql_func: The SQLFunction obj to get sql information from ChemBL. 86 | Generate the description value of the domain to complete the sorting of the domain. 87 | ''' 88 | self.domain_info = DomainInfo(cfg, sql_func) 89 | 90 | def domain_value(self, item_domain_data): 91 | ''' 92 | 93 | Args: 94 | item_domain_data: Single domain data. 95 | 96 | Returns: 97 | The sorting of domains is done directly according to the value of domain. 98 | ''' 99 | return item_domain_data[0] 100 | 101 | def domain_capacity(self, item_domain_data): 102 | ''' 103 | 104 | Args: 105 | item_domain_data: Single domain data. 106 | 107 | Returns: 108 | Return the number of samples in the domain as the description value of the domain. 109 | ''' 110 | return len(item_domain_data[1]) 111 | 112 | def scaffold_size(self, item_domain_data): 113 | ''' 114 | 115 | Args: 116 | item_domain_data: Single domain data. 117 | 118 | Returns: 119 | Returns the size of the scaffold as the description value of the domain. 120 | ''' 121 | scaffold = item_domain_data[0] 122 | size_scaffold = self.domain_info.size(scaffold) 123 | return size_scaffold 124 | -------------------------------------------------------------------------------- /drugood/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) OpenMMLab. All rights reserved. 3 | """ 4 | from drugood.datasets.base_dataset import BaseDataset 5 | from drugood.datasets.builder import DATASETS, PIPELINES, build_dataloader, build_dataset 6 | from drugood.datasets.dataset_wrappers import ClassBalancedDataset, ConcatDataset, RepeatDataset 7 | from drugood.datasets.drugood_dataset import DrugOODDataset, LBAPDataset, SBAPDataset 8 | from drugood.datasets.multi_label import MultiLabelDataset 9 | from drugood.datasets.pipelines import Compose 10 | from drugood.datasets.samplers import DistributedSampler 11 | 12 | __all__ = [ 13 | 'BaseDataset', 'MultiLabelDataset', 14 | 'build_dataloader', 'build_dataset', 'Compose', 15 | 'DistributedSampler', 'ConcatDataset', 'RepeatDataset', 16 | 'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 17 | 'DrugOODDataset', 18 | ] 19 | -------------------------------------------------------------------------------- /drugood/datasets/drugood_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | from functools import partial 3 | 4 | import mmcv 5 | import numpy as np 6 | import torch 7 | from mmcv import build_from_cfg 8 | 9 | from drugood.core.utils.data_collect import Collater 10 | from .base_dataset import BaseDataset 11 | from .builder import DATASETS, PIPELINES 12 | 13 | __all__ = ['DrugOODDataset', 'LBAPDataset', 'SBAPDataset'] 14 | 15 | 16 | @DATASETS.register_module() 17 | class DrugOODDataset(BaseDataset): 18 | def __init__(self, 19 | split="train", 20 | label_key="cls_label", 21 | **kwargs): 22 | self.split = split 23 | self.label_key = label_key 24 | 25 | super(DrugOODDataset, self).__init__(**kwargs) 26 | self.sort_domain() 27 | self.groups = self.get_groups() 28 | self._collate = self.initial_collater() 29 | 30 | def initial_collater(self): 31 | return Collater() 32 | 33 | def sort_domain(self): 34 | unique_domains = torch.unique(torch.FloatTensor([case['domain_id'] for case in self.data_infos])) 35 | for case in self.data_infos: 36 | case['domain_id'] = torch.searchsorted(unique_domains, case['domain_id']) 37 | 38 | def get_groups(self): 39 | groups = torch.FloatTensor([case['domain_id'] for case in self.data_infos]).long().unsqueeze(-1) 40 | return groups 41 | 42 | def get_gt_labels(self): 43 | gt_labels = np.array([int(data[self.label_key]) for data in self.data_infos]) 44 | return gt_labels 45 | 46 | def load_annotations(self): 47 | data = mmcv.load(self.ann_file) 48 | return data["split"][self.split] 49 | 50 | 51 | @DATASETS.register_module() 52 | class LBAPDataset(DrugOODDataset): 53 | def __init__(self, **kwargs): 54 | super(LBAPDataset, self).__init__(**kwargs) 55 | 56 | def prepare_data(self, idx): 57 | case = self.data_infos[idx] 58 | input = case["smiles"] 59 | results = {'input': input, 60 | 'gt_label': int(case[self.label_key]), 61 | 'group': case['domain_id']} 62 | return self.pipeline(results) 63 | 64 | 65 | @DATASETS.register_module() 66 | class SBAPDataset(DrugOODDataset): 67 | def __init__(self, tokenizer, **kwargs): 68 | self.tokenizer = build_from_cfg(tokenizer, PIPELINES) 69 | super(SBAPDataset, self).__init__(**kwargs) 70 | 71 | def prepare_data(self, idx): 72 | case = self.data_infos[idx] 73 | input = case["smiles"] 74 | aux_input = case["protein"] 75 | results = {'input': input, 76 | 'aux_input': aux_input, 77 | 'gt_label': int(case[self.label_key]), 78 | 'group': case['domain_id']} 79 | return self.pipeline(results) 80 | 81 | def initial_collater(self): 82 | return Collater(convert_fn=partial(self.tokenizer.__call__)) 83 | -------------------------------------------------------------------------------- /drugood/datasets/grouper.py: -------------------------------------------------------------------------------- 1 | # This code is modified based on 2 | # https://github.com/p-lambda/wilds/blob/a7a452c80cad311cf0aabfd59af8348cba1b9861/wilds/common/grouper.py 3 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 4 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 5 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 6 | 7 | import warnings 8 | 9 | import numpy as np 10 | import torch 11 | from wilds.common.utils import get_counts 12 | 13 | 14 | class Grouper: 15 | """ 16 | Groupers group data points together based on their metadata. 17 | They are used for training and evaluation, 18 | e.g., to measure the accuracies of different groups of data. 19 | """ 20 | 21 | def __init__(self): 22 | raise NotImplementedError 23 | 24 | @property 25 | def n_groups(self): 26 | """ 27 | The number of groups defined by this Grouper. 28 | """ 29 | return self._n_groups 30 | 31 | def metadata_to_group(self, metadata, return_counts=False): 32 | """ 33 | Args: 34 | - metadata (Tensor): An n x d matrix containing d metadata fields 35 | for n different points. 36 | - return_counts (bool): If True, return group counts as well. 37 | Output: 38 | - group (Tensor): An n-length vector of groups. 39 | - group_counts (Tensor): Optional, depending on return_counts. 40 | An n_group-length vector of integers containing the 41 | numbers of data points in each group in the metadata. 42 | """ 43 | raise NotImplementedError 44 | 45 | def group_str(self, group): 46 | """ 47 | Args: 48 | - group (int): A single integer representing a group. 49 | Output: 50 | - group_str (str): A string containing the pretty name of that group. 51 | """ 52 | raise NotImplementedError 53 | 54 | def group_field_str(self, group): 55 | """ 56 | Args: 57 | - group (int): A single integer representing a group. 58 | Output: 59 | - group_str (str): A string containing the name of that group. 60 | """ 61 | raise NotImplementedError 62 | 63 | 64 | class CombinatorialGrouper(Grouper): 65 | def __init__(self, dataset): 66 | grouped_metadata = dataset.groups 67 | if not isinstance(grouped_metadata, torch.LongTensor): 68 | grouped_metadata_long = grouped_metadata.long() 69 | if not torch.all(grouped_metadata == grouped_metadata_long): 70 | warnings.warn(f'CombinatorialGrouper: converting metadata into long') 71 | grouped_metadata = grouped_metadata_long 72 | 73 | self.cardinality = 1 + torch.max( 74 | grouped_metadata, dim=0)[0] 75 | cumprod = torch.cumprod(self.cardinality, dim=0) 76 | self._n_groups = cumprod[-1].item() 77 | self.factors_np = np.concatenate(([1], cumprod[:-1])) 78 | self.factors = torch.from_numpy(self.factors_np) 79 | 80 | def metadata_to_group(self, metadata, return_counts=False): 81 | groups = metadata[:, ].long() @ self.factors 82 | if return_counts: 83 | group_counts = get_counts(groups, self._n_groups) 84 | return groups, group_counts 85 | else: 86 | return groups 87 | 88 | def group_str(self, group): 89 | NotImplemented 90 | 91 | def group_field_str(self, group): 92 | return self.group_str(group).replace('=', ':').replace(',', '_').replace(' ', '') 93 | -------------------------------------------------------------------------------- /drugood/datasets/multi_label.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import numpy as np 5 | 6 | from drugood.core import average_performance, mean_average_precision 7 | from .base_dataset import BaseDataset 8 | 9 | 10 | class MultiLabelDataset(BaseDataset): 11 | """Multi-label Dataset.""" 12 | 13 | def get_cat_ids(self, idx): 14 | """Get category ids by index. 15 | 16 | Args: 17 | idx (int): Index of data. 18 | 19 | Returns: 20 | np.ndarray: Image categories of specified index. 21 | """ 22 | gt_labels = self.data_infos[idx]['gt_label'] 23 | cat_ids = np.where(gt_labels == 1)[0] 24 | return cat_ids 25 | 26 | def evaluate(self, 27 | results, 28 | metric='mAP', 29 | metric_options=None, 30 | logger=None, 31 | **deprecated_kwargs): 32 | """Evaluate the dataset. 33 | 34 | Args: 35 | results (list): Testing results of the dataset. 36 | metric (str | list[str]): Metrics to be evaluated. 37 | Default value is 'mAP'. Options are 'mAP', 'CP', 'CR', 'CF1', 38 | 'OP', 'OR' and 'OF1'. 39 | metric_options (dict, optional): Options for calculating metrics. 40 | Allowed keys are 'k' and 'thr'. Defaults to None 41 | logger (logging.Logger | str, optional): Logger used for printing 42 | related information during evaluation. Defaults to None. 43 | deprecated_kwargs (dict): Used for containing deprecated arguments. 44 | 45 | Returns: 46 | dict: evaluation results 47 | """ 48 | if metric_options is None: 49 | metric_options = {'thr': 0.5} 50 | 51 | if deprecated_kwargs != {}: 52 | warnings.warn('Option arguments for metrics has been changed to ' 53 | '`metric_options`.') 54 | metric_options = {**deprecated_kwargs} 55 | 56 | if isinstance(metric, str): 57 | metrics = [metric] 58 | else: 59 | metrics = metric 60 | allowed_metrics = ['mAP', 'CP', 'CR', 'CF1', 'OP', 'OR', 'OF1'] 61 | eval_results = {} 62 | results = np.vstack(results) 63 | gt_labels = self.get_gt_labels() 64 | num_imgs = len(results) 65 | assert len(gt_labels) == num_imgs, 'dataset testing results should ' \ 66 | 'be of the same length as gt_labels.' 67 | 68 | invalid_metrics = set(metrics) - set(allowed_metrics) 69 | if len(invalid_metrics) != 0: 70 | raise ValueError(f'metric {invalid_metrics} is not supported.') 71 | 72 | if 'mAP' in metrics: 73 | map_value = mean_average_precision(results, gt_labels) 74 | eval_results['mAP'] = map_value 75 | if len(set(metrics) - {'mAP'}) != 0: 76 | performance_keys = ['CP', 'CR', 'CF1', 'OP', 'OR', 'OF1'] 77 | performance_values = average_performance(results, gt_labels, 78 | **metric_options) 79 | for k, v in zip(performance_keys, performance_values): 80 | if k in metrics: 81 | eval_results[k] = v 82 | 83 | return eval_results 84 | -------------------------------------------------------------------------------- /drugood/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Compose 3 | from .compose import Compose 4 | # Data Formatting 5 | from .formating import Collect, ToTensor, to_tensor, Warp, SmileToGraph, SeqToToken 6 | 7 | __all__ = [ 8 | 'Compose', 'to_tensor', 'ToTensor', 'Collect', 9 | 'Warp', "SmileToGraph", "SeqToToken" 10 | ] 11 | -------------------------------------------------------------------------------- /drugood/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections.abc import Sequence 3 | 4 | from mmcv.utils import build_from_cfg 5 | 6 | from ..builder import PIPELINES 7 | 8 | 9 | @PIPELINES.register_module() 10 | class Compose(object): 11 | """Compose a data pipeline with a sequence of transforms. 12 | 13 | Args: 14 | transforms (list[dict | callable]): 15 | Either config dicts of transforms or transform objects. 16 | """ 17 | 18 | def __init__(self, transforms): 19 | assert isinstance(transforms, Sequence) 20 | self.transforms = [] 21 | for transform in transforms: 22 | if isinstance(transform, dict): 23 | transform = build_from_cfg(transform, PIPELINES) 24 | self.transforms.append(transform) 25 | elif callable(transform): 26 | self.transforms.append(transform) 27 | else: 28 | raise TypeError('transform must be callable or a dict, but got' 29 | f' {type(transform)}') 30 | 31 | def __call__(self, data): 32 | for t in self.transforms: 33 | data = t(data) 34 | if data is None: 35 | return None 36 | return data 37 | 38 | def __repr__(self): 39 | format_string = self.__class__.__name__ + '(' 40 | for t in self.transforms: 41 | format_string += f'\n {t}' 42 | format_string += '\n)' 43 | return format_string 44 | -------------------------------------------------------------------------------- /drugood/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | 4 | __all__ = ['DistributedSampler'] 5 | -------------------------------------------------------------------------------- /drugood/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 3 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 4 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 5 | import torch 6 | from torch.utils.data import DistributedSampler as _DistributedSampler 7 | 8 | 9 | class DistributedSampler(_DistributedSampler): 10 | 11 | def __init__(self, 12 | dataset, 13 | num_replicas=None, 14 | rank=None, 15 | shuffle=True, 16 | round_up=True): 17 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 18 | self.shuffle = shuffle 19 | self.round_up = round_up 20 | if self.round_up: 21 | self.total_size = self.num_samples * self.num_replicas 22 | else: 23 | self.total_size = len(self.dataset) 24 | 25 | def __iter__(self): 26 | # deterministically shuffle based on epoch 27 | if self.shuffle: 28 | g = torch.Generator() 29 | g.manual_seed(self.epoch) 30 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 31 | else: 32 | indices = torch.arange(len(self.dataset)).tolist() 33 | 34 | # add extra samples to make it evenly divisible 35 | if self.round_up: 36 | indices = ( 37 | indices * 38 | int(self.total_size / len(indices) + 1))[:self.total_size] 39 | assert len(indices) == self.total_size 40 | 41 | # subsample 42 | indices = indices[self.rank:self.total_size:self.num_replicas] 43 | if self.round_up: 44 | assert len(indices) == self.num_samples 45 | 46 | return iter(indices) 47 | -------------------------------------------------------------------------------- /drugood/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) OpenMMLab. All rights reserved. 3 | """ 4 | from .algorithms import * 5 | from .backbones import * # noqa: F401,F403 6 | from .builder import (BACKBONES, CLASSIFIERS, HEADS, LOSSES, NECKS, 7 | build_backbone, build_head, build_losses, build_model, 8 | build_tasker, build_neck) 9 | from .heads import * # noqa: F401,F403 10 | from .losses import * # noqa: F401,F403 11 | from .necks import * # noqa: F401,F403 12 | from .taskers import * 13 | 14 | __all__ = [ 15 | 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'CLASSIFIERS', 'build_backbone', 16 | 'build_head', 'build_neck', 'build_losses', 'build_tasker' 17 | ] 18 | -------------------------------------------------------------------------------- /drugood/models/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. 2 | # All rights reserved. 3 | from .coral import CORAL 4 | from .dann import DANN 5 | from .erm import ERM 6 | from .groupdro import GroupDRO 7 | from .irm import IRM 8 | from .mixup import MixUp 9 | -------------------------------------------------------------------------------- /drugood/models/algorithms/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 3 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 4 | from abc import ABCMeta, abstractmethod 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch_geometric 10 | import transformers 11 | from dgl import DGLGraph 12 | from mmcv.runner import BaseModule 13 | 14 | 15 | class BaseAlgorithm(BaseModule, metaclass=ABCMeta): 16 | def __init__(self, init_cfg=None): 17 | super(BaseAlgorithm, self).__init__(init_cfg) 18 | 19 | @abstractmethod 20 | def forward_train(self, input, group, **kwargs): 21 | """Placeholder for Forward function for training.""" 22 | pass 23 | 24 | @abstractmethod 25 | def simple_test(self, input, group, **kwargs): 26 | """Placeholder for single case test.""" 27 | pass 28 | 29 | def forward_test(self, input, group, **kwargs): 30 | return self.simple_test(input, group, **kwargs) 31 | 32 | def forward(self, input, group, return_loss=True, **kwargs): 33 | if return_loss: 34 | return self.forward_train(input, group, **kwargs) 35 | else: 36 | return self.forward_test(input, group, **kwargs) 37 | 38 | def train_step(self, data_batch, optimizer): 39 | losses = self(**data_batch) 40 | loss, log_vars = self._parse_losses(losses) 41 | 42 | outputs = dict( 43 | loss=loss, 44 | log_vars=log_vars, 45 | num_samples=self.get_batch_num(data_batch)) 46 | 47 | return outputs 48 | 49 | def _parse_losses(self, losses): 50 | log_vars = OrderedDict() 51 | for loss_name, loss_value in losses.items(): 52 | if isinstance(loss_value, torch.Tensor): 53 | log_vars[loss_name] = loss_value.mean() 54 | elif isinstance(loss_value, list): 55 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) 56 | elif isinstance(loss_value, dict): 57 | for name, value in loss_value.items(): 58 | log_vars[name] = value 59 | else: 60 | raise TypeError( 61 | f'{loss_name} is not a tensor or list of tensors') 62 | 63 | loss = sum(_value for _key, _value in log_vars.items() 64 | if 'loss' in _key) 65 | 66 | log_vars['loss'] = loss 67 | for loss_name, loss_value in log_vars.items(): 68 | # reduce loss when distributed training 69 | if dist.is_available() and dist.is_initialized(): 70 | loss_value = loss_value.data.clone() 71 | dist.all_reduce(loss_value.div_(dist.get_world_size())) 72 | log_vars[loss_name] = loss_value.item() 73 | 74 | return loss, log_vars 75 | 76 | def get_batch_num(self, batch): 77 | if isinstance(batch["input"], torch.Tensor): 78 | return len(batch["input"].data) 79 | elif isinstance(batch["input"], torch_geometric.data.Data): 80 | return batch["input"].num_graphs 81 | elif isinstance(batch['input'], DGLGraph): 82 | return batch['input'].batch_size 83 | elif isinstance(batch['input'], transformers.BatchEncoding): 84 | return len(batch['input']) 85 | else: 86 | raise NotImplementedError 87 | -------------------------------------------------------------------------------- /drugood/models/algorithms/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 3 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 4 | from mmcv.utils import Registry 5 | 6 | ALGORITHMS = Registry('algorithms') 7 | 8 | 9 | def build_algorithm(cfg): 10 | return ALGORITHMS.build(cfg) 11 | -------------------------------------------------------------------------------- /drugood/models/algorithms/coral.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | 3 | import torch 4 | from wilds.common.utils import split_into_groups 5 | 6 | from drugood.models.algorithms.base import BaseAlgorithm 7 | from ..builder import MODELS, build_tasker 8 | 9 | 10 | @MODELS.register_module() 11 | class CORAL(BaseAlgorithm): 12 | """ 13 | Deep CORAL. 14 | This algorithm was originally proposed as an unsupervised domain adaptation algorithm. 15 | Original paper: 16 | @inproceedings{sun2016deep, 17 | title={Deep CORAL: Correlation alignment for deep domain adaptation}, 18 | author={Sun, Baochen and Saenko, Kate}, 19 | booktitle={European Conference on Computer Vision}, 20 | pages={443--450}, 21 | year={2016}, 22 | organization={Springer} 23 | } 24 | The original CORAL loss is the distance between second-order statistics (covariances) 25 | of the source and target features. 26 | The CORAL implementation below is adapted from Wilds's implementation: 27 | https://github.com/p-lambda/wilds/blob/a7a452c80cad311cf0aabfd59af8348cba1b9861/examples/algorithms/deepCORAL.py 28 | """ 29 | 30 | def __init__(self, 31 | tasker, 32 | coral_penalty_weight=0.1 33 | ): 34 | super().__init__() 35 | self.tasker = build_tasker(tasker) 36 | # set IRM-specific variables 37 | self.coral_penalty_weight = coral_penalty_weight 38 | 39 | def init_weights(self): 40 | pass 41 | 42 | def encode(self, input, **kwargs): 43 | feats = self.tasker.extract_feat(input, **kwargs) 44 | return feats 45 | 46 | def decode(self, feats, gt_label=None, return_loss=False): 47 | if return_loss: 48 | return self.tasker.head.forward_train(feats, gt_label) 49 | else: 50 | return self.tasker.head.forward_test(feats) 51 | 52 | def forward_train(self, input, group, gt_label, **kwargs): 53 | feats = self.encode(input, **kwargs) 54 | losses = self.decode(feats, gt_label, return_loss=True) 55 | unique_groups, group_indices, _ = split_into_groups(group) 56 | coral_penalty = [] 57 | n_groups_per_batch = unique_groups.numel() 58 | for i_group in range(n_groups_per_batch): 59 | for j_group in range(i_group + 1, n_groups_per_batch): 60 | coral_penalty.append(self.coral_penalty(feats[group_indices[i_group]], feats[group_indices[j_group]])) 61 | losses.update({"coral_loss": torch.vstack(coral_penalty) * self.coral_penalty_weight}) 62 | return losses 63 | 64 | def simple_test(self, input, group, **kwargs): 65 | feats = self.encode(input, **kwargs) 66 | logits = self.decode(feats) 67 | preds = self.tasker.head.post_process(logits) 68 | return preds 69 | 70 | def coral_penalty(self, x, y): 71 | if x.dim() > 2: 72 | x = x.view(-1, x.size(-1)) 73 | y = y.view(-1, y.size(-1)) 74 | 75 | mean_x = x.mean(0, keepdim=True) 76 | mean_y = y.mean(0, keepdim=True) 77 | cent_x = x - mean_x 78 | cent_y = y - mean_y 79 | cova_x = (cent_x.t() @ cent_x) / (len(x) - 1) 80 | cova_y = (cent_y.t() @ cent_y) / (len(y) - 1) 81 | 82 | mean_diff = (mean_x - mean_y).pow(2).mean() 83 | cova_diff = (cova_x - cova_y).pow(2).mean() 84 | 85 | return mean_diff + cova_diff 86 | -------------------------------------------------------------------------------- /drugood/models/algorithms/dann.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | 3 | from torch.autograd import Function 4 | 5 | from drugood.models.algorithms.base import BaseAlgorithm 6 | from ..builder import MODELS, build_tasker 7 | from ..builder import build_head 8 | 9 | 10 | class GradientReverseLayerF(Function): 11 | @staticmethod 12 | def forward(ctx, x, alpha): 13 | ctx.alpha = alpha 14 | return x.view_as(x) 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | output = grad_output.neg() * ctx.alpha 19 | return output, None 20 | 21 | 22 | @MODELS.register_module() 23 | class DANN(BaseAlgorithm): 24 | """ 25 | DANN. 26 | This algorithm was originally proposed as an unsupervised domain adaptation algorithm. 27 | Original paper: 28 | @article{ganin2016domain, 29 | title={Domain-adversarial training of neural networks}, 30 | author={Ganin, Yaroslav and Ustinova, Evgeniya and Ajakan, Hana and Germain, Pascal and Larochelle, 31 | Hugo and Laviolette, Fran{\c{c}}ois and Marchand, Mario and Lempitsky, Victor}, 32 | journal={The journal of machine learning research}, 33 | volume={17}, 34 | number={1}, 35 | pages={2096--2030}, 36 | year={2016}, 37 | publisher={JMLR. org} 38 | } 39 | """ 40 | 41 | def __init__(self, tasker, dann_cfg=None): 42 | super().__init__() 43 | self.tasker = build_tasker(tasker) 44 | assert dann_cfg is not None 45 | self.alpha = dann_cfg.get("alpha") 46 | self.aux_head = build_head(dann_cfg.get("aux_head")) 47 | 48 | def init_weights(self): 49 | pass 50 | 51 | def encode(self, input, **kwargs): 52 | feats = self.tasker.extract_feat(input, **kwargs) 53 | return feats 54 | 55 | def decode(self, feats, gt_label=None, return_loss=False): 56 | if return_loss: 57 | return self.tasker.head.forward_train(feats, gt_label) 58 | else: 59 | return self.tasker.head.forward_test(feats) 60 | 61 | def forward_train(self, input, group, gt_label, **kwargs): 62 | feats = self.encode(input, **kwargs) 63 | losses = self.decode(feats, gt_label, return_loss=True) 64 | 65 | _feature = GradientReverseLayerF.apply(feats, self.alpha) 66 | aux_losses = self.aux_head.forward_train(_feature, group) 67 | losses.update({f"aux_{key}": val for key, val in aux_losses.items()}) 68 | return losses 69 | 70 | def simple_test(self, input, group, **kwargs): 71 | feats = self.encode(input, **kwargs) 72 | logits = self.decode(feats) 73 | preds = self.tasker.head.post_process(logits) 74 | return preds 75 | -------------------------------------------------------------------------------- /drugood/models/algorithms/erm.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | from drugood.models.algorithms.base import BaseAlgorithm 3 | from ..builder import MODELS, build_tasker 4 | 5 | 6 | @MODELS.register_module() 7 | class ERM(BaseAlgorithm): 8 | def __init__(self, tasker): 9 | super().__init__() 10 | self.tasker = build_tasker(tasker) 11 | 12 | def init_weights(self): 13 | pass 14 | 15 | def encode(self, input, **kwargs): 16 | feats = self.tasker.extract_feat(input, **kwargs) 17 | return feats 18 | 19 | def decode(self, feats, gt_label=None, return_loss=False): 20 | if return_loss: 21 | return self.tasker.head.forward_train(feats, gt_label) 22 | else: 23 | return self.tasker.head.forward_test(feats) 24 | 25 | def forward_train(self, input, group, gt_label, **kwargs): 26 | feats = self.encode(input, **kwargs) 27 | losses = self.decode(feats, gt_label, return_loss=True) 28 | return losses 29 | 30 | def simple_test(self, input, group, **kwargs): 31 | feats = self.encode(input, **kwargs) 32 | logits = self.decode(feats) 33 | preds = self.tasker.head.post_process(logits) 34 | return preds 35 | -------------------------------------------------------------------------------- /drugood/models/algorithms/groupdro.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | import torch 3 | import torch_scatter 4 | 5 | from drugood.models.algorithms.base import BaseAlgorithm 6 | from ..builder import MODELS, build_tasker 7 | 8 | 9 | @MODELS.register_module() 10 | class GroupDRO(BaseAlgorithm): 11 | """ 12 | Group distributionally robust optimization. 13 | Original paper: 14 | @inproceedings{sagawa2019distributionally, 15 | title={Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization}, 16 | author={Sagawa, Shiori and Koh, Pang Wei and Hashimoto, Tatsunori B and Liang, Percy}, 17 | booktitle={International Conference on Learning Representations}, 18 | year={2019} 19 | } 20 | The GroupDRO implementation below is adapted from Wilds's implementation: 21 | https://github.com/p-lambda/wilds/blob/a7a452c80cad311cf0aabfd59af8348cba1b9861/examples/algorithms/groupDRO.py 22 | """ 23 | 24 | def __init__(self, 25 | tasker, 26 | num_groups=44930, 27 | group_dro_step_size=0.01, 28 | ): 29 | super().__init__() 30 | self.tasker = build_tasker(tasker) 31 | self.num_groups = num_groups 32 | # set GroupDRO-specific variables 33 | self.group_weights_step_size = group_dro_step_size 34 | self.group_weights = torch.ones(num_groups) 35 | self.group_weights = self.group_weights / self.group_weights.sum() 36 | 37 | def init_weights(self): 38 | pass 39 | 40 | def encode(self, input, group, **kwargs): 41 | feats = self.tasker.extract_feat(input, **kwargs) 42 | return feats 43 | 44 | def decode(self, feats, gt_label=None, return_loss=False): 45 | if return_loss: 46 | return self.tasker.head.forward_train(feats, gt_label) 47 | else: 48 | return self.tasker.head.forward_test(feats) 49 | 50 | def forward_train(self, input, group, gt_label, **kwargs): 51 | feats = self.encode(input, group, **kwargs) 52 | losses = self.decode(feats, gt_label, return_loss=True) 53 | losses = sum(_value for _key, _value in losses.items() if 'loss' in _key) 54 | # 55 | batch_idx = torch.where(~torch.isnan(gt_label))[0] 56 | group_idx = group[batch_idx] 57 | group_losses = torch_scatter.scatter(src=losses, index=group_idx, dim_size=self.num_groups, 58 | reduce='mean') 59 | # 60 | if self.group_weights.device != group_losses.device: 61 | self.group_weights = self.group_weights.to(device=group_losses.device) 62 | self.group_weights = self.group_weights * torch.exp(self.group_weights_step_size * group_losses.data) 63 | self.group_weights = (self.group_weights / (self.group_weights.sum())) 64 | losses = {"groupdro_loss": group_losses @ self.group_weights} 65 | return losses 66 | 67 | def simple_test(self, input, group, **kwargs): 68 | feats = self.encode(input, group, **kwargs) 69 | logits = self.decode(feats) 70 | preds = self.tasker.head.post_process(logits) 71 | return preds 72 | -------------------------------------------------------------------------------- /drugood/models/algorithms/irm.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | import torch 3 | from torch import autograd 4 | from wilds.common.utils import split_into_groups 5 | 6 | from drugood.models.algorithms.base import BaseAlgorithm 7 | from ..builder import MODELS, build_tasker 8 | 9 | 10 | @MODELS.register_module() 11 | class IRM(BaseAlgorithm): 12 | """ 13 | Invariant risk minimization. 14 | Original paper: 15 | @article{arjovsky2019invariant, 16 | title={Invariant risk minimization}, 17 | author={Arjovsky, Martin and Bottou, L{\'e}on and Gulrajani, Ishaan and Lopez-Paz, David}, 18 | journal={arXiv preprint arXiv:1907.02893}, 19 | year={2019} 20 | } 21 | The IRM penalty function below is adapted from the code snippet 22 | provided in the above paper. 23 | """ 24 | 25 | def __init__(self, 26 | tasker, 27 | irm_lambda=1, 28 | irm_penalty_anneal_iters=500 29 | ): 30 | super().__init__() 31 | self.tasker = build_tasker(tasker) 32 | # set IRM-specific variables 33 | self.irm_lambda = irm_lambda 34 | self.irm_penalty_anneal_iters = irm_penalty_anneal_iters 35 | self.scale = torch.nn.Parameter(torch.tensor(1.)) 36 | self.update_count = 0 37 | 38 | def init_weights(self): 39 | pass 40 | 41 | def encode(self, input, group, **kwargs): 42 | feats = self.tasker.extract_feat(input, **kwargs) 43 | return feats 44 | 45 | def decode(self, feats, gt_label=None, return_loss=False): 46 | if return_loss: 47 | return self.tasker.head.forward_train(feats, gt_label) 48 | else: 49 | return self.tasker.head.forward_test(feats) 50 | 51 | def forward_train(self, input, group, gt_label, **kwargs): 52 | feats = self.encode(input, group, **kwargs) 53 | cls_score = self.tasker.head.fc(feats) 54 | _, group_indices, _ = split_into_groups(group) 55 | main_losses = [] 56 | irm_penalty = [] 57 | for i_group in group_indices: 58 | group_losses_dict = self.tasker.head.loss(self.scale * cls_score[i_group], gt_label[i_group]) 59 | group_losses = sum(_value for _key, _value in group_losses_dict.items() if 'loss' in _key) 60 | if group_losses.numel() > 0: 61 | main_losses.append(group_losses.mean()) 62 | irm_penalty.append(self.irm_penalty(group_losses)) 63 | losses = {"main_loss": torch.vstack(main_losses), "irm_loss": torch.vstack(irm_penalty)} 64 | return losses 65 | 66 | def simple_test(self, input, group, **kwargs): 67 | feats = self.encode(input, group, **kwargs) 68 | logits = self.decode(feats) 69 | preds = self.tasker.head.post_process(logits) 70 | return preds 71 | 72 | def irm_penalty(self, losses): 73 | grad_1 = autograd.grad(losses[0::2].mean(), [self.scale], create_graph=True)[0] 74 | grad_2 = autograd.grad(losses[1::2].mean(), [self.scale], create_graph=True)[0] 75 | result = torch.sum(grad_1 * grad_2) 76 | return result 77 | -------------------------------------------------------------------------------- /drugood/models/algorithms/mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | from drugood.models.algorithms.base import BaseAlgorithm 3 | from ..builder import MODELS, build_tasker 4 | from ..utils import Augments 5 | 6 | 7 | @MODELS.register_module() 8 | class MixUp(BaseAlgorithm): 9 | """ 10 | Mixup 11 | Original paper: 12 | @article{zhang2017mixup, 13 | title={mixup: Beyond empirical risk minimization}, 14 | author={Zhang, Hongyi and Cisse, Moustapha and Dauphin, Yann N and Lopez-Paz, David}, 15 | journal={arXiv preprint arXiv:1710.09412}, 16 | year={2017} 17 | } 18 | Note that we adopt the feature-level mixup strategy 19 | """ 20 | 21 | def __init__(self, tasker, cfg=None): 22 | super().__init__() 23 | self.tasker = build_tasker(tasker) 24 | self.augment = None 25 | if cfg is not None: 26 | cfg['type'] = 'BatchMixup' 27 | if cfg.get('prob', None) is None: 28 | cfg['prob'] = 1.0 29 | self.augment = Augments(cfg) 30 | 31 | def init_weights(self): 32 | pass 33 | 34 | def encode(self, input, **kwargs): 35 | feats = self.tasker.extract_feat(input, **kwargs) 36 | return feats 37 | 38 | def decode(self, feats, gt_label=None, return_loss=False): 39 | if return_loss: 40 | return self.tasker.head.forward_train(feats, gt_label) 41 | else: 42 | return self.tasker.head.forward_test(feats) 43 | 44 | def forward_train(self, input, group, gt_label, **kwargs): 45 | feats = self.encode(input, **kwargs) 46 | if self.augment is not None: 47 | feats, gt_label = self.augment(feats, gt_label) 48 | losses = self.decode(feats, gt_label, return_loss=True) 49 | return losses 50 | 51 | def simple_test(self, input, group, **kwargs): 52 | feats = self.encode(input, **kwargs) 53 | logits = self.decode(feats) 54 | preds = self.tasker.head.post_process(logits) 55 | return preds 56 | -------------------------------------------------------------------------------- /drugood/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 3 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 4 | # Copyright (c) OpenMMLab. All rights reserved. 5 | from .attentivefp import AttentiveFPGNN 6 | from .bert import Bert 7 | from .gat import GAT 8 | from .gcn import GCN 9 | from .gin import GIN 10 | from .mgcn import MGCN 11 | from .nf import NF 12 | from .resnet import ResNet, ResNetV1d 13 | from .schnet import SchNet 14 | from .weave import Weave 15 | 16 | __all__ = [ 17 | 'ResNet', 18 | "AttentiveFPGNN", "GAT", "GCN", "MGCN", "SchNet", "NF", "Weave", "GIN", 19 | "Bert" 20 | ] 21 | -------------------------------------------------------------------------------- /drugood/models/backbones/attentivefp.py: -------------------------------------------------------------------------------- 1 | # The model implementation is adopted from the dgllife library 2 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 3 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 4 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 5 | from dgllife.model import AttentiveFPGNN as AttentiveFPGNN_DGL 6 | from dgllife.model import AttentiveFPReadout 7 | 8 | from ..builder import BACKBONES 9 | from ...core import move_to_device 10 | 11 | 12 | @BACKBONES.register_module() 13 | class AttentiveFPGNN(AttentiveFPGNN_DGL): 14 | def __init__(self, num_timesteps, get_node_weight=False, **kwargs): 15 | super(AttentiveFPGNN, self).__init__(**kwargs) 16 | self.get_node_weight = get_node_weight 17 | self.readout = AttentiveFPReadout( 18 | num_timesteps=num_timesteps, 19 | feat_size=kwargs.get("graph_feat_size"), 20 | dropout=kwargs.get("dropout")) 21 | 22 | def forward(self, input): 23 | input = move_to_device(input) 24 | node_feats = input.ndata["x"] 25 | edge_feats = input.edata["x"] 26 | node_feats = super().forward(input, node_feats, edge_feats) 27 | if self.get_node_weight: 28 | graph_feats, _ = self.readout(input, node_feats, self.get_node_weight) 29 | else: 30 | graph_feats = self.readout(input, node_feats, self.get_node_weight) 31 | return graph_feats 32 | -------------------------------------------------------------------------------- /drugood/models/backbones/base_backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from mmcv.runner import BaseModule 5 | 6 | 7 | class BaseBackbone(BaseModule, metaclass=ABCMeta): 8 | """Base backbone. 9 | 10 | This class defines the basic functions of a backbone. Any backbone that 11 | inherits this class should at least define its own `forward` function. 12 | """ 13 | 14 | def __init__(self, init_cfg=None): 15 | super(BaseBackbone, self).__init__(init_cfg) 16 | 17 | @abstractmethod 18 | def forward(self, x): 19 | """Forward computation. 20 | 21 | Args: 22 | x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of 23 | Torch.tensor, containing input data for forward computation. 24 | """ 25 | pass 26 | 27 | def train(self, mode=True): 28 | """Set module status before forward computation. 29 | 30 | Args: 31 | mode (bool): Whether it is train_mode or test_mode 32 | """ 33 | super(BaseBackbone, self).train(mode) 34 | -------------------------------------------------------------------------------- /drugood/models/backbones/bert.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | import torch.nn as nn 3 | from transformers import BertModel 4 | 5 | from ..builder import BACKBONES 6 | from ...core import move_to_device 7 | 8 | 9 | @BACKBONES.register_module() 10 | class Bert(nn.Module): 11 | def __init__(self, model="data/berts/bert-base-uncased"): 12 | super(Bert, self).__init__() 13 | self.model = BertModel.from_pretrained(model) # gradient_checkpointing = True 14 | 15 | def forward(self, input): 16 | input = move_to_device(input) 17 | feats = self.model(**input) 18 | return feats["pooler_output"] 19 | -------------------------------------------------------------------------------- /drugood/models/backbones/mgcn.py: -------------------------------------------------------------------------------- 1 | # The model implementation is adopted from the dgllife library 2 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 3 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 4 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 5 | 6 | import torch.nn as nn 7 | from dgllife.model import MGCNGNN, MLPNodeReadout 8 | 9 | from ..builder import BACKBONES 10 | from ...core import move_to_device 11 | 12 | 13 | @BACKBONES.register_module() 14 | class MGCN(MGCNGNN): 15 | def __init__(self, 16 | feats=128, n_layers=3, classifier_hidden_feats=64, 17 | n_tasks=1, num_node_types=100, num_edge_types=3000, 18 | cutoff=5.0, gap=1.0, predictor_hidden_feats=64): 19 | 20 | if predictor_hidden_feats == 64 and classifier_hidden_feats != 64: 21 | print('classifier_hidden_feats is deprecated and will be removed in the future, ' 22 | 'use predictor_hidden_feats instead') 23 | predictor_hidden_feats = classifier_hidden_feats 24 | 25 | super(MGCNGNN, self).__init__(self, feats, n_layers, num_node_types, 26 | num_edge_types, cutoff, gap) 27 | 28 | self.readout = MLPNodeReadout(node_feats=(n_layers + 1) * feats, 29 | hidden_feats=predictor_hidden_feats, 30 | graph_feats=n_tasks, 31 | activation=nn.Softplus(beta=1, threshold=20)) 32 | 33 | def forward(self, input): 34 | input = move_to_device(input) 35 | node_feats = input.ndata["x"] 36 | node_feats = super().forward(g=input, feats=node_feats) 37 | if self.get_node_weight: 38 | graph_feats, _ = self.readout(input, node_feats, self.get_node_weight) 39 | else: 40 | graph_feats = self.readout(input, node_feats, self.get_node_weight) 41 | return graph_feats 42 | -------------------------------------------------------------------------------- /drugood/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 3 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 4 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 5 | 6 | from mmcv.cnn import MODELS as MMCV_MODELS 7 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 8 | from mmcv.utils import Registry 9 | 10 | MODELS = Registry('models', parent=MMCV_MODELS) 11 | 12 | BACKBONES = MODELS 13 | NECKS = MODELS 14 | HEADS = MODELS 15 | LOSSES = MODELS 16 | CLASSIFIERS = MODELS 17 | 18 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 19 | TASKERS = Registry('taskers') 20 | 21 | 22 | def build_backbone(cfg): 23 | """Build backbone.""" 24 | return BACKBONES.build(cfg) 25 | 26 | 27 | def build_neck(cfg): 28 | """Build neck.""" 29 | return NECKS.build(cfg) 30 | 31 | 32 | def build_head(cfg): 33 | """Build head.""" 34 | return HEADS.build(cfg) 35 | 36 | 37 | def build_losses(cfg): 38 | """Build loss.""" 39 | if not isinstance(cfg, list): 40 | cfg = [cfg] 41 | return [LOSSES.build(_cfg) for _cfg in cfg] 42 | 43 | 44 | def build_tasker(cfg): 45 | return TASKERS.build(cfg) 46 | 47 | 48 | def build_model(cfg): 49 | """Build Models""" 50 | return MODELS.build(cfg) 51 | -------------------------------------------------------------------------------- /drugood/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Classification 3 | from .cls_head import ClsHead 4 | # Linear Classification 5 | from .linear_head import LinearClsHead 6 | # Multi-label Classification 7 | from .multi_label_head import MultiLabelClsHead 8 | from .multi_label_linear_head import MultiLabelLinearClsHead 9 | # Regression and Linear Regression 10 | from .reg_head import RegHead, LinearRegHead 11 | 12 | __all__ = [ 13 | 'ClsHead', 'LinearClsHead', 'MultiLabelClsHead', 14 | 'MultiLabelLinearClsHead', 15 | 'RegHead', 'LinearRegHead', 16 | ] 17 | -------------------------------------------------------------------------------- /drugood/models/heads/base_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from mmcv.runner import BaseModule 5 | 6 | 7 | class BaseHead(BaseModule, metaclass=ABCMeta): 8 | """Base head.""" 9 | 10 | def __init__(self, init_cfg=None): 11 | super(BaseHead, self).__init__(init_cfg) 12 | 13 | @abstractmethod 14 | def forward_train(self, x, gt_label, **kwargs): 15 | pass 16 | -------------------------------------------------------------------------------- /drugood/models/heads/cls_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from drugood.models.losses import Accuracy 6 | from .base_head import BaseHead 7 | from ..builder import HEADS, build_losses 8 | from ..utils import is_tracing 9 | 10 | 11 | @HEADS.register_module() 12 | class ClsHead(BaseHead): 13 | """classification head. 14 | 15 | Args: 16 | loss (dict): Config of classification loss. 17 | topk (int | tuple): Top-k accuracy. 18 | cal_acc (bool): Whether to calculate accuracy during training. 19 | If you use Mixup/CutMix or something like that during training, 20 | it is not reasonable to calculate accuracy. Defaults to False. 21 | """ 22 | 23 | def __init__(self, 24 | loss=None, 25 | topk=(1,), 26 | cal_acc=False, 27 | init_cfg=None): 28 | if (loss is None): 29 | loss = dict(type='CrossEntropyLoss', loss_weight=1.0) 30 | super(ClsHead, self).__init__(init_cfg=init_cfg) 31 | 32 | assert isinstance(loss, (dict, list)) 33 | assert isinstance(topk, (int, tuple)) 34 | if isinstance(topk, int): 35 | topk = (topk,) 36 | for _topk in topk: 37 | assert _topk > 0, 'Top-k should be larger than 0' 38 | self.topk = topk 39 | self.losses = build_losses(loss) 40 | self.compute_accuracy = Accuracy(topk=self.topk) 41 | self.cal_acc = cal_acc 42 | 43 | def loss(self, cls_score, gt_label): 44 | num_samples = len(cls_score) 45 | losses = dict() 46 | # compute loss 47 | for _loss in self.losses: 48 | name = _loss.__class__.__name__.replace("Loss", "_loss").lower() 49 | losses[name] = _loss(cls_score, gt_label, avg_factor=num_samples) 50 | if self.cal_acc: 51 | # compute accuracy 52 | acc = self.compute_accuracy(cls_score, gt_label) 53 | assert len(acc) == len(self.topk) 54 | losses['accuracy'] = { 55 | f'top-{k}': a 56 | for k, a in zip(self.topk, acc) 57 | } 58 | return losses 59 | 60 | def forward_test(self, cls_score): 61 | if isinstance(cls_score, tuple): 62 | cls_score = cls_score[-1] 63 | if isinstance(cls_score, list): 64 | cls_score = sum(cls_score) / float(len(cls_score)) 65 | pred = F.softmax(cls_score, dim=1) if cls_score is not None else None 66 | return pred 67 | 68 | def forward_train(self, cls_score, gt_label): 69 | if isinstance(cls_score, tuple): 70 | cls_score = cls_score[-1] 71 | losses = self.loss(cls_score, gt_label) 72 | return losses 73 | 74 | def simple_test(self, cls_score): 75 | if isinstance(cls_score, tuple): 76 | cls_score = cls_score[-1] 77 | if isinstance(cls_score, list): 78 | cls_score = sum(cls_score) / float(len(cls_score)) 79 | pred = F.softmax(cls_score, dim=1) if cls_score is not None else None 80 | return self.post_process(pred) 81 | 82 | def post_process(self, pred): 83 | on_trace = is_tracing() 84 | if torch.onnx.is_in_onnx_export() or on_trace: 85 | return pred 86 | pred = list(pred.detach().cpu().numpy()) 87 | return pred 88 | -------------------------------------------------------------------------------- /drugood/models/heads/linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cls_head import ClsHead 6 | from ..builder import HEADS 7 | 8 | 9 | @HEADS.register_module() 10 | class LinearClsHead(ClsHead): 11 | """Linear classifier head. 12 | Args: 13 | num_classes (int): Number of categories excluding the background 14 | category. 15 | in_channels (int): Number of channels in the input feature map. 16 | init_cfg (dict | optional): The extra init config of layers. 17 | Defaults to use dict(type='Normal', layer='Linear', std=0.01). 18 | """ 19 | 20 | def __init__(self, 21 | num_classes, 22 | in_channels, 23 | init_cfg=None, 24 | *args, 25 | **kwargs): 26 | super(LinearClsHead, self).__init__(init_cfg=init_cfg, *args, **kwargs) 27 | if (init_cfg is None): 28 | init_cfg = dict(type='Normal', layer='Linear', std=0.01) 29 | self.in_channels = in_channels 30 | self.num_classes = num_classes 31 | 32 | if self.num_classes <= 0: 33 | raise ValueError( 34 | f'num_classes={num_classes} must be a positive integer') 35 | 36 | self.fc = nn.Linear(self.in_channels, self.num_classes) 37 | 38 | def simple_test(self, x): 39 | """Test without augmentation.""" 40 | if isinstance(x, tuple): 41 | x = x[-1] 42 | cls_score = self.fc(x) 43 | if isinstance(cls_score, list): 44 | cls_score = sum(cls_score) / float(len(cls_score)) 45 | return self.post_process(cls_score) 46 | 47 | def forward_train(self, x, gt_label): 48 | if isinstance(x, tuple): 49 | x = x[-1] 50 | cls_score = self.fc(x) 51 | losses = self.loss(cls_score, gt_label) 52 | return losses 53 | 54 | def forward_test(self, x): 55 | if isinstance(x, tuple): 56 | x = x[-1] 57 | logits = self.fc(x) 58 | return logits 59 | 60 | def post_process(self, logits): 61 | pred = F.softmax(logits, dim=1) 62 | pred = list(pred.detach().cpu().numpy()) 63 | return pred 64 | -------------------------------------------------------------------------------- /drugood/models/heads/multi_label_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from .base_head import BaseHead 6 | from ..builder import HEADS, build_losses 7 | from ..utils import is_tracing 8 | 9 | 10 | @HEADS.register_module() 11 | class MultiLabelClsHead(BaseHead): 12 | """Classification head for multilabel tasks. 13 | 14 | Args: 15 | loss (dict): Config of classification loss. 16 | """ 17 | 18 | def __init__(self, 19 | loss=None, 20 | init_cfg=None): 21 | if (loss is None): 22 | loss = dict( 23 | type='CrossEntropyLoss', 24 | use_sigmoid=True, 25 | reduction='mean', 26 | loss_weight=1.0) 27 | super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg) 28 | 29 | assert isinstance(loss, dict) 30 | 31 | self.losses = build_losses(loss) 32 | 33 | def loss(self, cls_score, gt_label): 34 | gt_label = gt_label.type_as(cls_score) 35 | num_samples = len(cls_score) 36 | losses = dict() 37 | 38 | # map difficult examples to positive ones 39 | _gt_label = torch.abs(gt_label) 40 | # compute loss 41 | for _loss in self.losses: 42 | losses[_loss.__class__.__name__] = _loss(cls_score, gt_label, avg_factor=num_samples) 43 | return losses 44 | 45 | def forward_train(self, cls_score, gt_label): 46 | if isinstance(cls_score, tuple): 47 | cls_score = cls_score[-1] 48 | gt_label = gt_label.type_as(cls_score) 49 | losses = self.loss(cls_score, gt_label) 50 | return losses 51 | 52 | def simple_test(self, x): 53 | if isinstance(x, tuple): 54 | x = x[-1] 55 | if isinstance(x, list): 56 | x = sum(x) / float(len(x)) 57 | pred = F.sigmoid(x) if x is not None else None 58 | 59 | return self.post_process(pred) 60 | 61 | def post_process(self, pred): 62 | pred = F.sigmoid(pred) 63 | on_trace = is_tracing() 64 | if torch.onnx.is_in_onnx_export() or on_trace: 65 | return pred 66 | pred = list(pred.detach().cpu().numpy()) 67 | return pred 68 | -------------------------------------------------------------------------------- /drugood/models/heads/multi_label_linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .multi_label_head import MultiLabelClsHead 6 | from ..builder import HEADS 7 | 8 | 9 | @HEADS.register_module() 10 | class MultiLabelLinearClsHead(MultiLabelClsHead): 11 | """Linear classification head for multilabel tasks. 12 | 13 | Args: 14 | num_classes (int): Number of categories. 15 | in_channels (int): Number of channels in the input feature map. 16 | loss (dict): Config of classification loss. 17 | init_cfg (dict | optional): The extra init config of layers. 18 | Defaults to use dict(type='Normal', layer='Linear', std=0.01). 19 | """ 20 | 21 | def __init__(self, 22 | num_classes, 23 | in_channels, 24 | loss=None, 25 | init_cfg=None): 26 | if (loss is None): 27 | loss = dict( 28 | type='CrossEntropyLoss', 29 | use_sigmoid=True, 30 | reduction='mean', 31 | loss_weight=1.0) 32 | if (init_cfg is None): 33 | init_cfg = dict(type='Normal', layer='Linear', std=0.01) 34 | super(MultiLabelLinearClsHead, self).__init__( 35 | loss=loss, init_cfg=init_cfg) 36 | 37 | if num_classes <= 0: 38 | raise ValueError( 39 | f'num_classes={num_classes} must be a positive integer') 40 | 41 | self.in_channels = in_channels 42 | self.num_classes = num_classes 43 | 44 | self.fc = nn.Linear(self.in_channels, self.num_classes) 45 | 46 | def forward_train(self, x, gt_label): 47 | if isinstance(x, tuple): 48 | x = x[-1] 49 | gt_label = gt_label.type_as(x) 50 | cls_score = self.fc(x) 51 | losses = self.loss(cls_score, gt_label) 52 | return losses 53 | 54 | def simple_test(self, x): 55 | """Test without augmentation.""" 56 | if isinstance(x, tuple): 57 | x = x[-1] 58 | cls_score = self.fc(x) 59 | if isinstance(cls_score, list): 60 | cls_score = sum(cls_score) / float(len(cls_score)) 61 | return self.post_process(cls_score) 62 | 63 | def forward_test(self, x): 64 | if isinstance(x, tuple): 65 | x = x[-1] 66 | logits = self.fc(x) 67 | return logits 68 | 69 | def post_process(self, pred): 70 | pred = F.sigmoid(pred) 71 | pred = list(pred.detach().cpu().numpy()) 72 | return pred 73 | -------------------------------------------------------------------------------- /drugood/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) OpenMMLab. All rights reserved. 3 | """ 4 | from .accuracy import Accuracy, accuracy 5 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, cross_entropy) 6 | from .error import Error, error 7 | from .focal_loss import FocalLoss, sigmoid_focal_loss 8 | from .label_smooth_loss import LabelSmoothLoss 9 | from .mean_squared_error_loss import MeanSquaredLoss 10 | from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss, weighted_loss) 11 | 12 | __all__ = [ 13 | 'accuracy', 'Accuracy', 14 | 'error', 'Error', 15 | 'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 16 | 'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss', 17 | 'sigmoid_focal_loss', 'convert_to_one_hot', 'MeanSquaredLoss' 18 | ] 19 | -------------------------------------------------------------------------------- /drugood/models/losses/error.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def error_numpy(pred, target, metric="mae"): 9 | if metric == "mae": 10 | err = np.abs((pred - target)).mean() 11 | elif metric == "mse": 12 | err = np.abs(np.square(pred - target)).mean() 13 | else: 14 | raise TypeError(f"type should be mse or mae but got {metric}") 15 | return err 16 | 17 | 18 | def error_torch(pred, target, metric="mae"): 19 | if metric == "mae": 20 | err = torch.abs(pred - target).mean().data 21 | elif metric == "mse": 22 | err = (torch.square(pred - target)).mean().data 23 | else: 24 | raise TypeError(f"type should be mse or mae but got {metric}") 25 | return err 26 | 27 | 28 | def error(pred, target, metric="mae"): 29 | if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor): 30 | res = error_torch(pred, target, metric) 31 | elif isinstance(pred, np.ndarray) and isinstance(target, np.ndarray): 32 | res = error_numpy(pred, target, metric) 33 | else: 34 | raise TypeError( 35 | f'pred and target should both be torch.Tensor or np.ndarray, ' 36 | f'but got {type(pred)} and {type(target)}.') 37 | 38 | return res 39 | 40 | 41 | class Error(nn.Module): 42 | def __init__(self, metric="mae"): 43 | """Module to calculate the error. 44 | 45 | Args: 46 | type (str): The criterion used to calculate the 47 | error. Defaults to "mae". 48 | """ 49 | super().__init__() 50 | self.metric = metric 51 | 52 | def forward(self, pred, target): 53 | """Forward function to calculate error. 54 | 55 | Args: 56 | pred (torch.Tensor): Prediction of models. 57 | target (torch.Tensor): Target for each prediction. 58 | 59 | Returns: 60 | [float]: The error. 61 | """ 62 | return error(pred, target, self.metric) 63 | -------------------------------------------------------------------------------- /drugood/models/losses/mean_squared_error_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 3 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 4 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .utils import weight_reduce_loss 10 | from ..builder import LOSSES 11 | 12 | 13 | def mean_squared_error(pred, 14 | label, 15 | weight=None, 16 | reduction='mean', 17 | avg_factor=None, 18 | class_weight=None, 19 | **kwargs): 20 | """Calculate the MeanSquared loss. 21 | 22 | Args: 23 | pred (torch.Tensor): The prediction with shape (N, 1) 24 | label (torch.Tensor): The gt label of the prediction (N, 1). 25 | weight (torch.Tensor, optional): Sample-wise loss weight. 26 | reduction (str): The method used to reduce the loss. 27 | avg_factor (int, optional): Average factor that is used to average 28 | the loss. Defaults to None. 29 | class_weight (torch.Tensor, optional): The weight for each class with 30 | shape (C), C is the number of classes. Default None. 31 | # TODO class weight may be used for solving long tail problem in regress problem 32 | Returns: 33 | torch.Tensor: The calculated loss 34 | """ 35 | # element-wise losses 36 | if (pred.dim() == 2): 37 | pred = pred.squeeze() 38 | if (label.dim() == 2): 39 | label = label.squeeze() 40 | loss = F.mse_loss(pred, label, reduction='none') 41 | # apply weights and do the reduction 42 | if weight is not None: 43 | weight = weight.float() 44 | loss = weight_reduce_loss( 45 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor) 46 | 47 | return loss 48 | 49 | 50 | @LOSSES.register_module() 51 | class MeanSquaredLoss(nn.Module): 52 | def __init__(self, 53 | reduction='mean', 54 | loss_weight=1.0): 55 | super(MeanSquaredLoss, self).__init__() 56 | self.reduction = reduction 57 | self.loss_weight = loss_weight 58 | self.cls_criterion = mean_squared_error 59 | 60 | def forward(self, 61 | cls_score, 62 | label, 63 | weight=None, 64 | avg_factor=None, 65 | reduction_override=None, 66 | **kwargs): 67 | assert reduction_override in (None, 'none', 'mean', 'sum') 68 | reduction = ( 69 | reduction_override if reduction_override else self.reduction) 70 | 71 | cls_score = cls_score.to(torch.float32) 72 | label = label.to(torch.float32) 73 | 74 | loss_cls = self.loss_weight * self.cls_criterion( 75 | cls_score, 76 | label, 77 | weight, 78 | reduction=reduction, 79 | avg_factor=avg_factor, 80 | **kwargs) 81 | return loss_cls 82 | -------------------------------------------------------------------------------- /drugood/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .cat import Concatenate 3 | from .gap import GlobalAveragePooling 4 | 5 | __all__ = ['GlobalAveragePooling', "Concatenate"] 6 | -------------------------------------------------------------------------------- /drugood/models/necks/cat.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..builder import NECKS 6 | 7 | 8 | @NECKS.register_module() 9 | class Concatenate(nn.Module): 10 | def __init__(self, dim=1): 11 | super(Concatenate, self).__init__() 12 | self.dim = dim 13 | 14 | def init_weights(self): 15 | pass 16 | 17 | def forward(self, inputs): 18 | assert isinstance(inputs, list) 19 | return torch.cat(inputs, dim=self.dim) 20 | -------------------------------------------------------------------------------- /drugood/models/necks/gap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..builder import NECKS 6 | 7 | 8 | @NECKS.register_module() 9 | class GlobalAveragePooling(nn.Module): 10 | """Global Average Pooling neck. 11 | 12 | Note that we use `view` to remove extra channel after pooling. We do not 13 | use `squeeze` as it will also remove the batch dimension when the tensor 14 | has a batch dimension of size 1, which can lead to unexpected errors. 15 | 16 | Args: 17 | dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}. 18 | Default: 2 19 | """ 20 | 21 | def __init__(self, dim=2): 22 | super(GlobalAveragePooling, self).__init__() 23 | assert dim in [1, 2, 3], 'GlobalAveragePooling dim only support ' \ 24 | f'{1, 2, 3}, get {dim} instead.' 25 | if dim == 1: 26 | self.gap = nn.AdaptiveAvgPool1d(1) 27 | elif dim == 2: 28 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) 29 | else: 30 | self.gap = nn.AdaptiveAvgPool3d((1, 1, 1)) 31 | 32 | def init_weights(self): 33 | pass 34 | 35 | def forward(self, inputs): 36 | if isinstance(inputs, tuple): 37 | outs = tuple([self.gap(x) for x in inputs]) 38 | outs = tuple( 39 | [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) 40 | elif isinstance(inputs, torch.Tensor): 41 | outs = self.gap(inputs) 42 | outs = outs.view(inputs.size(0), -1) 43 | else: 44 | raise TypeError('neck inputs should be tuple or torch.tensor') 45 | return outs 46 | -------------------------------------------------------------------------------- /drugood/models/taskers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 3 | # The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). 4 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 5 | from .base import BaseClassifier 6 | from .classifier import Classifier 7 | from .regressor import Regressor, MIRegressor 8 | 9 | __all__ = ['BaseClassifier', 'Classifier', 10 | 'Regressor', 'MIRegressor'] 11 | -------------------------------------------------------------------------------- /drugood/models/taskers/classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | import warnings 3 | from abc import ABCMeta 4 | 5 | from mmcv.runner import BaseModule 6 | 7 | from ..builder import TASKERS, build_backbone, build_head, build_neck 8 | 9 | 10 | @TASKERS.register_module() 11 | class Classifier(BaseModule, metaclass=ABCMeta): 12 | def __init__(self, 13 | backbone, 14 | aux_backbone=None, 15 | neck=None, 16 | head=None, 17 | pretrained=None, 18 | train_cfg=None, 19 | init_cfg=None): 20 | super(Classifier, self).__init__(init_cfg) 21 | 22 | if pretrained is not None: 23 | warnings.warn('DeprecationWarning: pretrained is a deprecated \ 24 | key, please consider using init_cfg') 25 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) 26 | 27 | self.backbone = build_backbone(backbone) 28 | 29 | if aux_backbone is not None: 30 | self.aux_backbone = build_backbone(aux_backbone) 31 | 32 | if neck is not None: 33 | self.neck = build_neck(neck) 34 | 35 | if head is not None: 36 | self.head = build_head(head) 37 | 38 | @property 39 | def with_neck(self): 40 | return hasattr(self, 'neck') and self.neck is not None 41 | 42 | @property 43 | def with_head(self): 44 | return hasattr(self, 'head') and self.head is not None 45 | 46 | @property 47 | def with_aux_backbone(self): 48 | return hasattr(self, 'aux_backbone') and self.aux_backbone is not None 49 | 50 | def extract_feat(self, input, aux_input=None, **kwargs): 51 | feats = self.backbone(input) 52 | 53 | if self.with_aux_backbone and aux_input is not None: 54 | feats = [feats, self.aux_backbone(aux_input)] 55 | 56 | if self.with_neck: 57 | feats = self.neck(feats) 58 | 59 | return feats 60 | -------------------------------------------------------------------------------- /drugood/models/taskers/regressor.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | import warnings 3 | from abc import ABCMeta 4 | 5 | import torch 6 | from mmcv.runner import BaseModule 7 | 8 | from ..builder import TASKERS, build_backbone, build_head, build_neck 9 | 10 | 11 | @TASKERS.register_module() 12 | class Regressor(BaseModule, metaclass=ABCMeta): 13 | def __init__(self, 14 | backbone, 15 | neck=None, 16 | head=None, 17 | pretrained=None, 18 | init_cfg=None): 19 | super(Regressor, self).__init__(init_cfg) 20 | 21 | if pretrained is not None: 22 | warnings.warn('DeprecationWarning: pretrained is a deprecated \ 23 | key, please consider using init_cfg') 24 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) 25 | 26 | self.backbone = build_backbone(backbone) 27 | 28 | if neck is not None: 29 | self.neck = build_neck(neck) 30 | 31 | if head is not None: 32 | self.head = build_head(head) 33 | 34 | @property 35 | def with_neck(self): 36 | return hasattr(self, 'neck') and self.neck is not None 37 | 38 | @property 39 | def with_head(self): 40 | return hasattr(self, 'head') and self.head is not None 41 | 42 | def extract_feat(self, input): 43 | x = self.backbone(input) 44 | if self.with_neck: 45 | x = self.neck(x) 46 | return x 47 | 48 | 49 | @TASKERS.register_module() 50 | class MIRegressor(Regressor): 51 | def __init__(self, 52 | aux_backbone, 53 | aux_neck=None, 54 | **kwargs 55 | ): 56 | super(MIRegressor, self).__init__(**kwargs) 57 | self.aux_backbone = build_backbone(aux_backbone) 58 | if aux_neck is not None: 59 | self.aux_neck = build_neck(aux_neck) 60 | 61 | @property 62 | def with_aux_neck(self): 63 | return hasattr(self, 'aux_neck') and self.aux_neck is not None 64 | 65 | def extract_feat(self, input, aux_input, **kwargs): 66 | feats = super().extract_feat(input) 67 | aux_feats = self.aux_backbone(aux_input) 68 | feats = torch.cat([feats, aux_feats], dim=1) 69 | # if self.with_aux_neck: 70 | # aux_feats = self.aux_neck(aux_feats) 71 | return feats 72 | -------------------------------------------------------------------------------- /drugood/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .augment.augments import Augments 3 | from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple 4 | 5 | __all__ = [ 6 | 'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 7 | 'Augments', 'is_tracing' 8 | ] 9 | -------------------------------------------------------------------------------- /drugood/models/utils/augment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .augments import Augments 3 | from .cutmix import BatchCutMixLayer 4 | from .identity import Identity 5 | from .mixup import BatchMixupLayer 6 | 7 | __all__ = ['Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer'] 8 | -------------------------------------------------------------------------------- /drugood/models/utils/augment/augments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import random 3 | 4 | import numpy as np 5 | 6 | from .builder import build_augment 7 | 8 | 9 | class Augments(object): 10 | """Data augments. 11 | 12 | We implement some data augmentation methods, such as mixup, cutmix. 13 | 14 | Args: 15 | augments_cfg (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`): 16 | Config dict of augments 17 | 18 | Example: 19 | >>> augments_cfg = [ 20 | dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5), 21 | dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3) 22 | ] 23 | >>> augments = Augments(augments_cfg) 24 | >>> imgs = torch.randn(16, 3, 32, 32) 25 | >>> label = torch.randint(0, 10, (16, )) 26 | >>> imgs, label = augments(imgs, label) 27 | 28 | To decide which augmentation within Augments block is used 29 | the following rule is applied. 30 | We pick augmentation based on the probabilities. In the example above, 31 | we decide if we should use BatchCutMix with probability 0.5, 32 | BatchMixup 0.3. As Identity is not in augments_cfg, we use Identity with 33 | probability 1 - 0.5 - 0.3 = 0.2. 34 | """ 35 | 36 | def __init__(self, augments_cfg): 37 | super(Augments, self).__init__() 38 | 39 | if isinstance(augments_cfg, dict): 40 | augments_cfg = [augments_cfg] 41 | 42 | assert len(augments_cfg) > 0, \ 43 | 'The length of augments_cfg should be positive.' 44 | self.augments = [build_augment(cfg) for cfg in augments_cfg] 45 | self.augment_probs = [aug.prob for aug in self.augments] 46 | 47 | has_identity = any([cfg['type'] == 'Identity' for cfg in augments_cfg]) 48 | if has_identity: 49 | assert sum(self.augment_probs) == 1.0, \ 50 | 'The sum of augmentation probabilities should equal to 1,' \ 51 | ' but got {:.2f}'.format(sum(self.augment_probs)) 52 | else: 53 | assert sum(self.augment_probs) <= 1.0, \ 54 | 'The sum of augmentation probabilities should less than or ' \ 55 | 'equal to 1, but got {:.2f}'.format(sum(self.augment_probs)) 56 | identity_prob = 1 - sum(self.augment_probs) 57 | if identity_prob > 0: 58 | num_classes = self.augments[0].num_classes 59 | self.augments += [ 60 | build_augment( 61 | dict( 62 | type='Identity', 63 | num_classes=num_classes, 64 | prob=identity_prob)) 65 | ] 66 | self.augment_probs += [identity_prob] 67 | 68 | def __call__(self, img, gt_label): 69 | if self.augments: 70 | random_state = np.random.RandomState(random.randint(0, 2 ** 32 - 1)) 71 | aug = random_state.choice(self.augments, p=self.augment_probs) 72 | return aug(img, gt_label) 73 | return img, gt_label 74 | -------------------------------------------------------------------------------- /drugood/models/utils/augment/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | AUGMENT = Registry('augment') 5 | 6 | 7 | def build_augment(cfg, default_args=None): 8 | return build_from_cfg(cfg, AUGMENT, default_args) 9 | -------------------------------------------------------------------------------- /drugood/models/utils/augment/identity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn.functional as F 3 | 4 | from .builder import AUGMENT 5 | 6 | 7 | @AUGMENT.register_module(name='Identity') 8 | class Identity(object): 9 | """Change gt_label to one_hot encoding and keep img as the same. 10 | 11 | Args: 12 | num_classes (int): The number of classes. 13 | prob (float): MixUp probability. It should be in range [0, 1]. 14 | Default to 1.0 15 | """ 16 | 17 | def __init__(self, num_classes, prob=1.0): 18 | super(Identity, self).__init__() 19 | 20 | assert isinstance(num_classes, int) 21 | assert isinstance(prob, float) and 0.0 <= prob <= 1.0 22 | 23 | self.num_classes = num_classes 24 | self.prob = prob 25 | 26 | def one_hot(self, gt_label): 27 | return F.one_hot(gt_label, num_classes=self.num_classes) 28 | 29 | def __call__(self, img, gt_label): 30 | return img, self.one_hot(gt_label) 31 | -------------------------------------------------------------------------------- /drugood/models/utils/augment/mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from .builder import AUGMENT 9 | 10 | 11 | class BaseMixupLayer(object, metaclass=ABCMeta): 12 | """Base class for MixupLayer. 13 | 14 | Args: 15 | alpha (float): Parameters for Beta distribution. 16 | num_classes (int): The number of classes. 17 | prob (float): MixUp probability. It should be in range [0, 1]. 18 | Default to 1.0 19 | """ 20 | 21 | def __init__(self, alpha, num_classes, prob=1.0): 22 | super(BaseMixupLayer, self).__init__() 23 | 24 | assert isinstance(alpha, float) and alpha > 0 25 | assert isinstance(num_classes, int) 26 | assert isinstance(prob, float) and 0.0 <= prob <= 1.0 27 | 28 | self.alpha = alpha 29 | self.num_classes = num_classes 30 | self.prob = prob 31 | 32 | @abstractmethod 33 | def mixup(self, imgs, gt_label): 34 | pass 35 | 36 | 37 | @AUGMENT.register_module(name='BatchMixup') 38 | class BatchMixupLayer(BaseMixupLayer): 39 | """Mixup layer for batch mixup.""" 40 | 41 | def __init__(self, *args, **kwargs): 42 | super(BatchMixupLayer, self).__init__(*args, **kwargs) 43 | 44 | def mixup(self, img, gt_label): 45 | one_hot_gt_label = F.one_hot(gt_label, num_classes=self.num_classes) 46 | lam = np.random.beta(self.alpha, self.alpha) 47 | batch_size = img.size(0) 48 | index = torch.randperm(batch_size) 49 | 50 | mixed_img = lam * img + (1 - lam) * img[index, :] 51 | mixed_gt_label = lam * one_hot_gt_label + ( 52 | 1 - lam) * one_hot_gt_label[index, :] 53 | 54 | return mixed_img, mixed_gt_label 55 | 56 | def __call__(self, img, gt_label): 57 | return self.mixup(img, gt_label) 58 | -------------------------------------------------------------------------------- /drugood/models/utils/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import collections.abc 3 | import warnings 4 | from distutils.version import LooseVersion 5 | from itertools import repeat 6 | 7 | import torch 8 | 9 | 10 | def is_tracing() -> bool: 11 | if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'): 12 | on_trace = torch.jit.is_tracing() 13 | # In PyTorch 1.6, torch.jit.is_tracing has a bug. 14 | # Refers to https://github.com/pytorch/pytorch/issues/42448 15 | if isinstance(on_trace, bool): 16 | return on_trace 17 | else: 18 | return torch._C._is_tracing() 19 | else: 20 | warnings.warn( 21 | 'torch.jit.is_tracing is only supported after v1.6.0. ' 22 | 'Therefore is_tracing returns False automatically. Please ' 23 | 'set on_trace manually if you are using trace.', UserWarning) 24 | return False 25 | 26 | 27 | # From PyTorch internals 28 | def _ntuple(n): 29 | def parse(x): 30 | if isinstance(x, collections.abc.Iterable): 31 | return x 32 | return tuple(repeat(x, n)) 33 | 34 | return parse 35 | 36 | 37 | to_1tuple = _ntuple(1) 38 | to_2tuple = _ntuple(2) 39 | to_3tuple = _ntuple(3) 40 | to_4tuple = _ntuple(4) 41 | to_ntuple = _ntuple 42 | -------------------------------------------------------------------------------- /drugood/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .logger import get_root_logger 4 | from .smile_to_dgl import smile2graph 5 | 6 | __all__ = ['collect_env', 'get_root_logger', "smile2graph"] 7 | -------------------------------------------------------------------------------- /drugood/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 3 | # All Tencent Modifications are Copyright (C) THL A29 Limited. 4 | 5 | import dgl 6 | import dgllife 7 | import rdkit 8 | import torch_geometric 9 | import torch_scatter 10 | import torch_sparse 11 | from mmcv.utils import collect_env as collect_base_env 12 | from mmcv.utils import get_git_hash 13 | 14 | import drugood 15 | 16 | 17 | def collect_drugood_env(): 18 | """Collect the information of the environment related with graph and drug data . 19 | 20 | Returns: 21 | dict: The environment information. The following fields are contained. 22 | - DGL: DGL version 23 | - DGL Life: DGL Life version 24 | - Rdkit: Rdkit version 25 | - Torch Geometric: Torch Geometric version 26 | - Torch Sparse: Torch Sparse version 27 | - Torch Scatter: Torch Scatter version 28 | """ 29 | env_info = {} 30 | env_info['DGL'] = dgl.__version__ 31 | env_info['DGL Life'] = dgllife.__version__ 32 | env_info['Rdkit'] = rdkit.__version__ 33 | env_info['Torch Geometric'] = torch_geometric.__version__ 34 | env_info['Torch Sparse'] = torch_sparse.__version__ 35 | env_info['Torch Scatter'] = torch_scatter.__version__ 36 | return env_info 37 | 38 | 39 | def collect_env(): 40 | """Collect the information of the running environments.""" 41 | env_info = collect_base_env() 42 | env_info.update(**collect_drugood_env()) 43 | print(drugood.__version__) 44 | env_info['DrugOOD'] = drugood.__version__ + '+' + get_git_hash()[:7] 45 | return env_info 46 | 47 | 48 | if __name__ == '__main__': 49 | for name, val in collect_env().items(): 50 | print(f'{name}: {val}') 51 | -------------------------------------------------------------------------------- /drugood/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. 3 | # All rights reserved. All Tencent Modifications are Copyright (C) THL A29 Limited. 4 | import logging 5 | 6 | from mmcv.utils import get_logger 7 | 8 | 9 | def get_root_logger(log_file=None, log_level=logging.INFO): 10 | return get_logger('drugood', log_file, log_level) 11 | -------------------------------------------------------------------------------- /drugood/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved 2 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. 3 | # All rights reserved. All Tencent Modifications are Copyright (C) THL A29 Limited. 4 | __version__ = '0.0.1' 5 | 6 | 7 | def parse_version_info(version_str): 8 | """Parse a version string into a tuple. 9 | 10 | Args: 11 | version_str (str): The version string. 12 | Returns: 13 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into 14 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). 15 | """ 16 | version_info = [] 17 | for x in version_str.split('.'): 18 | if x.isdigit(): 19 | version_info.append(int(x)) 20 | elif x.find('rc') != -1: 21 | patch_version = x.split('rc') 22 | version_info.append(int(patch_version[0])) 23 | version_info.append(f'rc{patch_version[1]}') 24 | return tuple(version_info) 25 | 26 | 27 | version_info = parse_version_info(__version__) 28 | 29 | __all__ = ['__version__', 'version_info', 'parse_version_info'] 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/optional.txt 2 | -r requirements/runtime.txt 3 | -r requirements/tests.txt 4 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | docutils==0.16.0 2 | recommonmark 3 | sphinx==4.0.2 4 | sphinx_markdown_tables 5 | sphinx_rtd_theme==0.5.2 6 | -------------------------------------------------------------------------------- /requirements/mminstall.txt: -------------------------------------------------------------------------------- 1 | mmcv-full>=1.3.8,<=1.5.0 2 | -------------------------------------------------------------------------------- /requirements/optional.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/DrugOOD/61a70e4ad1fb227e4a264ed5ba87d4c78fdb4ae7/requirements/optional.txt -------------------------------------------------------------------------------- /requirements/readthedocs.txt: -------------------------------------------------------------------------------- 1 | mmcv>=1.3.8 2 | torch 3 | torchvision 4 | -------------------------------------------------------------------------------- /requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | mmcv 2 | Cython 3 | matplotlib 4 | numpy 5 | wilds 6 | dgllife 7 | PyTDC 8 | torch_geometric 9 | torch_sparse 10 | transformers 11 | rdkit-pypi 12 | dgl-cu110 -------------------------------------------------------------------------------- /requirements/tests.txt: -------------------------------------------------------------------------------- 1 | codecov 2 | flake8 3 | interrogate 4 | isort==4.3.21 5 | pytest 6 | xdoctest >= 0.10.0 7 | yapf 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal = 1 3 | 4 | [aliases] 5 | test = pytest 6 | 7 | [yapf] 8 | based_on_style = pep8 9 | blank_line_before_nested_class_or_def = true 10 | split_before_expression_after_opening_paren = true 11 | 12 | [isort] 13 | line_length = 79 14 | multi_line_output = 0 15 | known_standard_library = pkg_resources,setuptools 16 | known_first_party = mmcls 17 | known_third_party = PIL,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts 18 | no_lines_before = STDLIB,LOCALFOLDER 19 | default_section = THIRDPARTY 20 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/DrugOOD/61a70e4ad1fb227e4a264ed5ba87d4c78fdb4ae7/tools/__init__.py -------------------------------------------------------------------------------- /tools/analysis_tools/eval_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | import mmcv 5 | from mmcv import Config, DictAction 6 | 7 | from drugood.datasets import build_dataset 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Evaluate metric of the ' 12 | 'results saved in pkl format') 13 | parser.add_argument('config', help='Config of the model') 14 | parser.add_argument('pkl_results', help='Results in pickle format') 15 | parser.add_argument( 16 | '--metrics', 17 | type=str, 18 | nargs='+', 19 | help='Evaluation metrics, which depends on the dataset, e.g., ' 20 | '"accuracy", "precision", "recall" and "support".') 21 | parser.add_argument( 22 | '--cfg-options', 23 | nargs='+', 24 | action=DictAction, 25 | help='override some settings in the used config, the key-value pair ' 26 | 'in xxx=yyy format will be merged into config file. If the value to ' 27 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 28 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 29 | 'Note that the quotation marks are necessary and that no white space ' 30 | 'is allowed.') 31 | parser.add_argument( 32 | '--eval-options', 33 | nargs='+', 34 | action=DictAction, 35 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 36 | 'format will be kwargs for dataset.evaluate() function') 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | def main(): 42 | args = parse_args() 43 | 44 | cfg = Config.fromfile(args.config) 45 | assert args.metrics, ( 46 | 'Please specify at least one metric the argument "--metrics".') 47 | 48 | if args.cfg_options is not None: 49 | cfg.merge_from_dict(args.cfg_options) 50 | # import modules from string list. 51 | if cfg.get('custom_imports', None): 52 | from mmcv.utils import import_modules_from_strings 53 | import_modules_from_strings(**cfg['custom_imports']) 54 | cfg.data.test.test_mode = True 55 | 56 | dataset = build_dataset(cfg.data.test) 57 | outputs = mmcv.load(args.pkl_results) 58 | pred_score = outputs['class_scores'] 59 | 60 | kwargs = {} if args.eval_options is None else args.eval_options 61 | eval_kwargs = cfg.get('evaluation', {}).copy() 62 | # hard-code way to remove EvalHook args 63 | for key in [ 64 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule' 65 | ]: 66 | eval_kwargs.pop(key, None) 67 | eval_kwargs.update(dict(metric=args.metrics, **kwargs)) 68 | print(dataset.evaluate(pred_score, **eval_kwargs)) 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /tools/analysis_tools/get_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config 5 | from mmcv.cnn.utils import get_model_complexity_info 6 | 7 | from drugood.models import build_classifier 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Get model flops and params') 12 | parser.add_argument('config', help='config file path') 13 | parser.add_argument( 14 | '--shape', 15 | type=int, 16 | nargs='+', 17 | default=[224, 224], 18 | help='input image size') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def main(): 24 | args = parse_args() 25 | 26 | if len(args.shape) == 1: 27 | input_shape = (3, args.shape[0], args.shape[0]) 28 | elif len(args.shape) == 2: 29 | input_shape = (3,) + tuple(args.shape) 30 | else: 31 | raise ValueError('invalid input shape') 32 | 33 | cfg = Config.fromfile(args.config) 34 | model = build_classifier(cfg.model) 35 | model.eval() 36 | 37 | if hasattr(model, 'extract_feat'): 38 | model.forward = model.extract_feat 39 | else: 40 | raise NotImplementedError( 41 | 'FLOPs counter is currently not currently supported with {}'. 42 | format(model.__class__.__name__)) 43 | 44 | flops, params = get_model_complexity_info(model, input_shape) 45 | split_line = '=' * 30 46 | print(f'{split_line}\nInput shape: {input_shape}\n' 47 | f'Flops: {flops}\nParams: {params}\n{split_line}') 48 | print('!!!Please be cautious if you use the results in papers. ' 49 | 'You may need to check if all ops are supported and verify that the ' 50 | 'flops computation is correct.') 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /tools/analysis_tools/parse_logs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os.path as osp 4 | import pathlib 5 | from glob import glob 6 | 7 | import numpy as np 8 | from prettytable import PrettyTable 9 | 10 | 11 | def read_log(log_dir, best_metric="val:accuracy"): 12 | # find out the logs 13 | logs = glob(f"{log_dir}/*.log") 14 | logs.sort(key=osp.getmtime) 15 | target_log = logs[-1] 16 | 17 | with open(target_log) as f: 18 | f = f.readlines() 19 | f.reverse() 20 | for idx, l in enumerate(f): 21 | if f'Best {best_metric}' in l: 22 | return l, f[idx - 1] 23 | 24 | 25 | def log_metrics(info): 26 | group_table = PrettyTable() 27 | group_table.field_names = ["subset", *info[next(iter(info))].keys()] 28 | 29 | if info.get("seed") is None: 30 | group_table.title = 'exp result' 31 | else: 32 | group_table.title = f'seed {info.get("seed")}' 33 | info.pop("seed") 34 | 35 | for subset, metrics in info.items(): 36 | metric_values = [float(v) for k, v, in metrics.items()] 37 | group_table.add_row([subset, *metric_values]) 38 | 39 | print(group_table.get_string()) 40 | 41 | 42 | def parse_args(): 43 | parser = argparse.ArgumentParser(description='None') 44 | parser.add_argument('--work_dir', help='work dir of exp', 45 | default="/apdcephfs/share_1364275/yuanfengji/project/ood/work_dirs/erm/20210926_camelyon17_erm") 46 | parser.add_argument('--best_metric', default="val:accuracy") 47 | parser.add_argument("--ignore_metrics", help="metrics need to ignore") 48 | args = parser.parse_args() 49 | return args 50 | 51 | 52 | def main(): 53 | args = parse_args() 54 | assert args.best_metric, ( 55 | 'Please specify at best metric the argument "--best_metric".') 56 | 57 | folder_list = glob(f"{args.work_dir}/[0-9]/") 58 | print(f"Exp dir : {args.work_dir}") 59 | 60 | if not folder_list: 61 | folder_list.append(args.work_dir) 62 | folder_list.sort() 63 | 64 | infos = [] 65 | metrics = dict() 66 | seeds = [] 67 | 68 | for folder in folder_list: 69 | _, others = read_log(folder, args.best_metric) 70 | others = others.split(',') 71 | others.__delitem__(0) 72 | 73 | info = {} 74 | for item in others: 75 | if "\t" in item: 76 | item = item.split("\t")[1] 77 | if '\n' in item: 78 | item = item.split("\n")[0] 79 | d, m, v = item.replace(" ", "").split(":") 80 | 81 | if d not in info.keys(): 82 | info[d] = {m: v} 83 | else: 84 | info[d].update({m: v}) 85 | 86 | path = pathlib.PurePath(folder).name 87 | 88 | try: 89 | info["seed"] = int(path) 90 | seeds.append(int(path)) 91 | except ValueError as e: 92 | logging.exception(e) 93 | 94 | log_metrics(info) 95 | infos.append(info) 96 | 97 | print(f"Total {len(infos)} seed: {seeds}") 98 | 99 | for info in infos: 100 | for d, ms in info.items(): 101 | for m, v in ms.items(): 102 | if f'{d}:{m}' not in metrics.keys(): 103 | metrics[f'{d}:{m}'] = [float(v)] 104 | else: 105 | metrics[f'{d}:{m}'].append(float(v)) 106 | 107 | for metric, metric_value in metrics.items(): 108 | print(f'{metric} mean: {np.asarray(metric_value).mean():.2f} ' 109 | f'std: {np.asarray(metric_value).std():.2f} ' 110 | f'var: {np.asarray(metric_value).var():.2f}') 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | -------------------------------------------------------------------------------- /tools/curate.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 2 | 3 | import argparse 4 | import json 5 | import os.path as osp 6 | import random 7 | 8 | import numpy as np 9 | import tqdm 10 | from mmcv import Config, print_log 11 | 12 | from drugood.curators import GenericCurator 13 | from drugood.apis import set_random_seed 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description='Generate Dataset') 18 | parser.add_argument('config', help='curator config file path') 19 | parser.add_argument('--seed', type=int, default=12345, help='random seed (please not change it)') 20 | parser.add_argument('--deterministic', action='store_true', 21 | help='whether to set deterministic options for CUDNN backend.') 22 | args = parser.parse_args() 23 | return args 24 | 25 | def main(): 26 | args = parse_args() 27 | cfg = Config.fromfile(args.config) 28 | print_log(f'Curator Config:\n{cfg.pretty_text}''\n' + '-' * 60) 29 | 30 | # set random seed 31 | if args.seed is not None: 32 | print_log(f'Set random seed to {args.seed}, deterministic: {args.deterministic}') 33 | set_random_seed(args.seed, deterministic=args.deterministic) 34 | 35 | curator = GenericCurator(cfg) 36 | # Processing Flow 37 | data = curator.data_loading() 38 | data = curator.noise_filtering(data) 39 | data = curator.uncertainty_processing(data) 40 | data = curator.classification_label_generating(data) 41 | data = curator.data_splitting(data) 42 | curator.data_saving(data) 43 | curator.statistics_reporting() 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29500} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /tools/misc/print_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import warnings 4 | 5 | from mmcv import Config, DictAction 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description='Print the whole config') 10 | parser.add_argument('config', help='config file path') 11 | parser.add_argument( 12 | '--options', 13 | nargs='+', 14 | action=DictAction, 15 | help='override some settings in the used config, the key-value pair ' 16 | 'in xxx=yyy format will be merged into config file (deprecate), ' 17 | 'change to --cfg-options instead.') 18 | parser.add_argument( 19 | '--cfg-options', 20 | nargs='+', 21 | action=DictAction, 22 | help='override some settings in the used config, the key-value pair ' 23 | 'in xxx=yyy format will be merged into config file. If the value to ' 24 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 25 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 26 | 'Note that the quotation marks are necessary and that no white space ' 27 | 'is allowed.') 28 | args = parser.parse_args() 29 | 30 | if args.options and args.cfg_options: 31 | raise ValueError( 32 | '--options and --cfg-options cannot be both ' 33 | 'specified, --options is deprecated in favor of --cfg-options') 34 | if args.options: 35 | warnings.warn('--options is deprecated in favor of --cfg-options') 36 | args.cfg_options = args.options 37 | 38 | return args 39 | 40 | 41 | def main(): 42 | args = parse_args() 43 | 44 | cfg = Config.fromfile(args.config) 45 | if args.cfg_options is not None: 46 | cfg.merge_from_dict(args.cfg_options) 47 | # import modules from string list. 48 | if cfg.get('custom_imports', None): 49 | from mmcv.utils import import_modules_from_strings 50 | import_modules_from_strings(**cfg['custom_imports']) 51 | print(f'Config:\n{cfg.pretty_text}') 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | SRUN_ARGS=${SRUN_ARGS:-""} 13 | PY_ARGS=${@:5} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 25 | --------------------------------------------------------------------------------