├── .gitignore ├── LICENSE ├── README.md ├── data └── .keep ├── download-data.sh ├── experiments └── .keep ├── requirements.txt └── src ├── .envrc ├── dimenet ├── __init__.py ├── const.py ├── functional.py ├── loader.py ├── logging.py ├── loss.py ├── modules │ ├── __init__.py │ ├── bessel_basis_layer.py │ ├── dimenet.py │ ├── embedding_block.py │ ├── envelope.py │ ├── interaction_block.py │ ├── output_block.py │ └── spherical_basis_layer.py └── params.py ├── mylib ├── __init__.py ├── gcp │ ├── __init__.py │ └── util.py ├── lgb │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ └── model_extraction.py │ ├── metrics.py │ ├── null_imp.py │ └── util.py ├── numpy │ ├── __init__.py │ └── functional.py ├── pandas │ ├── __init__.py │ ├── cache.py │ ├── corr.py │ └── util.py ├── params.py ├── sklearn │ ├── __init__.py │ ├── fe │ │ ├── __init__.py │ │ ├── pair_count_encoder.py │ │ └── target_encoder.py │ └── split.py ├── torch │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ └── dataset.py │ ├── fe │ │ ├── __init__.py │ │ └── bert_emb.py │ ├── nn │ │ ├── __init__.py │ │ ├── functional.py │ │ ├── init.py │ │ ├── mish_init.py │ │ └── modules │ │ │ ├── __init__.py │ │ │ ├── dense.py │ │ │ ├── gauss_rank_transform.py │ │ │ ├── mlp.py │ │ │ ├── pair_norm.py │ │ │ └── se_layer.py │ ├── optim │ │ ├── SM3.py │ │ ├── __init__.py │ │ └── sched.py │ └── tools │ │ ├── __init__.py │ │ ├── ema │ │ ├── __init__.py │ │ └── utils.py │ │ ├── lr_finder.py │ │ └── swa │ │ ├── __init__.py │ │ └── utils.py └── utils │ ├── __init__.py │ ├── plt.py │ └── text.py ├── params └── 001.yaml ├── run_create_db.py └── run_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | .idea/ 134 | input/ 135 | .DS_Store 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 akirasosa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DimeNet PyTorch 2 | 3 | This repository is [DimeNet](https://arxiv.org/abs/2003.03123) PyTorch version which is ported from the original [TensorFlow repo](https://github.com/klicperajo/dimenet). 4 | 5 | ## Getting Started 6 | 7 | ``` 8 | # Download processed QM9 data. 9 | ./download-data.sh 10 | 11 | # Train model to predict mu. 12 | cd src 13 | python run_train.py params/001.yaml 14 | ``` 15 | 16 | ## Results 17 | 18 | Epochs 800 is used. 19 | 20 | |Target|Unit|MAE| 21 | |---|---|---| 22 | | mu | Debye | 0.0285 | 23 | | U0 | meV | TODO | 24 | 25 | 26 | ## Differences from original 27 | 28 | * Use RAdam as an optimizer. 29 | * Use Mish as an activation. 30 | * The number of layers and n_hidden in OutputBlock might be different. 31 | * The loss func might be different. 32 | * Data splitting might be different. 33 | 34 | ## Cite 35 | 36 | ``` 37 | @inproceedings{klicpera_dimenet_2020, 38 | title = {Directional Message Passing for Molecular Graphs}, 39 | author = {Klicpera, Johannes and Gro{\ss}, Janek and G{\"u}nnemann, Stephan}, 40 | booktitle={International Conference on Learning Representations (ICLR)}, 41 | year = {2020} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /data/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/data/.keep -------------------------------------------------------------------------------- /download-data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | mkdir -p data/pytorch-dimenet 4 | wget "https://www.dropbox.com/s/fifvs2gpdnocxxr/qm9.parquet?dl=1" -O data/pytorch-dimenet/qm9.parquet 5 | -------------------------------------------------------------------------------- /experiments/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/experiments/.keep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | cachetools==4.1.1 3 | certifi==2020.6.20 4 | chardet==3.0.4 5 | cycler==0.10.0 6 | dacite==1.5.1 7 | future==0.18.2 8 | google-auth==1.20.0 9 | google-auth-oauthlib==0.4.1 10 | grpcio==1.30.0 11 | idna==2.10 12 | joblib==0.16.0 13 | kiwisolver==1.2.0 14 | Markdown==3.2.2 15 | matplotlib==3.3.0 16 | mpmath==1.1.0 17 | numpy==1.19.1 18 | oauthlib==3.1.0 19 | omegaconf==2.0.0 20 | pandas==1.1.0 21 | Pillow==7.2.0 22 | protobuf==3.12.4 23 | pyarrow==1.0.0 24 | pyasn1==0.4.8 25 | pyasn1-modules==0.2.8 26 | pyparsing==2.4.7 27 | python-dateutil==2.8.1 28 | pytorch-lightning==0.8.5 29 | pytorch-ranger==0.1.1 30 | pytz==2020.1 31 | PyYAML==5.3.1 32 | requests==2.24.0 33 | requests-oauthlib==1.3.0 34 | rsa==4.6 35 | scikit-learn==0.23.1 36 | scipy==1.5.2 37 | six==1.15.0 38 | sympy==1.6.1 39 | tensorboard==2.3.0 40 | tensorboard-plugin-wit==1.7.0 41 | threadpoolctl==2.1.0 42 | timm==0.1.30 43 | torch==1.6.0 44 | torch-optimizer==0.0.1a14 45 | torch-scatter==2.0.5 46 | torchvision==0.7.0 47 | tqdm==4.48.1 48 | typing-extensions==3.7.4.2 49 | urllib3==1.25.10 50 | Werkzeug==1.0.1 -------------------------------------------------------------------------------- /src/.envrc: -------------------------------------------------------------------------------- 1 | source $CONDA_HOME/etc/profile.d/conda.sh 2 | conda activate dimenet 3 | -------------------------------------------------------------------------------- /src/dimenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/dimenet/__init__.py -------------------------------------------------------------------------------- /src/dimenet/const.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | DATA_DIR = Path('../data') 4 | EXP_DIR = Path('../experiments') 5 | 6 | DATA_QM9_DIR = DATA_DIR / 'qm9' # https://www.kaggle.com/zaharch/quantum-machine-9-aka-qm9 7 | DATA_CSC_DIR = DATA_DIR / 'champs-scalar-coupling' # https://www.kaggle.com/c/champs-scalar-coupling/data 8 | DATA_PROCESSED_DIR = DATA_DIR / 'pytorch-dimenet' 9 | 10 | STRUCTURES_CSV = DATA_CSC_DIR / 'structures.csv' 11 | QM9_DB = DATA_PROCESSED_DIR / 'qm9.parquet' 12 | 13 | ATOM_MAP = { 14 | 'H': 1, 15 | 'C': 6, 16 | 'N': 7, 17 | 'O': 8, 18 | 'F': 9, 19 | } 20 | -------------------------------------------------------------------------------- /src/dimenet/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def calculate_interatomic_distances(R, idx_i, idx_j): 6 | Ri = R[idx_i] 7 | Rj = R[idx_j] 8 | # ReLU prevents negative numbers in sqrt 9 | Dij = torch.sqrt(F.relu(torch.sum((Ri - Rj) ** 2, -1))) 10 | return Dij 11 | 12 | 13 | def calculate_neighbor_angles(R, id3_i, id3_j, id3_k): 14 | """Calculate angles for neighboring atom triplets""" 15 | Ri = R[id3_i] 16 | Rj = R[id3_j] 17 | Rk = R[id3_k] 18 | R1 = Rj - Ri 19 | R2 = Rk - Ri 20 | x = torch.sum(R1 * R2, dim=-1) 21 | y = torch.cross(R1, R2) 22 | y = torch.norm(y, dim=-1) 23 | angle = torch.atan2(y, x) 24 | return angle 25 | -------------------------------------------------------------------------------- /src/dimenet/loader.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Dict, Union, Callable, Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import scipy.sparse as sp 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | 11 | @dataclasses.dataclass 12 | class AtomsBatch: 13 | batch_seg: torch.Tensor 14 | 15 | R: torch.Tensor 16 | Z: torch.Tensor 17 | 18 | idnb_i: torch.Tensor 19 | idnb_j: torch.Tensor 20 | id3dnb_i: torch.Tensor 21 | id3dnb_j: torch.Tensor 22 | id3dnb_k: torch.Tensor 23 | id_expand_kj: torch.Tensor 24 | id_reduce_ji: torch.Tensor 25 | 26 | rc_A: Optional[torch.Tensor] = None 27 | rc_B: Optional[torch.Tensor] = None 28 | rc_C: Optional[torch.Tensor] = None 29 | mu: Optional[torch.Tensor] = None 30 | alpha: Optional[torch.Tensor] = None 31 | homo: Optional[torch.Tensor] = None 32 | lumo: Optional[torch.Tensor] = None 33 | gap: Optional[torch.Tensor] = None 34 | r2: Optional[torch.Tensor] = None 35 | zpve: Optional[torch.Tensor] = None 36 | U0: Optional[torch.Tensor] = None 37 | U: Optional[torch.Tensor] = None 38 | H: Optional[torch.Tensor] = None 39 | G: Optional[torch.Tensor] = None 40 | Cv: Optional[torch.Tensor] = None 41 | mulliken: Optional[torch.Tensor] = None 42 | 43 | def __getitem__(self, item: str): 44 | return getattr(self, item) 45 | 46 | @staticmethod 47 | def from_dict(params: Dict, device: Union[str, torch.device]): 48 | return AtomsBatch(**{ 49 | k: v.to(device) 50 | for k, v in params.items() 51 | }) 52 | 53 | 54 | def to_tensor(v: np.ndarray): 55 | if v.dtype in [np.float64, np.float32]: 56 | return torch.from_numpy(v).float() 57 | return torch.from_numpy(v).long() 58 | 59 | 60 | def _concat(to_stack): 61 | """ function to stack (or concatentate) depending on dimensions """ 62 | if np.asarray(to_stack[0]).ndim >= 2: 63 | return np.concatenate(to_stack) 64 | else: 65 | return np.hstack(to_stack) 66 | 67 | 68 | def _bmat_fast(mats): 69 | new_data = np.concatenate([mat.data for mat in mats]) 70 | 71 | ind_offset = np.zeros(1 + len(mats)) 72 | ind_offset[1:] = np.cumsum([mat.shape[0] for mat in mats]) 73 | new_indices = np.concatenate( 74 | [mats[i].indices + ind_offset[i] for i in range(len(mats))]) 75 | 76 | indptr_offset = np.zeros(1 + len(mats)) 77 | indptr_offset[1:] = np.cumsum([mat.nnz for mat in mats]) 78 | new_indptr = np.concatenate( 79 | [mats[i].indptr[i >= 1:] + indptr_offset[i] for i in range(len(mats))]) 80 | return sp.csr_matrix((new_data, new_indices, new_indptr)) 81 | 82 | 83 | def _restore_shape(data): 84 | for d in data: 85 | d['R'] = d['R'].reshape(-1, 3) 86 | return data 87 | 88 | 89 | def get_loader( 90 | dataset: Dataset, 91 | cutoff: float = 5., 92 | post_fn: Callable = to_tensor, 93 | **kwargs, 94 | ) -> DataLoader: 95 | collate_fn = AtomsCollate(cutoff=cutoff, post_fn=post_fn) 96 | return DataLoader(dataset, collate_fn=collate_fn, **kwargs) 97 | 98 | 99 | @dataclasses.dataclass 100 | class AtomsCollate: 101 | post_fn: Callable 102 | cutoff: float = 5. 103 | 104 | def __call__(self, examples): 105 | examples = _restore_shape(examples) 106 | 107 | data = { 108 | k: _concat([examples[n][k] for n in range(len(examples))]) 109 | for k in examples[0].keys() 110 | } 111 | 112 | adj_matrices = [] 113 | for i, e in enumerate(examples): 114 | R = e['R'] 115 | D = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1) 116 | adj_matrices.append(sp.csr_matrix(D <= self.cutoff)) 117 | adj_matrices[-1] -= sp.eye(len(e['Z']), dtype=np.bool) 118 | 119 | adj_matrix = _bmat_fast(adj_matrices) 120 | atomids_to_edgeid = sp.csr_matrix( 121 | (np.arange(adj_matrix.nnz), adj_matrix.indices, adj_matrix.indptr), 122 | shape=adj_matrix.shape) 123 | edgeid_to_target, edgeid_to_source = adj_matrix.nonzero() 124 | 125 | # Target (i) and source (j) nodes of edges 126 | data['idnb_i'] = edgeid_to_target 127 | data['idnb_j'] = edgeid_to_source 128 | 129 | # Indices of triplets k->j->i 130 | ntriplets = adj_matrix[edgeid_to_source].sum(1).A1 131 | id3ynb_i = np.repeat(edgeid_to_target, ntriplets) 132 | id3ynb_j = np.repeat(edgeid_to_source, ntriplets) 133 | id3ynb_k = adj_matrix[edgeid_to_source].nonzero()[1] 134 | 135 | # Indices of triplets that are not i->j->i 136 | id3_y_to_d, = (id3ynb_i != id3ynb_k).nonzero() 137 | data['id3dnb_i'] = id3ynb_i[id3_y_to_d] 138 | data['id3dnb_j'] = id3ynb_j[id3_y_to_d] 139 | data['id3dnb_k'] = id3ynb_k[id3_y_to_d] 140 | 141 | # Edge indices for interactions 142 | # j->i => k->j 143 | data['id_expand_kj'] = atomids_to_edgeid[edgeid_to_source, :].data[id3_y_to_d] 144 | # j->i => k->j => j->i 145 | data['id_reduce_ji'] = atomids_to_edgeid[edgeid_to_source, :].tocoo().row[id3_y_to_d] 146 | 147 | N = [len(e['Z']) for e in examples] 148 | data['batch_seg'] = np.repeat(np.arange(len(examples)), N) 149 | 150 | return { 151 | k: self.post_fn(v) 152 | for k, v in data.items() 153 | if k != 'name' 154 | } 155 | 156 | 157 | # %% 158 | if __name__ == '__main__': 159 | # %% 160 | from mylib.torch.data.dataset import PandasDataset 161 | from dimenet.const import QM9_DB 162 | 163 | # %% 164 | df = pd.read_parquet(QM9_DB, columns=[ 165 | 'R', 166 | 'Z', 167 | 'U0', 168 | ]) 169 | 170 | # %% 171 | dataset = PandasDataset(df) 172 | loader = get_loader(dataset, batch_size=2, shuffle=False, cutoff=5.) 173 | 174 | for batch in loader: 175 | batch = AtomsBatch.from_dict(batch, device='cpu') 176 | print(batch.R, batch['U0']) 177 | break 178 | -------------------------------------------------------------------------------- /src/dimenet/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging.config import dictConfig 3 | 4 | 5 | def configure_logging(): 6 | dictConfig({ 7 | 'version': 1, 8 | 'formatters': { 9 | 'customFormat': { 10 | 'format': '%(asctime)s - %(levelname)s - %(filename)s - %(name)s - %(funcName)s - %(message)s', 11 | }, 12 | }, 13 | 'handlers': { 14 | 'customFileHandler': { 15 | 'class': 'logging.FileHandler', 16 | 'filename': '../lightning.log', 17 | 'formatter': 'customFormat', 18 | 'level': logging.DEBUG, 19 | }, 20 | }, 21 | 'loggers': { 22 | 'lightning': { 23 | 'handlers': ['customFileHandler'], 24 | 'level': logging.INFO, 25 | 'propagate': 0 26 | }, 27 | }, 28 | }) 29 | -------------------------------------------------------------------------------- /src/dimenet/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mae_loss(y_pred, y_true): 5 | err = torch.abs(y_true - y_pred) 6 | mae = err.mean() 7 | return mae 8 | -------------------------------------------------------------------------------- /src/dimenet/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/dimenet/modules/__init__.py -------------------------------------------------------------------------------- /src/dimenet/modules/bessel_basis_layer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from dimenet.modules.envelope import Envelope 6 | 7 | 8 | class BesselBasisLayer(nn.Module): 9 | def __init__(self, num_radial, cutoff, envelope_exponent=5): 10 | super(BesselBasisLayer, self).__init__() 11 | self.num_radial = num_radial 12 | self.cutoff = cutoff 13 | self.envelope = Envelope(envelope_exponent) 14 | freq_init = np.pi * torch.arange(1, num_radial + 1) 15 | self.frequencies = nn.Parameter(freq_init) 16 | 17 | def forward(self, inputs): 18 | d_scaled = inputs / self.cutoff 19 | d_scaled = torch.unsqueeze(d_scaled, -1) 20 | d_cutoff = self.envelope(d_scaled) 21 | return d_cutoff * torch.sin(self.frequencies * d_scaled) 22 | -------------------------------------------------------------------------------- /src/dimenet/modules/dimenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from timm.models.layers import mish 3 | from torch_scatter import scatter_add 4 | 5 | from dimenet.functional import calculate_interatomic_distances, calculate_neighbor_angles 6 | from dimenet.loader import get_loader, AtomsBatch 7 | from dimenet.modules.bessel_basis_layer import BesselBasisLayer 8 | from dimenet.modules.embedding_block import EmbeddingBlock 9 | from dimenet.modules.interaction_block import InteractionBlock 10 | from dimenet.modules.output_block import OutputBlock 11 | from dimenet.modules.spherical_basis_layer import SphericalBasisLayer 12 | from mylib.torch.data.dataset import PandasDataset 13 | from mylib.torch.nn.mish_init import init_weights 14 | 15 | 16 | class DimeNet(nn.Module): 17 | def __init__( 18 | self, 19 | emb_size=128, 20 | num_blocks=6, 21 | num_bilinear=8, 22 | num_spherical=7, 23 | num_radial=6, 24 | cutoff=5.0, 25 | envelope_exponent=5, 26 | num_before_skip=1, 27 | num_after_skip=2, 28 | num_dense_output=3, 29 | num_targets=12, 30 | activation=mish, 31 | ): 32 | super(DimeNet, self).__init__() 33 | self.num_blocks = num_blocks 34 | 35 | self.rbf_layer = BesselBasisLayer( 36 | num_radial, 37 | cutoff=cutoff, 38 | envelope_exponent=envelope_exponent, 39 | ) 40 | self.sbf_layer = SphericalBasisLayer( 41 | num_spherical, 42 | num_radial, 43 | cutoff=cutoff, 44 | envelope_exponent=envelope_exponent, 45 | ) 46 | self.emb_block = EmbeddingBlock( 47 | emb_size, 48 | num_radial=num_radial, 49 | activation=activation, 50 | ) 51 | self.output_blocks = nn.ModuleList([ 52 | OutputBlock( 53 | emb_size, 54 | num_radial=num_radial, 55 | n_layers=num_dense_output, 56 | n_out=num_targets, 57 | activation=activation, 58 | ) 59 | for _ in range(num_blocks + 1) 60 | ]) 61 | self.int_blocks = nn.ModuleList([ 62 | InteractionBlock( 63 | emb_size, 64 | num_radial=num_radial, 65 | num_spherical=num_spherical, 66 | num_bilinear=num_bilinear, 67 | num_before_skip=num_before_skip, 68 | num_after_skip=num_after_skip, 69 | activation=activation, 70 | ) 71 | for _ in range(num_blocks) 72 | ]) 73 | 74 | init_weights(self) 75 | 76 | def forward(self, inputs: AtomsBatch): 77 | Z = inputs.Z 78 | R = inputs.R 79 | idnb_i = inputs.idnb_i 80 | idnb_j = inputs.idnb_j 81 | id3dnb_i = inputs.id3dnb_i 82 | id3dnb_j = inputs.id3dnb_j 83 | id3dnb_k = inputs.id3dnb_k 84 | id_expand_kj = inputs.id_expand_kj 85 | id_reduce_ji = inputs.id_reduce_ji 86 | batch_seg = inputs.batch_seg 87 | 88 | # Calculate distances 89 | Dij = calculate_interatomic_distances(R, idnb_i, idnb_j) 90 | rbf = self.rbf_layer(Dij) 91 | 92 | # Calculate angles 93 | A_ijk = calculate_neighbor_angles(R, id3dnb_i, id3dnb_j, id3dnb_k) 94 | sbf = self.sbf_layer((Dij, A_ijk, id_expand_kj)) 95 | 96 | # Embedding block 97 | x = self.emb_block((Z, rbf, idnb_i, idnb_j)) 98 | P = self.output_blocks[0]((x, rbf, idnb_i)) 99 | 100 | # Interaction blocks 101 | for i in range(self.num_blocks): 102 | x = self.int_blocks[i]((x, rbf, sbf, id_expand_kj, id_reduce_ji)) 103 | P += self.output_blocks[i + 1]([x, rbf, idnb_i]) 104 | 105 | P = scatter_add(P, batch_seg, dim=0) 106 | # P = torch.zeros((n_batch, P.size(1))) \ 107 | # .type_as(P) \ 108 | # .scatter_add(0, batch_seg.unsqueeze(-1).expand_as(P), P) 109 | 110 | return P 111 | 112 | 113 | # %% 114 | if __name__ == '__main__': 115 | # %% 116 | import pandas as pd 117 | from dimenet.const import QM9_DB 118 | 119 | # %% 120 | df = pd.read_parquet(QM9_DB, columns=[ 121 | 'R', 122 | 'Z', 123 | 'U0', 124 | ]) 125 | dataset = PandasDataset(df) 126 | 127 | # %% 128 | loader = get_loader(dataset, batch_size=2, shuffle=False) 129 | model = DimeNet( 130 | 128, 131 | num_blocks=6, 132 | num_bilinear=8, 133 | num_spherical=7, 134 | num_radial=6, 135 | num_targets=3, 136 | ) 137 | 138 | for batch in loader: 139 | batch = AtomsBatch.from_dict(batch, device='cpu') 140 | outputs = model(batch) 141 | print(outputs.shape) 142 | print(batch.R.shape) 143 | break 144 | -------------------------------------------------------------------------------- /src/dimenet/modules/embedding_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from mylib.torch.nn.modules.dense import Dense 5 | 6 | 7 | class EmbeddingBlock(nn.Module): 8 | def __init__(self, emb_size, num_radial, activation=None): 9 | super(EmbeddingBlock, self).__init__() 10 | self.embedding = nn.Embedding(100, emb_size, padding_idx=0) 11 | self.dense_rbf = Dense(num_radial, emb_size, activation=activation) 12 | self.dense = Dense(emb_size * 3, emb_size, activation=activation) 13 | 14 | def forward(self, inputs): 15 | Z, rbf, idnb_i, idnb_j = inputs 16 | 17 | rbf = self.dense_rbf(rbf) 18 | x = self.embedding(Z) 19 | 20 | x1 = x[idnb_i] 21 | x2 = x[idnb_j] 22 | 23 | x = torch.cat((x1, x2, rbf), dim=-1) 24 | x = self.dense(x) 25 | 26 | return x 27 | -------------------------------------------------------------------------------- /src/dimenet/modules/envelope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Envelope(nn.Module): 6 | def __init__(self, exponent): 7 | super(Envelope, self).__init__() 8 | self.exponent = exponent 9 | self.p = exponent + 1 10 | self.a = -(self.p + 1) * (self.p + 2) / 2 11 | self.b = self.p * (self.p + 2) 12 | self.c = -self.p * (self.p + 1) / 2 13 | 14 | def forward(self, inputs): 15 | env_val = 1 / inputs \ 16 | + self.a * inputs ** (self.p - 1) \ 17 | + self.b * inputs ** self.p \ 18 | + self.c * inputs ** (self.p + 1) 19 | 20 | return torch.where(inputs < 1, env_val, torch.zeros_like(inputs)) 21 | -------------------------------------------------------------------------------- /src/dimenet/modules/interaction_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_scatter import scatter_add 4 | 5 | from mylib.torch.nn.modules.dense import Dense 6 | 7 | 8 | class ResidualLayer(nn.Module): 9 | def __init__(self, units, **kwargs): 10 | super(ResidualLayer, self).__init__() 11 | self.dense_1 = Dense(units, units, **kwargs) 12 | self.dense_2 = Dense(units, units, **kwargs) 13 | 14 | def forward(self, inputs): 15 | x = inputs + self.dense_2(self.dense_1(inputs)) 16 | return x 17 | 18 | 19 | class InteractionBlock(nn.Module): 20 | def __init__(self, emb_size, num_radial, num_spherical, num_bilinear, num_before_skip, num_after_skip, 21 | activation=None): 22 | super(InteractionBlock, self).__init__() 23 | self.emb_size = emb_size 24 | self.num_bilinear = num_bilinear 25 | 26 | self.dense_rbf = Dense(num_radial, emb_size, bias=False) 27 | self.dense_sbf = Dense(num_radial * num_spherical, num_bilinear, bias=False) 28 | 29 | self.dense_ji = Dense(emb_size, emb_size, activation=activation, bias=True) 30 | self.dense_kj = Dense(emb_size, emb_size, activation=activation, bias=True) 31 | 32 | bilin_initializer = torch.empty((self.emb_size, self.num_bilinear, self.emb_size)) \ 33 | .normal_(mean=0, std=2 / emb_size) 34 | self.W_bilin = nn.Parameter(bilin_initializer) 35 | 36 | self.layers_before_skip = nn.ModuleList([ 37 | ResidualLayer(emb_size, activation=activation, bias=True) 38 | for _ in range(num_before_skip) 39 | ]) 40 | 41 | self.final_before_skip = Dense(emb_size, emb_size, activation=activation, bias=True) 42 | 43 | self.layers_after_skip = nn.ModuleList([ 44 | ResidualLayer(emb_size, activation=activation, bias=True) 45 | for _ in range(num_after_skip) 46 | ]) 47 | 48 | def forward(self, inputs): 49 | x, rbf, sbf, id_expand_kj, id_reduce_ji = inputs 50 | # n_interactions = len(torch.unique(id_reduce_ji, sorted=False)) 51 | 52 | # Initial transformation 53 | x_ji = self.dense_ji(x) 54 | x_kj = self.dense_kj(x) 55 | 56 | # Transform via Bessel basis 57 | g = self.dense_rbf(rbf) 58 | x_kj = x_kj * g 59 | 60 | # Transform via spherical basis 61 | sbf = self.dense_sbf(sbf) 62 | x_kj = x_kj[id_expand_kj] 63 | 64 | # Apply bilinear layer to interactions and basis function activation 65 | x_kj = torch.einsum("wj,wl,ijl->wi", sbf, x_kj, self.W_bilin) 66 | 67 | x_kj = scatter_add(x_kj, id_reduce_ji, dim=0) # sum over messages 68 | # x_kj = torch.zeros((n_interactions, x_kj.size(1))) \ 69 | # .type_as(x_kj) \ 70 | # .scatter_add(0, id_reduce_ji.unsqueeze(-1).expand_as(x_kj), x_kj) 71 | 72 | # Transformations before skip connection 73 | x2 = x_ji + x_kj 74 | for layer in self.layers_before_skip: 75 | x2 = layer(x2) 76 | x2 = self.final_before_skip(x2) 77 | 78 | # Skip connection 79 | x = x + x2 80 | 81 | # Transformations after skip connection 82 | for layer in self.layers_after_skip: 83 | x = layer(x) 84 | return x 85 | 86 | 87 | # %% 88 | if __name__ == '__main__': 89 | import numpy as np 90 | 91 | # %% 92 | x = np.random.random((7, 64)).astype(np.float32) 93 | rbf = np.random.random((7, 6)).astype(np.float32) 94 | sbf = np.random.random((10, 42)).astype(np.float32) 95 | id_expand_kj = np.tile(np.arange(0, 7), 2)[:10] 96 | id_reduce_ji = np.tile(np.arange(0, 7), 2)[:10] 97 | 98 | # %% 99 | b = InteractionBlock(64, 6, 7, 8, 1, 2) 100 | b([ 101 | torch.from_numpy(x), 102 | torch.from_numpy(rbf), 103 | torch.from_numpy(sbf), 104 | torch.from_numpy(id_expand_kj), 105 | torch.from_numpy(id_reduce_ji), 106 | ]) 107 | -------------------------------------------------------------------------------- /src/dimenet/modules/output_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_scatter import scatter_add 3 | 4 | from mylib.torch.nn.modules.dense import Dense 5 | from mylib.torch.nn.modules.mlp import MLP 6 | 7 | 8 | class OutputBlock(nn.Module): 9 | def __init__(self, emb_size, num_radial, n_layers, n_out=12, activation=None): 10 | super(OutputBlock, self).__init__() 11 | self.dense_rbf = Dense(num_radial, emb_size, bias=False) 12 | self.mlp = MLP(emb_size, n_out, n_hidden=emb_size, n_layers=n_layers, activation=activation) 13 | 14 | def forward(self, inputs): 15 | x, rbf, idnb_i = inputs 16 | # n_atoms = len(torch.unique(idnb_i, sorted=False)) 17 | 18 | g = self.dense_rbf(rbf) 19 | x = g * x 20 | x = scatter_add(x, idnb_i, dim=0) 21 | # x = torch.zeros((n_atoms, x.size(1))) \ 22 | # .type_as(x) \ 23 | # .scatter_add(0, idnb_i.unsqueeze(-1).expand_as(x), x) 24 | x = self.mlp(x) 25 | return x 26 | -------------------------------------------------------------------------------- /src/dimenet/modules/spherical_basis_layer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sympy as sym 3 | import torch 4 | import torch.nn as nn 5 | from scipy import special as sp 6 | from scipy.optimize import brentq 7 | 8 | from dimenet.modules.envelope import Envelope 9 | 10 | 11 | def Jn(r, n): 12 | """ 13 | numerical spherical bessel functions of order n 14 | """ 15 | return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) 16 | 17 | 18 | def Jn_zeros(n, k): 19 | """ 20 | Compute the first k zeros of the spherical bessel functions up to order n (excluded) 21 | """ 22 | zerosj = np.zeros((n, k), dtype="float32") 23 | zerosj[0] = np.arange(1, k + 1) * np.pi 24 | points = np.arange(1, k + n) * np.pi 25 | racines = np.zeros(k + n - 1, dtype="float32") 26 | for i in range(1, n): 27 | for j in range(k + n - 1 - i): 28 | foo = brentq(Jn, points[j], points[j + 1], (i,)) 29 | racines[j] = foo 30 | points = racines 31 | zerosj[i][:k] = racines[:k] 32 | 33 | return zerosj 34 | 35 | 36 | def spherical_bessel_formulas(n): 37 | """ 38 | Computes the sympy formulas for the spherical bessel functions up to order n (excluded) 39 | """ 40 | x = sym.symbols('x') 41 | 42 | f = [sym.sin(x) / x] 43 | a = sym.sin(x) / x 44 | for i in range(1, n): 45 | b = sym.diff(a, x) / x 46 | f += [sym.simplify(b * (-x) ** i)] 47 | a = sym.simplify(b) 48 | return f 49 | 50 | 51 | def bessel_basis(n, k): 52 | """ 53 | Compute the sympy formulas for the normalized and rescaled spherical bessel functions up to 54 | order n (excluded) and maximum frequency k (excluded). 55 | """ 56 | 57 | zeros = Jn_zeros(n, k) 58 | normalizer = [] 59 | for order in range(n): 60 | normalizer_tmp = [] 61 | for i in range(k): 62 | normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1) ** 2] 63 | normalizer_tmp = 1 / np.array(normalizer_tmp) ** 0.5 64 | normalizer += [normalizer_tmp] 65 | 66 | f = spherical_bessel_formulas(n) 67 | x = sym.symbols('x') 68 | bess_basis = [] 69 | for order in range(n): 70 | bess_basis_tmp = [] 71 | for i in range(k): 72 | bess_basis_tmp += [sym.simplify(normalizer[order] 73 | [i] * f[order].subs(x, zeros[order, i] * x))] 74 | bess_basis += [bess_basis_tmp] 75 | return bess_basis 76 | 77 | 78 | def sph_harm_prefactor(l, m): 79 | """ 80 | Computes the constant pre-factor for the spherical harmonic of degree l and order m 81 | input: 82 | l: int, l>=0 83 | m: int, -l<=m<=l 84 | """ 85 | return ((2 * l + 1) * np.math.factorial(l - abs(m)) / (4 * np.pi * np.math.factorial(l + abs(m)))) ** 0.5 86 | 87 | 88 | def associated_legendre_polynomials(l, zero_m_only=True): 89 | """ 90 | Computes sympy formulas of the associated legendre polynomials up to order l (excluded). 91 | """ 92 | z = sym.symbols('z') 93 | P_l_m = [[0] * (j + 1) for j in range(l)] 94 | 95 | P_l_m[0][0] = 1 96 | if l > 0: 97 | P_l_m[1][0] = z 98 | 99 | for j in range(2, l): 100 | P_l_m[j][0] = sym.simplify( 101 | ((2 * j - 1) * z * P_l_m[j - 1][0] - (j - 1) * P_l_m[j - 2][0]) / j) 102 | if not zero_m_only: 103 | for i in range(1, l): 104 | P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]) 105 | if i + 1 < l: 106 | P_l_m[i + 1][i] = sym.simplify((2 * i + 1) * z * P_l_m[i][i]) 107 | for j in range(i + 2, l): 108 | P_l_m[j][i] = sym.simplify( 109 | ((2 * j - 1) * z * P_l_m[j - 1][i] - (i + j - 1) * P_l_m[j - 2][i]) / (j - i)) 110 | 111 | return P_l_m 112 | 113 | 114 | def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True): 115 | """ 116 | Computes formula strings of the the real part of the spherical harmonics up to order l (excluded). 117 | Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta. 118 | """ 119 | if not zero_m_only: 120 | S_m = [0] 121 | C_m = [1] 122 | for i in range(1, l): 123 | x = sym.symbols('x') 124 | y = sym.symbols('y') 125 | S_m += [x * S_m[i - 1] + y * C_m[i - 1]] 126 | C_m += [x * C_m[i - 1] - y * S_m[i - 1]] 127 | 128 | P_l_m = associated_legendre_polynomials(l, zero_m_only) 129 | if spherical_coordinates: 130 | theta = sym.symbols('theta') 131 | z = sym.symbols('z') 132 | for i in range(len(P_l_m)): 133 | for j in range(len(P_l_m[i])): 134 | if type(P_l_m[i][j]) != int: 135 | P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) 136 | if not zero_m_only: 137 | phi = sym.symbols('phi') 138 | for i in range(len(S_m)): 139 | S_m[i] = S_m[i].subs(x, sym.sin( 140 | theta) * sym.cos(phi)).subs(y, sym.sin(theta) * sym.sin(phi)) 141 | for i in range(len(C_m)): 142 | C_m[i] = C_m[i].subs(x, sym.sin( 143 | theta) * sym.cos(phi)).subs(y, sym.sin(theta) * sym.sin(phi)) 144 | 145 | Y_func_l_m = [['0'] * (2 * j + 1) for j in range(l)] 146 | for i in range(l): 147 | Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) 148 | 149 | if not zero_m_only: 150 | for i in range(1, l): 151 | for j in range(1, i + 1): 152 | Y_func_l_m[i][j] = sym.simplify( 153 | 2 ** 0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) 154 | for i in range(1, l): 155 | for j in range(1, i + 1): 156 | Y_func_l_m[i][-j] = sym.simplify( 157 | 2 ** 0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) 158 | 159 | return Y_func_l_m 160 | 161 | 162 | class SphericalBasisLayer(nn.Module): 163 | def __init__(self, num_spherical, num_radial, cutoff, envelope_exponent=5): 164 | super(SphericalBasisLayer, self).__init__() 165 | 166 | assert num_radial <= 64 167 | 168 | self.num_radial = num_radial 169 | self.num_spherical = num_spherical 170 | self.cutoff = cutoff 171 | self.envelope = Envelope(envelope_exponent) 172 | 173 | # retrieve formulas 174 | self.bessel_formulas = bessel_basis(num_spherical, num_radial) 175 | self.sph_harm_formulas = real_sph_harm(num_spherical) 176 | self.sph_funcs = [] 177 | self.bessel_funcs = [] 178 | 179 | x = sym.symbols('x') 180 | theta = sym.symbols('theta') 181 | modules = {'sin': torch.sin, 'cos': torch.cos} 182 | for i in range(num_spherical): 183 | if i == 0: 184 | first_sph = sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)(0) 185 | self.sph_funcs.append(lambda tensor: torch.zeros_like(tensor) + first_sph) 186 | else: 187 | self.sph_funcs.append(sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)) 188 | for j in range(num_radial): 189 | self.bessel_funcs.append(sym.lambdify([x], self.bessel_formulas[i][j], modules)) 190 | 191 | def forward(self, inputs): 192 | d, Angles, id_expand_kj = inputs 193 | 194 | d_scaled = d / self.cutoff 195 | rbf = [f(d_scaled) for f in self.bessel_funcs] 196 | rbf = torch.stack(rbf, dim=1) 197 | 198 | d_cutoff = self.envelope(d_scaled) 199 | rbf_env = d_cutoff[:, None] * rbf 200 | rbf_env = rbf_env[id_expand_kj.long()] 201 | 202 | cbf = [f(Angles) for f in self.sph_funcs] 203 | cbf = torch.stack(cbf, dim=1) 204 | cbf = cbf.repeat_interleave(self.num_radial, dim=1) 205 | 206 | return rbf_env * cbf 207 | -------------------------------------------------------------------------------- /src/dimenet/params.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Optional, List 3 | 4 | from dimenet.const import EXP_DIR, QM9_DB 5 | from mylib.params import ParamsMixIn 6 | 7 | 8 | @dataclasses.dataclass(frozen=True) 9 | class TrainerParams(ParamsMixIn): 10 | num_tpu_cores: Optional[int] = None 11 | gpus: Optional[List[int]] = None 12 | epochs: int = 100 13 | use_16bit: bool = False 14 | resume_from_checkpoint: Optional[str] = None 15 | save_dir: str = str(EXP_DIR) 16 | 17 | 18 | @dataclasses.dataclass(frozen=True) 19 | class ModuleParams(ParamsMixIn): 20 | target: str = 'mu' 21 | 22 | lr: float = 3e-4 23 | weight_decay: float = 1e-4 24 | 25 | batch_size: int = 32 26 | 27 | optim: str = 'radam' 28 | 29 | ema_decay: Optional[float] = None 30 | ema_eval_freq: int = 1 31 | 32 | fold: int = 0 33 | n_splits: Optional[int] = 4 34 | 35 | seed: int = 0 36 | 37 | db_path: str = str(QM9_DB) 38 | 39 | 40 | @dataclasses.dataclass(frozen=True) 41 | class Params(ParamsMixIn): 42 | module_params: ModuleParams 43 | trainer_params: TrainerParams 44 | note: str = '' 45 | 46 | @property 47 | def m(self): 48 | return self.module_params 49 | 50 | @property 51 | def t(self): 52 | return self.trainer_params 53 | 54 | 55 | # %% 56 | if __name__ == '__main__': 57 | # %% 58 | p = Params.load('params/pre_train/002.yaml') 59 | print(p) 60 | -------------------------------------------------------------------------------- /src/mylib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/__init__.py -------------------------------------------------------------------------------- /src/mylib/gcp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/gcp/__init__.py -------------------------------------------------------------------------------- /src/mylib/gcp/util.py: -------------------------------------------------------------------------------- 1 | from google.cloud import storage 2 | 3 | 4 | def upload_blob(project: str, bucket_name: str, source_file_name: str, destination_blob_name: str): 5 | storage_client = storage.Client(project=project) 6 | bucket = storage_client.bucket(bucket_name) 7 | blob = bucket.blob(destination_blob_name) 8 | 9 | blob.upload_from_filename(source_file_name) 10 | 11 | print( 12 | "File {} uploaded to {}.".format( 13 | source_file_name, destination_blob_name 14 | ) 15 | ) 16 | -------------------------------------------------------------------------------- /src/mylib/lgb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/lgb/__init__.py -------------------------------------------------------------------------------- /src/mylib/lgb/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/lgb/callbacks/__init__.py -------------------------------------------------------------------------------- /src/mylib/lgb/callbacks/model_extraction.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from lightgbm import Booster 4 | 5 | 6 | class ModelExtractionCallback(object): 7 | def __init__(self): 8 | self._model = None 9 | 10 | def __call__(self, env): 11 | self._model = env.model 12 | 13 | def _assert_called_cb(self): 14 | if self._model is None: 15 | raise RuntimeError('callback has not called yet') 16 | 17 | @property 18 | def boosters_proxy(self): 19 | self._assert_called_cb() 20 | return self._model 21 | 22 | @property 23 | def raw_boosters(self) -> List[Booster]: 24 | self._assert_called_cb() 25 | return self._model.boosters 26 | 27 | @property 28 | def best_iteration(self): 29 | self._assert_called_cb() 30 | return self._model.best_iteration 31 | -------------------------------------------------------------------------------- /src/mylib/lgb/metrics.py: -------------------------------------------------------------------------------- 1 | import lightgbm as lgb 2 | import numpy as np 3 | from sklearn.metrics import mean_squared_log_error 4 | 5 | 6 | def lgb_rmsle_score(preds: np.ndarray, dval: lgb.Dataset): 7 | label = dval.get_label() 8 | y_true = np.exp(label) 9 | y_pred = np.exp(preds) 10 | 11 | return 'rmsle', np.sqrt(mean_squared_log_error(y_true, y_pred)), False -------------------------------------------------------------------------------- /src/mylib/lgb/null_imp.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import cpu_count 2 | from typing import Optional, Dict 3 | 4 | import lightgbm as lgb 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm.auto import tqdm 8 | 9 | 10 | class NullImpSelection: 11 | def __init__(self, df: pd.DataFrame, params: Optional[Dict] = None, n_repeat: int = 80): 12 | self.df = df 13 | params = params if params is not None else {} 14 | self.params = { 15 | 'objective': 'regression', 16 | 'boosting_type': 'rf', 17 | 'subsample': 0.623, 18 | 'colsample_bytree': 0.7, 19 | 'num_leaves': 127, 20 | 'max_depth': 8, 21 | 'seed': 123, 22 | 'bagging_freq': 1, 23 | 'n_jobs': cpu_count(), 24 | 'verbose': -1, 25 | **params, 26 | } 27 | self.n_repeat = n_repeat 28 | self.real_imp_df: Optional[pd.DataFrame] = None 29 | self.null_imp_df: Optional[pd.DataFrame] = None 30 | 31 | def prepare_imp(self): 32 | self._make_real_imp() 33 | self._make_null_imp() 34 | 35 | def _make_real_imp(self): 36 | self.real_imp_df = self._get_feature_importances(shuffle=False) 37 | 38 | def _make_null_imp(self): 39 | null_imp_df = pd.DataFrame() 40 | for i in tqdm(range(self.n_repeat)): 41 | imp_df = self._get_feature_importances(shuffle=True) 42 | imp_df['run'] = i + 1 43 | null_imp_df = pd.concat([null_imp_df, imp_df], axis=0) 44 | self.null_imp_df = null_imp_df 45 | 46 | def get_feature_scores(self) -> pd.DataFrame: 47 | assert self.real_imp_df is not None, 'Run prepare_imp at first.' 48 | 49 | feature_scores = [] 50 | real_imp = self.real_imp_df 51 | null_imp = self.null_imp_df 52 | 53 | for _f in real_imp['feature'].unique(): 54 | f_null_imps_gain = null_imp.loc[null_imp['feature'] == _f, 'importance_gain'].values 55 | f_act_imps_gain = real_imp.loc[real_imp['feature'] == _f, 'importance_gain'].mean() 56 | gain_score = np.log( 57 | 1e-10 + f_act_imps_gain / (1 + np.percentile(f_null_imps_gain, 75))) # Avoid divide by zero 58 | f_null_imps_split = null_imp.loc[null_imp['feature'] == _f, 'importance_split'].values 59 | f_act_imps_split = real_imp.loc[real_imp['feature'] == _f, 'importance_split'].mean() 60 | split_score = np.log( 61 | 1e-10 + f_act_imps_split / (1 + np.percentile(f_null_imps_split, 75))) # Avoid divide by zero 62 | feature_scores.append((_f, split_score, gain_score)) 63 | 64 | return pd.DataFrame(feature_scores, columns=['feature', 'split_score', 'gain_score']) 65 | 66 | def get_correlation_scores(self) -> pd.DataFrame: 67 | assert self.real_imp_df is not None, 'Run prepare_imp at first.' 68 | 69 | correlation_scores = [] 70 | real_imp = self.real_imp_df 71 | null_imp = self.null_imp_df 72 | 73 | for _f in real_imp['feature'].unique(): 74 | f_null_imps = null_imp.loc[null_imp['feature'] == _f, 'importance_gain'].values 75 | f_act_imps = real_imp.loc[real_imp['feature'] == _f, 'importance_gain'].values 76 | gain_score = 100 * (f_null_imps < np.percentile(f_act_imps, 25)).sum() / f_null_imps.size 77 | f_null_imps = null_imp.loc[null_imp['feature'] == _f, 'importance_split'].values 78 | f_act_imps = real_imp.loc[real_imp['feature'] == _f, 'importance_split'].values 79 | split_score = 100 * (f_null_imps < np.percentile(f_act_imps, 25)).sum() / f_null_imps.size 80 | correlation_scores.append((_f, split_score, gain_score)) 81 | 82 | return pd.DataFrame(correlation_scores, columns=['feature', 'split_score', 'gain_score']) 83 | 84 | def _get_feature_importances(self, shuffle: bool) -> pd.DataFrame: 85 | X = self.df.drop(['target'], axis=1) 86 | y = self.df['target'].copy() 87 | if shuffle: 88 | y = self.df['target'].copy().sample(frac=1.0) 89 | 90 | dtrain = lgb.Dataset(X, label=y) 91 | model = lgb.train( 92 | params=self.params, 93 | train_set=dtrain, 94 | num_boost_round=400, 95 | ) 96 | 97 | x_cols = X.columns 98 | 99 | imp_df = pd.DataFrame() 100 | imp_df["feature"] = list(x_cols) 101 | imp_df["importance_gain"] = model.feature_importance(importance_type='gain') 102 | imp_df["importance_split"] = model.feature_importance(importance_type='split') 103 | 104 | return imp_df 105 | -------------------------------------------------------------------------------- /src/mylib/lgb/util.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pandas as pd 4 | from lightgbm import Booster 5 | 6 | 7 | def make_imp_df(boosters: List[Booster]) -> pd.DataFrame: 8 | df = pd.concat([ 9 | pd.DataFrame({'name': b.feature_name(), 'importance': b.feature_importance()}) 10 | for b in boosters 11 | ]) 12 | return df.groupby('name').mean() \ 13 | .sort_values('importance') \ 14 | .reset_index(level='name') \ 15 | .set_index('name') 16 | -------------------------------------------------------------------------------- /src/mylib/numpy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/numpy/__init__.py -------------------------------------------------------------------------------- /src/mylib/numpy/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rand_rotation_matrix(deflection=1.0, rand=None): 5 | if rand is None: 6 | rand = np.random.uniform(size=(3,)) 7 | 8 | theta, phi, z = rand 9 | 10 | theta = theta * 2.0 * deflection * np.pi 11 | phi = phi * 2.0 * np.pi 12 | z = z * 2.0 * deflection 13 | 14 | r = np.sqrt(z) 15 | V = ( 16 | np.sin(phi) * r, 17 | np.cos(phi) * r, 18 | np.sqrt(2.0 - z) 19 | ) 20 | 21 | st = np.sin(theta) 22 | ct = np.cos(theta) 23 | 24 | R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1))) 25 | M = (np.outer(V, V) - np.eye(3)).dot(R) 26 | 27 | return M 28 | -------------------------------------------------------------------------------- /src/mylib/pandas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/pandas/__init__.py -------------------------------------------------------------------------------- /src/mylib/pandas/cache.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union, Callable, Any 3 | 4 | import pandas as pd 5 | 6 | 7 | class PandasCache: 8 | def __init__(self, cache_path: Union[Path, str]): 9 | self.cache_path = Path(cache_path) 10 | self.cache_path.mkdir(parents=True, exist_ok=True) 11 | 12 | def __call__(self, fn: Callable[[Any], pd.DataFrame]): 13 | def inner(*args, **kwargs): 14 | cache_key = f'{fn.__name__}|{args}|{kwargs}'.replace('/', "\\") 15 | cache = self.cache_path / cache_key 16 | if cache.exists(): 17 | return pd.read_parquet(cache) 18 | 19 | cache_val = fn(*args, **kwargs) 20 | cache_val.to_parquet(self.cache_path / cache_key) 21 | 22 | return cache_val 23 | 24 | return inner 25 | 26 | def clear(self, name: str = ''): 27 | files = self.cache_path.glob(f'{name}*') 28 | for f in files: 29 | f.unlink() 30 | 31 | 32 | if __name__ == '__main__': 33 | c = PandasCache('/tmp/cache') 34 | 35 | 36 | @c 37 | def foo(x, y=1): 38 | print('run foo') 39 | return pd.DataFrame([ 40 | {'a': 1}, 41 | {'a': 2}, 42 | ]) 43 | 44 | 45 | foo(5, y=2) 46 | -------------------------------------------------------------------------------- /src/mylib/pandas/corr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from nancorrmp.nancorrmp import NaNCorrMp 4 | 5 | 6 | def calc_corr(df: pd.DataFrame) -> pd.DataFrame: 7 | corr = NaNCorrMp.calculate(df.select_dtypes('number')) 8 | corr = corr.abs().unstack().sort_values(ascending=False).reset_index() 9 | corr.columns = ['col1', 'col2', 'val'] 10 | df_corr = corr[corr.col1 > corr.col2] 11 | return df_corr.reset_index() 12 | 13 | 14 | def find_high_corr(df_corr: pd.DataFrame, col: str, n: int = 10) -> pd.DataFrame: 15 | df_tmp = df_corr[(df_corr.col1 == col) | (df_corr.col2 == col)].head(n).copy() 16 | df_tmp.loc[df_corr.col1 == col, 'col1'] = np.nan 17 | df_tmp.loc[df_corr.col2 == col, 'col2'] = np.nan 18 | df_tmp['col'] = df_tmp.col1.combine_first(df_tmp.col2) 19 | return df_tmp[['col', 'val']] 20 | -------------------------------------------------------------------------------- /src/mylib/pandas/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def cast_64(df: pd.DataFrame) -> pd.DataFrame: 6 | for c in df.columns: 7 | if df[c].dtype == 'float64': 8 | df[c] = df[c].astype(np.float32) 9 | if df[c].dtype == 'int64': 10 | df[c] = df[c].astype(np.int32) 11 | return df 12 | -------------------------------------------------------------------------------- /src/mylib/params.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | from typing import TypeVar, Type, Optional, Union, IO 4 | 5 | from dacite import from_dict 6 | from omegaconf import OmegaConf, DictConfig 7 | 8 | T = TypeVar('T') 9 | 10 | 11 | class ParamsMixIn: 12 | @classmethod 13 | def load(cls: Type[T], file: Optional[str] = None) -> T: 14 | if file is None: 15 | parser = ArgumentParser() 16 | parser.add_argument('file', type=str) 17 | file = parser.parse_args().file 18 | data = OmegaConf.to_container(OmegaConf.load(file)) 19 | 20 | return from_dict(data_class=cls, data=data) 21 | 22 | def pretty(self) -> str: 23 | return self.dict_config().pretty() 24 | 25 | def dict_config(self) -> DictConfig: 26 | return OmegaConf.structured(self) 27 | 28 | def save(self, f: Union[str, Path, IO[str]]): 29 | OmegaConf.save(self.dict_config(), f) 30 | -------------------------------------------------------------------------------- /src/mylib/sklearn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/sklearn/__init__.py -------------------------------------------------------------------------------- /src/mylib/sklearn/fe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/sklearn/fe/__init__.py -------------------------------------------------------------------------------- /src/mylib/sklearn/fe/pair_count_encoder.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from sklearn.base import TransformerMixin 3 | from sklearn.decomposition import TruncatedSVD 4 | import numpy as np 5 | 6 | 7 | # https://www.kaggle.com/matleonard/categorical-encodings 8 | class PairCountEncoder(TransformerMixin): 9 | def __init__(self, n_components=3, seed=123): 10 | self.svd = TruncatedSVD(n_components=n_components, random_state=seed) 11 | self.svd_encoding = None 12 | 13 | def fit(self, X, y=None): 14 | df = pd.concat(( 15 | pd.DataFrame(X.values, columns=['main', 'sub']), 16 | pd.DataFrame(np.ones(len(X)), columns=['y']) 17 | ), axis=1) 18 | pair_counts = df.groupby(['main', 'sub'])['y'].count() 19 | mat = pair_counts.unstack(fill_value=0) 20 | self.svd_encoding = pd.DataFrame(self.svd.fit_transform(mat), index=mat.index) 21 | return self 22 | 23 | def transform(self, X, y=None): 24 | return self.svd_encoding.reindex(X.values[:, 0]).values 25 | -------------------------------------------------------------------------------- /src/mylib/sklearn/fe/target_encoder.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from typing import List, Optional, Iterable, Union 3 | 4 | import category_encoders as ce 5 | import numpy as np 6 | import pandas as pd 7 | from category_encoders.utils import convert_input, convert_input_vector 8 | from sklearn import model_selection 9 | from sklearn.base import BaseEstimator, clone, TransformerMixin 10 | from sklearn.model_selection import BaseCrossValidator, StratifiedKFold, KFold 11 | from sklearn.utils.multiclass import type_of_target 12 | 13 | 14 | def check_cv(cv: Union[int, Iterable, BaseCrossValidator] = 5, 15 | y: Optional[Union[pd.Series, np.ndarray]] = None, 16 | stratified: bool = False, 17 | random_state: int = 0): 18 | if cv is None: 19 | cv = 5 20 | if isinstance(cv, numbers.Integral): 21 | if stratified and (y is not None) and (type_of_target(y) in ('binary', 'multiclass')): 22 | return StratifiedKFold(cv, shuffle=True, random_state=random_state) 23 | else: 24 | return KFold(cv, shuffle=True, random_state=random_state) 25 | 26 | return model_selection.check_cv(cv, y, stratified) 27 | 28 | 29 | class KFoldEncoderWrapper(BaseEstimator, TransformerMixin): 30 | """KFold Wrapper for sklearn like interface 31 | 32 | This class wraps sklearn's TransformerMixIn (object that has fit/transform/fit_transform methods), 33 | and call it as K-fold manner. 34 | 35 | Args: 36 | base_transformer: 37 | Transformer object to be wrapped. 38 | cv: 39 | int, cross-validation generator or an iterable which determines the cross-validation splitting strategy. 40 | 41 | - None, to use the default ``KFold(5, random_state=0, shuffle=True)``, 42 | - integer, to specify the number of folds in a ``(Stratified)KFold``, 43 | - CV splitter (the instance of ``BaseCrossValidator``), 44 | - An iterable yielding (train, test) splits as arrays of indices. 45 | groups: 46 | Group labels for the samples. Only used in conjunction with a “Group” cv instance (e.g., ``GroupKFold``). 47 | return_same_type: 48 | If True, `transform` and `fit_transform` return the same type as X. 49 | If False, these APIs always return a numpy array, similar to sklearn's API. 50 | """ 51 | 52 | def __init__(self, base_transformer: BaseEstimator, 53 | cv: Optional[Union[int, Iterable, BaseCrossValidator]] = None, return_same_type: bool = True, 54 | groups: Optional[pd.Series] = None): 55 | self.cv = cv 56 | self.base_transformer = base_transformer 57 | 58 | self.n_splits = None 59 | self.transformers = None 60 | self.return_same_type = return_same_type 61 | self.groups = groups 62 | 63 | def _pre_train(self, y): 64 | self.cv = check_cv(self.cv, y) 65 | self.n_splits = self.cv.get_n_splits() 66 | self.transformers = [clone(self.base_transformer) for _ in range(self.n_splits + 1)] 67 | 68 | def _fit_train(self, X: pd.DataFrame, y: Optional[pd.Series], **fit_params) -> pd.DataFrame: 69 | if y is None: 70 | X_ = self.transformers[-1].transform(X) 71 | return self._post_transform(X_) 72 | 73 | X_ = X.copy() 74 | 75 | for i, (train_index, test_index) in enumerate(self.cv.split(X_, y, self.groups)): 76 | self.transformers[i].fit(X.iloc[train_index], y.iloc[train_index], **fit_params) 77 | X_.iloc[test_index, :] = self.transformers[i].transform(X.iloc[test_index]) 78 | self.transformers[-1].fit(X, y, **fit_params) 79 | 80 | return X_ 81 | 82 | def _post_fit(self, X: pd.DataFrame, y: pd.Series) -> pd.DataFrame: 83 | return X 84 | 85 | def _post_transform(self, X: pd.DataFrame) -> pd.DataFrame: 86 | return X 87 | 88 | def fit(self, X: pd.DataFrame, y: pd.Series): 89 | """ 90 | Fit models for each fold. 91 | 92 | Args: 93 | X: 94 | Data 95 | y: 96 | Target 97 | Returns: 98 | returns the transformer object. 99 | """ 100 | self._post_fit(self.fit_transform(X, y), y) 101 | return self 102 | 103 | def transform(self, X: Union[pd.DataFrame, np.ndarray]) -> Union[pd.DataFrame, np.ndarray]: 104 | """ 105 | Transform X 106 | 107 | Args: 108 | X: Data 109 | 110 | Returns: 111 | Transformed version of X. It will be pd.DataFrame If X is `pd.DataFrame` and return_same_type is True. 112 | """ 113 | is_pandas = isinstance(X, pd.DataFrame) 114 | X_ = self._fit_train(X, None) 115 | X_ = self._post_transform(X_) 116 | return X_ if self.return_same_type and is_pandas else X_.values 117 | 118 | def fit_transform(self, X: Union[pd.DataFrame, np.ndarray], y: pd.Series = None, **fit_params) \ 119 | -> Union[pd.DataFrame, np.ndarray]: 120 | """ 121 | Fit models for each fold, then transform X 122 | 123 | Args: 124 | X: 125 | Data 126 | y: 127 | Target 128 | fit_params: 129 | Additional parameters passed to models 130 | 131 | Returns: 132 | Transformed version of X. It will be pd.DataFrame If X is `pd.DataFrame` and return_same_type is True. 133 | """ 134 | assert len(X) == len(y) 135 | self._pre_train(y) 136 | 137 | is_pandas = isinstance(X, pd.DataFrame) 138 | X = convert_input(X) 139 | y = convert_input_vector(y, X.index) 140 | 141 | if y.isnull().sum() > 0: 142 | # y == null is regarded as test data 143 | X_ = X.copy() 144 | X_.loc[~y.isnull(), :] = self._fit_train(X[~y.isnull()], y[~y.isnull()], **fit_params) 145 | X_.loc[y.isnull(), :] = self._fit_train(X[y.isnull()], None, **fit_params) 146 | else: 147 | X_ = self._fit_train(X, y, **fit_params) 148 | 149 | X_ = self._post_transform(self._post_fit(X_, y)) 150 | 151 | return X_ if self.return_same_type and is_pandas else X_.values 152 | 153 | 154 | class TargetEncoder(KFoldEncoderWrapper): 155 | """Target Encoder 156 | 157 | KFold version of category_encoders.TargetEncoder in 158 | https://contrib.scikit-learn.org/categorical-encoding/targetencoder.html. 159 | 160 | Args: 161 | cv: 162 | int, cross-validation generator or an iterable which determines the cross-validation splitting strategy. 163 | 164 | - None, to use the default ``KFold(5, random_state=0, shuffle=True)``, 165 | - integer, to specify the number of folds in a ``(Stratified)KFold``, 166 | - CV splitter (the instance of ``BaseCrossValidator``), 167 | - An iterable yielding (train, test) splits as arrays of indices. 168 | groups: 169 | Group labels for the samples. Only used in conjunction with a “Group” cv instance (e.g., ``GroupKFold``). 170 | cols: 171 | A list of columns to encode, if None, all string columns will be encoded. 172 | drop_invariant: 173 | Boolean for whether or not to drop columns with 0 variance. 174 | handle_missing: 175 | Options are ‘error’, ‘return_nan’ and ‘value’, defaults to ‘value’, which returns the target mean. 176 | handle_unknown: 177 | Options are ‘error’, ‘return_nan’ and ‘value’, defaults to ‘value’, which returns the target mean. 178 | min_samples_leaf: 179 | Minimum samples to take category average into account. 180 | smoothing: 181 | Smoothing effect to balance categorical average vs prior. Higher value means stronger regularization. 182 | The value must be strictly bigger than 0. 183 | return_same_type: 184 | If True, ``transform`` and ``fit_transform`` return the same type as X. 185 | If False, these APIs always return a numpy array, similar to sklearn's API. 186 | """ 187 | 188 | def __init__(self, cv: Optional[Union[Iterable, BaseCrossValidator]] = None, 189 | groups: Optional[pd.Series] = None, 190 | cols: List[str] = None, 191 | drop_invariant: bool = False, handle_missing: str = 'value', handle_unknown: str = 'value', 192 | min_samples_leaf: int = 1, smoothing: float = 1.0, return_same_type: bool = True): 193 | e = ce.TargetEncoder(cols=cols, drop_invariant=drop_invariant, return_df=True, 194 | handle_missing=handle_missing, 195 | handle_unknown=handle_unknown, 196 | min_samples_leaf=min_samples_leaf, smoothing=smoothing) 197 | 198 | super().__init__(e, cv, return_same_type, groups) 199 | 200 | def _post_transform(self, X: pd.DataFrame) -> pd.DataFrame: 201 | cols = self.transformers[0].cols 202 | for c in cols: 203 | X[c] = X[c].astype(float) 204 | return X 205 | -------------------------------------------------------------------------------- /src/mylib/sklearn/split.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import StratifiedKFold 2 | from sklearn.preprocessing import KBinsDiscretizer 3 | 4 | 5 | class KBinsStratifiedKFold(StratifiedKFold): 6 | def __init__( 7 | self, 8 | n_splits=5, 9 | *, 10 | shuffle=False, 11 | random_state=None, 12 | n_bins: int = 5, 13 | strategy: str = 'quantile', 14 | ): 15 | super().__init__(n_splits, shuffle=shuffle, random_state=random_state) 16 | self.kbd = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy=strategy) 17 | 18 | def _iter_test_indices(self, X=None, y=None, groups=None): 19 | super()._iter_test_indices() 20 | 21 | def split(self, X, y, groups=None): 22 | y_binned = self.kbd.fit_transform(y) 23 | return super().split(X, y_binned, groups) 24 | 25 | 26 | # %% 27 | if __name__ == '__main__': 28 | # %% 29 | import numpy as np 30 | 31 | skf = KBinsStratifiedKFold(n_splits=5, shuffle=True, random_state=123, n_bins=10) 32 | y = np.random.random(100).reshape(-1, 1) 33 | # %% 34 | list(skf.split(y, y)) 35 | -------------------------------------------------------------------------------- /src/mylib/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/torch/__init__.py -------------------------------------------------------------------------------- /src/mylib/torch/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/torch/data/__init__.py -------------------------------------------------------------------------------- /src/mylib/torch/data/dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class PandasDataset(Dataset): 6 | def __init__(self, df: pd.DataFrame): 7 | self.df = df 8 | 9 | def __getitem__(self, idx): 10 | row = self.df.iloc[idx] 11 | return row.to_dict() 12 | 13 | def __len__(self): 14 | return len(self.df) 15 | -------------------------------------------------------------------------------- /src/mylib/torch/fe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/torch/fe/__init__.py -------------------------------------------------------------------------------- /src/mylib/torch/fe/bert_emb.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from tqdm.auto import tqdm 8 | from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModel 9 | 10 | 11 | def make_bert_emb( 12 | df: pd.DataFrame, 13 | bert_name: str, 14 | out_path: Path, 15 | tokenize_fn: Callable[[PreTrainedTokenizer, pd.Series], torch.Tensor], 16 | emb_type: str = 'avg_max', 17 | n_shuffle: Optional[int] = None # It's for tags... 18 | ): 19 | tokenizer = AutoTokenizer.from_pretrained(bert_name) 20 | model = AutoModel.from_pretrained(bert_name).eval().cuda() 21 | 22 | save_as_npy = out_path.is_dir() 23 | 24 | results = pd.DataFrame() 25 | for idx, row in tqdm(df.iterrows(), total=len(df)): 26 | token = tokenize_fn(tokenizer, row).unsqueeze(0).cuda() 27 | with torch.no_grad(): 28 | 29 | if n_shuffle is not None: 30 | token = token.squeeze() 31 | cls_token = token[0].reshape(-1) 32 | sep_token = token[-1].reshape(-1) 33 | token = torch.stack([ 34 | torch.cat(( 35 | cls_token, 36 | token[1:-1][torch.randperm(len(token) - 2)], 37 | sep_token, 38 | )) 39 | for _ in range(n_shuffle) 40 | ], dim=0) 41 | 42 | outputs = model(token) 43 | if emb_type == 'cls_token': 44 | emb = outputs[0][:, 0].squeeze().cpu().numpy() 45 | elif emb_type == 'avg_max': 46 | emb = torch.cat(( 47 | outputs[0].mean(dim=1), 48 | outputs[0].max(dim=1)[0], 49 | ), dim=-1).mean(0).cpu().numpy() 50 | else: 51 | raise Exception('unsupported emb_type') 52 | if save_as_npy: 53 | np.save(str(out_path / f'{idx}.npy'), emb) 54 | else: 55 | results = pd.concat(( 56 | results, 57 | pd.DataFrame( 58 | data=emb.reshape(1, -1), 59 | index=[idx], 60 | ), 61 | )) 62 | if not save_as_npy: 63 | results.columns = [str(n) for n in range(results.shape[1])] 64 | results.to_parquet(str(out_path)) 65 | 66 | 67 | def tokenize(tokenizer: PreTrainedTokenizer, row: pd.Series, col: str): 68 | text = row[col] if row[col] is not None else ' ' 69 | 70 | try: 71 | tokens = tokenizer.encode( 72 | text, 73 | add_special_tokens=True, 74 | # max_length=tokenizer.model_max_length, 75 | max_length=512, 76 | ) 77 | except Exception as e: 78 | print(row.name) 79 | raise e 80 | 81 | return torch.tensor(tokens) 82 | -------------------------------------------------------------------------------- /src/mylib/torch/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/torch/nn/__init__.py -------------------------------------------------------------------------------- /src/mylib/torch/nn/functional.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch_scatter import scatter_sum 5 | from torch_scatter.utils import broadcast 6 | 7 | 8 | @torch.jit.script 9 | def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 10 | out: Optional[torch.Tensor] = None, 11 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: 12 | out = scatter_sum(src, index, dim, out, dim_size) 13 | dim_size = out.size(dim) 14 | 15 | index_dim = dim 16 | if index_dim < 0: 17 | index_dim = index_dim + src.dim() 18 | if index.dim() <= index_dim: 19 | index_dim = index.dim() - 1 20 | 21 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) 22 | count = scatter_sum(ones, index, index_dim, None, dim_size) 23 | count_ret = count.clone() 24 | count.clamp_(1) 25 | count = broadcast(count, out, dim) 26 | out.div_(count) 27 | return out, count_ret 28 | 29 | 30 | def onehot(indexes, N=None, ignore_index=None): 31 | """ 32 | Creates a one-representation of indexes with N possible entries 33 | if N is not specified, it will suit the maximum index appearing. 34 | indexes is a long-tensor of indexes 35 | ignore_index will be zero in onehot representation 36 | """ 37 | if N is None: 38 | N = indexes.max() + 1 39 | sz = list(indexes.size()) 40 | output = indexes.new().byte().resize_(*sz, N).zero_() 41 | output.scatter_(-1, indexes.unsqueeze(-1), 1) 42 | if ignore_index is not None and ignore_index >= 0: 43 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0) 44 | return output 45 | -------------------------------------------------------------------------------- /src/mylib/torch/nn/init.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from torch.nn.init import constant_ 4 | 5 | zeros_initializer = partial(constant_, val=0.0) 6 | -------------------------------------------------------------------------------- /src/mylib/torch/nn/mish_init.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def init_weights(m, variance=1.0): 7 | def _calculate_fan_in_and_fan_out(tensor): 8 | dimensions = tensor.dim() 9 | if dimensions < 2: 10 | return 1, 1 11 | 12 | if dimensions == 2: # Linear 13 | fan_in = tensor.size(1) 14 | fan_out = tensor.size(0) 15 | else: 16 | num_input_fmaps = tensor.size(1) 17 | num_output_fmaps = tensor.size(0) 18 | receptive_field_size = 1 19 | if tensor.dim() > 2: 20 | receptive_field_size = tensor[0][0].numel() 21 | fan_in = num_input_fmaps * receptive_field_size 22 | fan_out = num_output_fmaps * receptive_field_size 23 | 24 | return fan_in, fan_out 25 | 26 | def _initialize_weights(tensor, variance, filters=1): 27 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 28 | gain = variance / math.sqrt(fan_in * filters) 29 | with torch.no_grad(): 30 | torch.nn.init.normal_(tensor) 31 | return tensor.data * gain 32 | 33 | def _initialize_bias(tensor, variance): 34 | with torch.no_grad(): 35 | torch.nn.init.normal_(tensor) 36 | return tensor.data * variance 37 | 38 | if m is None: 39 | return 40 | if hasattr(m, 'weight') and m.weight is not None: 41 | # We want to avoid initializing Batch Normalization 42 | if hasattr(m, 'running_mean'): 43 | return 44 | 45 | # If we have channels we probably are a Convolutional Layer 46 | filters = 1 47 | if hasattr(m, 'in_channels'): 48 | filters = m.in_channels 49 | 50 | m.weight.data = _initialize_weights(m.weight, variance=variance, filters=filters) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | m.bias.data = _initialize_bias(m.bias, variance=variance) 53 | -------------------------------------------------------------------------------- /src/mylib/torch/nn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/torch/nn/modules/__init__.py -------------------------------------------------------------------------------- /src/mylib/torch/nn/modules/dense.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.init import xavier_uniform_ 3 | 4 | from mylib.torch.nn.init import zeros_initializer 5 | 6 | 7 | class Dense(nn.Linear): 8 | r"""Fully connected linear layer with activation function. 9 | 10 | .. math:: 11 | y = activation(xW^T + b) 12 | 13 | Args: 14 | in_features (int): number of input feature :math:`x`. 15 | out_features (int): number of output features :math:`y`. 16 | bias (bool, optional): if False, the layer will not adapt bias :math:`b`. 17 | activation (callable, optional): if None, no activation function is used. 18 | weight_init (callable, optional): weight initializer from current weight. 19 | bias_init (callable, optional): bias initializer from current bias. 20 | 21 | """ 22 | 23 | def __init__( 24 | self, 25 | in_features, 26 | out_features, 27 | bias=True, 28 | activation=None, 29 | weight_init=xavier_uniform_, 30 | # weight_init=xavier_normal_, 31 | bias_init=zeros_initializer, 32 | ): 33 | self.weight_init = weight_init 34 | self.bias_init = bias_init 35 | self.activation = activation 36 | # initialize linear layer y = xW^T + b 37 | super(Dense, self).__init__(in_features, out_features, bias) 38 | 39 | def reset_parameters(self): 40 | """Reinitialize models weight and bias values.""" 41 | self.weight_init(self.weight) 42 | if self.bias is not None: 43 | self.bias_init(self.bias) 44 | 45 | def forward(self, inputs): 46 | """Compute layer output. 47 | 48 | Args: 49 | inputs (dict of torch.Tensor): batch of input values. 50 | 51 | Returns: 52 | torch.Tensor: layer output. 53 | 54 | """ 55 | # compute linear layer y = xW^T + b 56 | y = super(Dense, self).forward(inputs) 57 | # add activation function 58 | if self.activation: 59 | y = self.activation(y) 60 | return y 61 | -------------------------------------------------------------------------------- /src/mylib/torch/nn/modules/gauss_rank_transform.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class GaussRankTransform(nn.Module): 8 | def __init__(self, data: torch.Tensor, eps=1e-6): 9 | super(GaussRankTransform, self).__init__() 10 | tformed = self._erfinv(data, eps) 11 | data, sort_idx = data.sort() 12 | self.register_buffer('src', data) 13 | self.register_buffer('dst', tformed[sort_idx]) 14 | 15 | @staticmethod 16 | def _erfinv(data: torch.Tensor, eps): 17 | rank = data.argsort().argsort().float() 18 | 19 | rank_scaled = (rank / rank.max() - 0.5) * 2 20 | rank_scaled = rank_scaled.clamp(-1 + eps, 1 - eps) 21 | 22 | tformed = rank_scaled.erfinv() 23 | 24 | return tformed 25 | 26 | def forward(self, x): 27 | return self._transform(x, self.dst, self.src) 28 | 29 | def invert(self, x): 30 | return self._transform(x, self.src, self.dst) 31 | 32 | def _transform(self, x, src, dst): 33 | pos = src.argsort()[x.argsort().argsort()] 34 | 35 | N = len(self.src) 36 | pos[pos >= N] = N - 1 37 | pos[pos - 1 <= 0] = 0 38 | 39 | x1 = dst[pos] 40 | x2 = dst[pos - 1] 41 | y1 = src[pos] 42 | y2 = src[pos - 1] 43 | 44 | relative = (x - x2) / (x1 - x2) 45 | 46 | return (1 - relative) * y2 + relative * y1 47 | 48 | 49 | # %% 50 | if __name__ == '__main__': 51 | # %% 52 | x = torch.from_numpy(np.random.uniform(low=0, high=1, size=2000)) 53 | 54 | grt = GaussRankTransform(x) 55 | x_tformed = grt.forward(x) 56 | x_inv = grt.invert(x_tformed) 57 | 58 | # %% 59 | print(x) 60 | print(x_inv) 61 | 62 | print(grt.dst) 63 | print(torch.sort(x_tformed)[0]) 64 | 65 | bins = 100 66 | plt.hist(x, bins=bins) 67 | plt.show() 68 | 69 | plt.hist(x_inv, bins=bins) 70 | plt.show() 71 | 72 | plt.hist(grt.src, bins=bins) 73 | plt.show() 74 | 75 | plt.hist(x_tformed, bins=bins) 76 | plt.show() 77 | 78 | plt.hist(grt.dst, bins=bins) 79 | plt.show() 80 | -------------------------------------------------------------------------------- /src/mylib/torch/nn/modules/mlp.py: -------------------------------------------------------------------------------- 1 | from timm.models.layers import mish 2 | from torch import nn 3 | 4 | from mylib.torch.nn.modules.dense import Dense 5 | 6 | 7 | class MLP(nn.Module): 8 | """Multiple layer fully connected perceptron neural network. 9 | 10 | Args: 11 | n_in (int): number of input nodes. 12 | n_out (int): number of output nodes. 13 | n_hidden (list of int or int, optional): number hidden layer nodes. 14 | If an integer, same number of node is used for all hidden layers resulting 15 | in a rectangular network. 16 | If None, the number of neurons is divided by two after each layer starting 17 | n_in resulting in a pyramidal network. 18 | n_layers (int, optional): number of layers. 19 | activation (callable, optional): activation function. All hidden layers would 20 | the same activation function except the output layer that does not apply 21 | any activation function. 22 | 23 | """ 24 | 25 | def __init__( 26 | self, n_in, n_out, n_hidden=None, n_layers=2, activation=mish 27 | ): 28 | super(MLP, self).__init__() 29 | # get list of number of nodes in input, hidden & output layers 30 | if n_hidden is None: 31 | c_neurons = n_in 32 | self.n_neurons = [] 33 | for i in range(n_layers): 34 | self.n_neurons.append(c_neurons) 35 | c_neurons = c_neurons // 2 36 | self.n_neurons.append(n_out) 37 | else: 38 | # get list of number of nodes hidden layers 39 | if type(n_hidden) is int: 40 | n_hidden = [n_hidden] * (n_layers - 1) 41 | self.n_neurons = [n_in] + n_hidden + [n_out] 42 | 43 | # assign a Dense layer (with activation function) to each hidden layer 44 | layers = [ 45 | Dense( 46 | self.n_neurons[i], self.n_neurons[i + 1], 47 | activation=activation, 48 | ) 49 | for i in range(n_layers - 1) 50 | ] 51 | # assign a Dense layer (without activation function) to the output layer 52 | layers.append( 53 | Dense( 54 | self.n_neurons[-2], self.n_neurons[-1], 55 | activation=None, 56 | ) 57 | ) 58 | # put all layers together to make the network 59 | self.out_net = nn.Sequential(*layers) 60 | 61 | def forward(self, inputs): 62 | """Compute neural network output. 63 | 64 | Args: 65 | inputs (torch.Tensor): network input. 66 | 67 | Returns: 68 | torch.Tensor: network output. 69 | 70 | """ 71 | return self.out_net(inputs) 72 | -------------------------------------------------------------------------------- /src/mylib/torch/nn/modules/pair_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class PairNorm(nn.Module): 5 | def __init__(self, mode='PN', scale=1): 6 | """ 7 | mode: 8 | 'None' : No normalization 9 | 'PN' : Original version 10 | 'PN-SI' : Scale-Individually version 11 | 'PN-SCS' : Scale-and-Center-Simultaneously version 12 | 13 | ('SCS'-mode is not in the paper but we found it works well in practice, 14 | especially for GCN and GAT.) 15 | PairNorm is typically used after each graph convolution operation. 16 | """ 17 | assert mode in ['None', 'PN', 'PN-SI', 'PN-SCS'] 18 | super(PairNorm, self).__init__() 19 | self.mode = mode 20 | self.scale = scale 21 | 22 | # Scale can be set based on origina data, and also the current feature lengths. 23 | # We leave the experiments to future. A good pool we used for choosing scale: 24 | # [0.1, 1, 10, 50, 100] 25 | 26 | def forward(self, x): 27 | if self.mode == 'None': 28 | return x 29 | 30 | col_mean = x.mean(dim=0) 31 | if self.mode == 'PN': 32 | x = x - col_mean 33 | rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt() 34 | x = self.scale * x / rownorm_mean 35 | 36 | if self.mode == 'PN-SI': 37 | x = x - col_mean 38 | rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt() 39 | x = self.scale * x / rownorm_individual 40 | 41 | if self.mode == 'PN-SCS': 42 | rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt() 43 | x = self.scale * x / rownorm_individual - col_mean 44 | 45 | return x 46 | -------------------------------------------------------------------------------- /src/mylib/torch/nn/modules/se_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.models.layers import Mish 3 | from torch import nn 4 | 5 | from mylib.torch.nn.modules.dense import Dense 6 | 7 | 8 | class SELayer(nn.Module): 9 | def __init__(self, channel, reduction=16): 10 | super(SELayer, self).__init__() 11 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 12 | self.fc = nn.Sequential( 13 | # nn.Linear(channel, channel // reduction, bias=False), 14 | # nn.ReLU(inplace=True), 15 | # nn.Linear(channel // reduction, channel, bias=False), 16 | Dense(channel, channel // reduction, bias=False), 17 | Mish(inplace=True), 18 | Dense(channel // reduction, channel, bias=False), 19 | nn.Sigmoid() 20 | ) 21 | 22 | def forward(self, x): 23 | b, c, _, _ = x.size() 24 | y = self.avg_pool(x).view(b, c) 25 | y = self.fc(y).view(b, c, 1, 1) 26 | return x * y.expand_as(x) 27 | 28 | 29 | if __name__ == '__main__': 30 | inputs = torch.randn((3, 12, 768, 1)) 31 | m = SELayer(12) 32 | # %% 33 | m(inputs).shape 34 | -------------------------------------------------------------------------------- /src/mylib/torch/optim/SM3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | 5 | class SM3(Optimizer): 6 | """Implements SM3 algorithm. 7 | It has been proposed in `Memory-Efficient Adaptive Optimization`_. 8 | Arguments: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float, optional): coefficient that scale delta before it is applied 12 | to the parameters (default: 0.1) 13 | momentum (float, optional): coefficient used to scale prior updates 14 | before adding. This drastically increases memory usage if 15 | `momentum > 0.0`. This is ignored if the parameter's gradient 16 | is sparse. (default: 0.0) 17 | beta (float, optional): coefficient used for exponential moving 18 | averages (default: 0.0) 19 | eps (float, optional): Term added to square-root in denominator to 20 | improve numerical stability (default: 1e-30) 21 | .. _Memory-Efficient Adaptive Optimization: 22 | https://arxiv.org/abs/1901.11150 23 | """ 24 | 25 | def __init__(self, params, lr=0.1, momentum=0.0, beta=0.0, eps=1e-30): 26 | if not 0.0 <= lr: 27 | raise ValueError("Invalid learning rate: {0}".format(lr)) 28 | if not 0.0 <= momentum < 1.0: 29 | raise ValueError("Invalid momentum: {0}".format(momentum)) 30 | if not 0.0 <= beta < 1.0: 31 | raise ValueError("Invalid beta: {0}".format(beta)) 32 | if not 0.0 <= eps: 33 | raise ValueError("Invalid eps: {0}".format(eps)) 34 | 35 | defaults = {'lr': lr, 'momentum': momentum, 'beta': beta, 'eps': eps} 36 | super(SM3, self).__init__(params, defaults) 37 | 38 | @torch.no_grad() 39 | def step(self, closure=None): 40 | """Performs single optimization step. 41 | Arguments: 42 | closure (callable, optional): A closure that reevaluates the model 43 | and returns the loss. 44 | """ 45 | loss = None 46 | if closure is not None: 47 | with torch.enable_grad(): 48 | loss = closure() 49 | 50 | for group in self.param_groups: 51 | momentum = group['momentum'] 52 | beta = group['beta'] 53 | eps = group['eps'] 54 | for p in group['params']: 55 | if p is None: 56 | continue 57 | grad = p.grad 58 | 59 | state = self.state[p] 60 | shape = grad.shape 61 | rank = len(shape) 62 | 63 | # State initialization 64 | if len(state) == 0: 65 | state['step'] = 0 66 | state['momentum_buffer'] = 0. 67 | _add_initial_accumulators(state, grad) 68 | 69 | if grad.is_sparse: 70 | # the update is non-linear so indices must be unique 71 | grad.coalesce() 72 | grad_indices = grad._indices() 73 | grad_values = grad._values() 74 | 75 | # Transform update_values into sparse tensor 76 | def make_sparse(values): 77 | constructor = grad.new 78 | if grad_indices.dim() == 0 or values.dim() == 0: 79 | return constructor().resize_as_(grad) 80 | return constructor(grad_indices, values, grad.size()) 81 | 82 | acc = state[_key(0)] 83 | update_values = _compute_sparse_update(beta, acc, grad_values, grad_indices) 84 | 85 | self._update_sparse_accumulator(beta, acc, make_sparse(update_values)) 86 | 87 | # Add small amount for numerical stability 88 | update_values.add_(eps).rsqrt_().mul_(grad_values) 89 | 90 | update = make_sparse(update_values) 91 | else: 92 | # Get previous accumulators mu_{t-1} 93 | if rank > 1: 94 | acc_list = [state[_key(i)] for i in range(rank)] 95 | else: 96 | acc_list = [state[_key(0)]] 97 | 98 | # Get update from accumulators and gradients 99 | update = _compute_update(beta, acc_list, grad) 100 | 101 | # Update accumulators. 102 | self._update_accumulator(beta, acc_list, update) 103 | 104 | # Add small amount for numerical stability 105 | update.add_(eps).rsqrt_().mul_(grad) 106 | 107 | if momentum > 0.: 108 | m = state['momentum_buffer'] 109 | update.mul_(1. - momentum).add_(momentum, m) 110 | state['momentum_buffer'] = update.detach() 111 | 112 | p.sub_(group['lr'], update) 113 | state['step'] += 1 114 | return loss 115 | 116 | def _update_accumulator(self, beta, acc_list, update): 117 | for i, acc in enumerate(acc_list): 118 | nu_max = _max_reduce_except_dim(update, i) 119 | if beta > 0.: 120 | torch.max(acc, nu_max, out=acc) 121 | else: 122 | # No need to compare - nu_max is bigger because of grad ** 2 123 | acc.copy_(nu_max) 124 | 125 | def _update_sparse_accumulator(self, beta, acc, update): 126 | nu_max = _max_reduce_except_dim(update.to_dense(), 0).squeeze() 127 | if beta > 0.: 128 | torch.max(acc, nu_max, out=acc) 129 | else: 130 | # No need to compare - nu_max is bigger because of grad ** 2 131 | acc.copy_(nu_max) 132 | 133 | 134 | def _compute_sparse_update(beta, acc, grad_values, grad_indices): 135 | # In the sparse case, a single accumulator is used. 136 | update_values = torch.gather(acc, 0, grad_indices[0]) 137 | if beta > 0.: 138 | update_values.mul_(beta) 139 | update_values.addcmul_(1. - beta, grad_values, grad_values) 140 | return update_values 141 | 142 | 143 | def _compute_update(beta, acc_list, grad): 144 | rank = len(acc_list) 145 | update = acc_list[0].clone() 146 | for i in range(1, rank): 147 | # We rely on broadcasting to get the proper end shape. 148 | # Note that torch.min is currently (as of 1.23.2020) not commutative - 149 | # the order matters for NaN values. 150 | update = torch.min(update, acc_list[i]) 151 | if beta > 0.: 152 | update.mul_(beta) 153 | update.addcmul_(1. - beta, grad, grad) 154 | 155 | return update 156 | 157 | 158 | def _key(i): 159 | # Returns key used for accessing accumulators 160 | return 'accumulator_' + str(i) 161 | 162 | 163 | def _add_initial_accumulators(state, grad): 164 | # Creates initial accumulators. For a dense tensor of shape (n1, n2, n3), 165 | # then our initial accumulators are of shape (n1, 1, 1), (1, n2, 1) and 166 | # (1, 1, n3). For a sparse tensor of shape (n, *), we use a single 167 | # accumulator of shape (n,). 168 | shape = grad.shape 169 | rank = len(shape) 170 | defaults = {'device': grad.device, 'dtype': grad.dtype} 171 | acc = {} 172 | 173 | if grad.is_sparse: 174 | acc[_key(0)] = torch.zeros(shape[0], **defaults) 175 | elif rank == 0: 176 | # The scalar case is handled separately 177 | acc[_key(0)] = torch.zeros(shape, **defaults) 178 | else: 179 | for i in range(rank): 180 | acc_shape = [1] * i + [shape[i]] + [1] * (rank - 1 - i) 181 | acc[_key(i)] = torch.zeros(acc_shape, **defaults) 182 | 183 | state.update(acc) 184 | 185 | 186 | def _max_reduce_except_dim(tensor, dim): 187 | # Computes max along all dimensions except the given dim. 188 | # If tensor is a scalar, it returns tensor. 189 | rank = len(tensor.shape) 190 | result = tensor 191 | if rank > 0: 192 | assert dim < rank 193 | for d in range(rank): 194 | if d != dim: 195 | result = result.max(dim=d, keepdim=True).values 196 | return result 197 | -------------------------------------------------------------------------------- /src/mylib/torch/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/torch/optim/__init__.py -------------------------------------------------------------------------------- /src/mylib/torch/optim/sched.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import numpy as np 4 | 5 | 6 | @dataclasses.dataclass 7 | class _FlatCos: 8 | total_steps: int 9 | flat_rate: float 10 | cos_rate: float 11 | 12 | def __call__(self, step: int): 13 | flat_steps = int(self.flat_rate / (self.flat_rate + self.cos_rate) * total_steps) 14 | cos_steps = total_steps - flat_steps 15 | if step <= flat_steps: 16 | return 1 17 | f = np.cos((step - flat_steps) / cos_steps * np.pi) * 0.5 + 0.5 18 | return np.clip(f, 0, 1) 19 | 20 | 21 | def flat_cos(total_steps: int, flat_rate: float = 1., cos_rate: float = 0.72): 22 | return _FlatCos(total_steps, flat_rate, cos_rate) 23 | 24 | 25 | @dataclasses.dataclass 26 | class _Linear: 27 | total_steps: int 28 | start: float 29 | stop: float 30 | flat_rate_pre: float 31 | flat_rate_post: float 32 | 33 | def __post_init__(self): 34 | steps_pre = int(self.flat_rate_pre * total_steps) 35 | steps_post = int(self.flat_rate_post * total_steps) 36 | linear_steps = total_steps - steps_pre - steps_post 37 | self.schedule = np.concatenate(( 38 | np.ones(steps_pre) * self.start, 39 | np.linspace(self.start, self.stop, linear_steps), 40 | np.ones(steps_post) * self.stop, 41 | )) 42 | 43 | def __call__(self, step: int): 44 | return self.schedule[step] 45 | 46 | 47 | def linear(total_steps: int, start: float = 0., stop: float = 1., flat_rate_pre: float = 0., flat_rate_post: float = 0.): 48 | return _Linear(total_steps, start, stop, flat_rate_pre, flat_rate_post) 49 | 50 | 51 | # %% 52 | if __name__ == '__main__': 53 | # %% 54 | import matplotlib.pyplot as plt 55 | 56 | # %% 57 | total_steps = 100 58 | sched = linear(total_steps, flat_rate_pre=0.1, flat_rate_post=0.2) 59 | # sched = flat_cos(total_steps) 60 | values = [sched(n) for n in range(total_steps)] 61 | 62 | plt.plot(range(total_steps), values) 63 | plt.show() 64 | -------------------------------------------------------------------------------- /src/mylib/torch/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/torch/tools/__init__.py -------------------------------------------------------------------------------- /src/mylib/torch/tools/ema/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/torch/tools/ema/__init__.py -------------------------------------------------------------------------------- /src/mylib/torch/tools/ema/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def update_ema(ema_model: nn.Module, model: nn.Module, decay: float): 6 | with torch.no_grad(): 7 | msd = model.state_dict() 8 | for k, ema_v in ema_model.state_dict().items(): 9 | model_v = msd[k].detach() 10 | ema_v.copy_(ema_v * decay + (1. - decay) * model_v) 11 | -------------------------------------------------------------------------------- /src/mylib/torch/tools/lr_finder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | from tqdm.autonotebook import tqdm 8 | 9 | from csc.loader import AtomsBatch 10 | 11 | try: 12 | from apex import amp 13 | 14 | IS_AMP_AVAILABLE = True 15 | except ImportError: 16 | import logging 17 | 18 | logging.basicConfig() 19 | logger = logging.getLogger(__name__) 20 | logger.warning( 21 | "To enable mixed precision training, please install `apex`. " 22 | "Or you can re-install this package by the following command:\n" 23 | ' pip install models-lr-finder -v --global-option="amp"' 24 | ) 25 | IS_AMP_AVAILABLE = False 26 | del logging 27 | 28 | 29 | class LRFinder(object): 30 | """Learning rate range test. 31 | 32 | The learning rate range test increases the learning rate in a pre-training run 33 | between two boundaries in a linear or exponential manner. It provides valuable 34 | information on how well the network can be trained over a range of learning rates 35 | and what is the optimal learning rate. 36 | 37 | Arguments: 38 | model (torch.nn.Module): wrapped models. 39 | optimizer (models.optim.Optimizer): wrapped optimizer where the defined learning 40 | is assumed to be the lower boundary of the range test. 41 | criterion (torch.nn.Module): wrapped loss function. 42 | device (str or models.device, optional): a string ("cpu" or "cuda") with an 43 | optional ordinal for the device type (e.g. "cuda:X", where is the ordinal). 44 | Alternatively, can be an object representing the device on which the 45 | computation will take place. Default: None, uses the same device as `models`. 46 | memory_cache (boolean, optional): if this flag is set to True, `state_dict` of 47 | models and optimizer will be cached in memory. Otherwise, they will be saved 48 | to files under the `cache_dir`. 49 | cache_dir (string, optional): path for storing temporary files. If no path is 50 | specified, system-wide temporary directory is used. Notice that this 51 | parameter will be ignored if `memory_cache` is True. 52 | 53 | Example: 54 | >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") 55 | >>> lr_finder.range_test(dataloader, end_lr=100, num_iter=100) 56 | >>> lr_finder.plot() # to inspect the loss-learning rate graph 57 | >>> lr_finder.reset() # to reset the models and optimizer to their initial state 58 | 59 | Reference: 60 | Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 61 | fastai/lr_find: https://github.com/fastai/fastai 62 | """ 63 | 64 | def __init__( 65 | self, 66 | model, 67 | optimizer, 68 | criterion, 69 | device=None, 70 | memory_cache=True, 71 | cache_dir=None, 72 | ): 73 | # Check if the optimizer is already attached to a scheduler 74 | self.optimizer = optimizer 75 | self._check_for_scheduler() 76 | 77 | self.model = model 78 | self.criterion = criterion 79 | self.history = {"lr": [], "loss": []} 80 | self.best_loss = None 81 | self.memory_cache = memory_cache 82 | self.cache_dir = cache_dir 83 | 84 | # Save the original state of the models and optimizer so they can be restored if 85 | # needed 86 | self.model_device = next(self.model.parameters()).device 87 | self.state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) 88 | self.state_cacher.store("models", self.model.state_dict()) 89 | self.state_cacher.store("optimizer", self.optimizer.state_dict()) 90 | 91 | # If device is None, use the same as the models 92 | if device: 93 | self.device = device 94 | else: 95 | self.device = self.model_device 96 | 97 | def reset(self): 98 | """Restores the models and optimizer to their initial states.""" 99 | 100 | self.model.load_state_dict(self.state_cacher.retrieve("models")) 101 | self.optimizer.load_state_dict(self.state_cacher.retrieve("optimizer")) 102 | self.model.to(self.model_device) 103 | 104 | def range_test( 105 | self, 106 | train_loader, 107 | val_loader=None, 108 | start_lr=None, 109 | end_lr=10, 110 | num_iter=100, 111 | step_mode="exp", 112 | smooth_f=0.05, 113 | diverge_th=5, 114 | accumulation_steps=1, 115 | ): 116 | """Performs the learning rate range test. 117 | 118 | Arguments: 119 | train_loader (models.utils.data.DataLoader): the training set data laoder. 120 | val_loader (models.utils.data.DataLoader, optional): if `None` the range test 121 | will only use the training loss. When given a data loader, the models is 122 | evaluated after each iteration on that dataset and the evaluation loss 123 | is used. Note that in this mode the test takes significantly longer but 124 | generally produces more precise results. Default: None. 125 | start_lr (float, optional): the starting learning rate for the range test. 126 | Default: None (uses the learning rate from the optimizer). 127 | end_lr (float, optional): the maximum learning rate to test. Default: 10. 128 | num_iter (int, optional): the number of iterations over which the test 129 | occurs. Default: 100. 130 | step_mode (str, optional): one of the available learning rate policies, 131 | linear or exponential ("linear", "exp"). Default: "exp". 132 | smooth_f (float, optional): the loss smoothing factor within the [0, 1[ 133 | interval. Disabled if set to 0, otherwise the loss is smoothed using 134 | exponential smoothing. Default: 0.05. 135 | diverge_th (int, optional): the test is stopped when the loss surpasses the 136 | threshold: diverge_th * best_loss. Default: 5. 137 | accumulation_steps (int, optional): steps for gradient accumulation. If it 138 | is 1, gradients are not accumulated. Default: 1. 139 | 140 | Example (fastai approach): 141 | >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") 142 | >>> lr_finder.range_test(dataloader, end_lr=100, num_iter=100) 143 | 144 | Example (Leslie Smith's approach): 145 | >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") 146 | >>> lr_finder.range_test(trainloader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear") 147 | 148 | Gradient accumulation is supported; example: 149 | >>> train_data = ... # prepared dataset 150 | >>> desired_bs, real_bs = 32, 4 # batch size 151 | >>> accumulation_steps = desired_bs // real_bs # required steps for accumulation 152 | >>> dataloader = models.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True) 153 | >>> acc_lr_finder = LRFinder(net, optimizer, criterion, device="cuda") 154 | >>> acc_lr_finder.range_test(dataloader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps) 155 | 156 | Reference: 157 | [Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups]( 158 | https://medium.com/huggingface/ec88c3e51255) 159 | [thomwolf/gradient_accumulation](https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3) 160 | """ 161 | 162 | # Reset test results 163 | self.history = {"lr": [], "loss": []} 164 | self.best_loss = None 165 | 166 | # Move the models to the proper device 167 | self.model.to(self.device) 168 | 169 | # Check if the optimizer is already attached to a scheduler 170 | self._check_for_scheduler() 171 | 172 | # Set the starting learning rate 173 | if start_lr: 174 | self._set_learning_rate(start_lr) 175 | 176 | # Initialize the proper learning rate policy 177 | if step_mode.lower() == "exp": 178 | lr_schedule = ExponentialLR(self.optimizer, end_lr, num_iter) 179 | elif step_mode.lower() == "linear": 180 | lr_schedule = LinearLR(self.optimizer, end_lr, num_iter) 181 | else: 182 | raise ValueError("expected one of (exp, linear), got {}".format(step_mode)) 183 | 184 | if smooth_f < 0 or smooth_f >= 1: 185 | raise ValueError("smooth_f is outside the range [0, 1[") 186 | 187 | # Create an iterator to get data batch by batch 188 | iter_wrapper = DataLoaderIterWrapper(train_loader) 189 | for iteration in tqdm(range(num_iter)): 190 | # Train on batch and retrieve loss 191 | loss = self._train_batch(iter_wrapper, accumulation_steps) 192 | if val_loader: 193 | loss = self._validate(val_loader) 194 | 195 | # Update the learning rate 196 | lr_schedule.step() 197 | self.history["lr"].append(lr_schedule.get_lr()[0]) 198 | 199 | # Track the best loss and smooth it if smooth_f is specified 200 | if iteration == 0: 201 | self.best_loss = loss 202 | else: 203 | if smooth_f > 0: 204 | loss = smooth_f * loss + (1 - smooth_f) * self.history["loss"][-1] 205 | if loss < self.best_loss: 206 | self.best_loss = loss 207 | 208 | # Check if the loss has diverged; if it has, stop the test 209 | self.history["loss"].append(loss) 210 | if loss > diverge_th * self.best_loss: 211 | print("Stopping early, the loss has diverged") 212 | break 213 | 214 | print("Learning rate search finished. See the graph with {finder_name}.plot()") 215 | 216 | def _set_learning_rate(self, new_lrs): 217 | if not isinstance(new_lrs, list): 218 | new_lrs = [new_lrs] * len(self.optimizer.param_groups) 219 | if len(new_lrs) != len(self.optimizer.param_groups): 220 | raise ValueError( 221 | "Length of `new_lrs` is not equal to the number of parameter groups " 222 | + "in the given optimizer" 223 | ) 224 | 225 | for param_group, new_lr in zip(self.optimizer.param_groups, new_lrs): 226 | param_group["lr"] = new_lr 227 | 228 | def _check_for_scheduler(self): 229 | for param_group in self.optimizer.param_groups: 230 | if "initial_lr" in param_group: 231 | raise RuntimeError("Optimizer already has a scheduler attached to it") 232 | 233 | def _train_batch(self, iter_wrapper, accumulation_steps): 234 | self.model.train() 235 | total_loss = None # for late initialization 236 | 237 | self.optimizer.zero_grad() 238 | for i in range(accumulation_steps): 239 | # inputs, labels = iter_wrapper.get_batch() 240 | # inputs, labels = self._move_to_device(inputs, labels) 241 | inputs = iter_wrapper.get_batch() 242 | inputs = AtomsBatch.from_dict(inputs, device='cuda') 243 | 244 | # Forward pass 245 | # outputs: MolOut = self.models(inputs) 246 | # loss = self.criterion(outputs.y_pred, outputs.y_true) 247 | y_pred = self.model(inputs) 248 | loss = self.criterion(y_pred, inputs) 249 | 250 | # Loss should be averaged in each step 251 | loss /= accumulation_steps 252 | 253 | # Backward pass 254 | if IS_AMP_AVAILABLE and hasattr(self.optimizer, "_amp_stash"): 255 | # For minor performance optimization, see also: 256 | # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations 257 | delay_unscale = ((i + 1) % accumulation_steps) != 0 258 | 259 | with amp.scale_loss( 260 | loss, self.optimizer, delay_unscale=delay_unscale 261 | ) as scaled_loss: 262 | scaled_loss.backward() 263 | else: 264 | loss.backward() 265 | 266 | if total_loss is None: 267 | total_loss = loss 268 | else: 269 | total_loss += loss 270 | 271 | self.optimizer.step() 272 | 273 | return total_loss.item() 274 | 275 | def _move_to_device(self, inputs, labels): 276 | def move(obj, device): 277 | if isinstance(obj, tuple): 278 | return tuple(move(o, device) for o in obj) 279 | elif torch.is_tensor(obj): 280 | return obj.to(device) 281 | else: 282 | return obj 283 | 284 | inputs = move(inputs, self.device) 285 | labels = move(labels, self.device) 286 | return inputs, labels 287 | 288 | def _validate(self, dataloader): 289 | # Set models to evaluation mode and disable gradient computation 290 | running_loss = 0 291 | total_pairs = 0 292 | self.model.eval() 293 | with torch.no_grad(): 294 | for inputs in dataloader: 295 | # Move data to the correct device 296 | # inputs, labels = self._move_to_device(inputs, labels) 297 | inputs = AtomsBatch.from_dict(inputs, device='cuda') 298 | 299 | # Forward pass and loss computation 300 | # outputs: MolOut = self.models(inputs) 301 | # loss = self.criterion(outputs.y_pred, outputs.y_true) 302 | y_pred = self.model(inputs) 303 | loss = self.criterion(y_pred, inputs) 304 | running_loss += loss.item() * len(y_pred) 305 | total_pairs += len(y_pred) 306 | 307 | return running_loss / total_pairs 308 | 309 | def plot(self, skip_start=10, skip_end=5, log_lr=True, show_lr=None): 310 | """Plots the learning rate range test. 311 | 312 | Arguments: 313 | skip_start (int, optional): number of batches to trim from the start. 314 | Default: 10. 315 | skip_end (int, optional): number of batches to trim from the start. 316 | Default: 5. 317 | log_lr (bool, optional): True to plot the learning rate in a logarithmic 318 | scale; otherwise, plotted in a linear scale. Default: True. 319 | show_lr (float, optional): is set, will add vertical line to visualize 320 | specified learning rate; Default: None. 321 | """ 322 | 323 | if skip_start < 0: 324 | raise ValueError("skip_start cannot be negative") 325 | if skip_end < 0: 326 | raise ValueError("skip_end cannot be negative") 327 | if show_lr is not None and not isinstance(show_lr, float): 328 | raise ValueError("show_lr must be float") 329 | 330 | # Get the data to plot from the history dictionary. Also, handle skip_end=0 331 | # properly so the behaviour is the expected 332 | lrs = self.history["lr"] 333 | losses = self.history["loss"] 334 | if skip_end == 0: 335 | lrs = lrs[skip_start:] 336 | losses = losses[skip_start:] 337 | else: 338 | lrs = lrs[skip_start:-skip_end] 339 | losses = losses[skip_start:-skip_end] 340 | 341 | # Plot loss as a function of the learning rate 342 | plt.plot(lrs, losses) 343 | if log_lr: 344 | plt.xscale("log") 345 | plt.xlabel("Learning rate") 346 | plt.ylabel("Loss") 347 | 348 | if show_lr is not None: 349 | plt.axvline(x=show_lr, color="red") 350 | plt.show() 351 | 352 | 353 | class LinearLR(_LRScheduler): 354 | """Linearly increases the learning rate between two boundaries over a number of 355 | iterations. 356 | 357 | Arguments: 358 | optimizer (models.optim.Optimizer): wrapped optimizer. 359 | end_lr (float): the final learning rate. 360 | num_iter (int): the number of iterations over which the test occurs. 361 | last_epoch (int, optional): the index of last epoch. Default: -1. 362 | """ 363 | 364 | def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1): 365 | self.end_lr = end_lr 366 | self.num_iter = num_iter 367 | super(LinearLR, self).__init__(optimizer, last_epoch) 368 | 369 | def get_lr(self): 370 | curr_iter = self.last_epoch + 1 371 | r = curr_iter / self.num_iter 372 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] 373 | 374 | 375 | class ExponentialLR(_LRScheduler): 376 | """Exponentially increases the learning rate between two boundaries over a number of 377 | iterations. 378 | 379 | Arguments: 380 | optimizer (models.optim.Optimizer): wrapped optimizer. 381 | end_lr (float): the final learning rate. 382 | num_iter (int): the number of iterations over which the test occurs. 383 | last_epoch (int, optional): the index of last epoch. Default: -1. 384 | """ 385 | 386 | def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1): 387 | self.end_lr = end_lr 388 | self.num_iter = num_iter 389 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 390 | 391 | def get_lr(self): 392 | curr_iter = self.last_epoch + 1 393 | r = curr_iter / self.num_iter 394 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] 395 | 396 | 397 | class StateCacher(object): 398 | def __init__(self, in_memory, cache_dir=None): 399 | self.in_memory = in_memory 400 | self.cache_dir = cache_dir 401 | 402 | if self.cache_dir is None: 403 | import tempfile 404 | 405 | self.cache_dir = tempfile.gettempdir() 406 | else: 407 | if not os.path.isdir(self.cache_dir): 408 | raise ValueError("Given `cache_dir` is not a valid directory.") 409 | 410 | self.cached = {} 411 | 412 | def store(self, key, state_dict): 413 | if self.in_memory: 414 | self.cached.update({key: copy.deepcopy(state_dict)}) 415 | else: 416 | fn = os.path.join(self.cache_dir, "state_{}_{}.pt".format(key, id(self))) 417 | self.cached.update({key: fn}) 418 | torch.save(state_dict, fn) 419 | 420 | def retrieve(self, key): 421 | if key not in self.cached: 422 | raise KeyError("Target {} was not cached.".format(key)) 423 | 424 | if self.in_memory: 425 | return self.cached.get(key) 426 | else: 427 | fn = self.cached.get(key) 428 | if not os.path.exists(fn): 429 | raise RuntimeError( 430 | "Failed to load state in {}. File doesn't exist anymore.".format(fn) 431 | ) 432 | state_dict = torch.load(fn, map_location=lambda storage, location: storage) 433 | return state_dict 434 | 435 | def __del__(self): 436 | """Check whether there are unused cached files existing in `cache_dir` before 437 | this instance being destroyed.""" 438 | 439 | if self.in_memory: 440 | return 441 | 442 | for k in self.cached: 443 | if os.path.exists(self.cached[k]): 444 | os.remove(self.cached[k]) 445 | 446 | 447 | class DataLoaderIterWrapper(object): 448 | """A wrapper for iterating `models.utils.data.DataLoader` with the ability to reset 449 | itself while `StopIteration` is raised.""" 450 | 451 | def __init__(self, data_loader, auto_reset=True): 452 | self.data_loader = data_loader 453 | self.auto_reset = auto_reset 454 | self._iterator = iter(data_loader) 455 | 456 | def __next__(self): 457 | # Get a new set of inputs and labels 458 | try: 459 | inputs = next(self._iterator) 460 | except StopIteration: 461 | if not self.auto_reset: 462 | raise 463 | self._iterator = iter(self.data_loader) 464 | inputs = next(self._iterator) 465 | 466 | return inputs 467 | 468 | # make it compatible with python 2 469 | next = __next__ 470 | 471 | def get_batch(self): 472 | return next(self) 473 | -------------------------------------------------------------------------------- /src/mylib/torch/tools/swa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/torch/tools/swa/__init__.py -------------------------------------------------------------------------------- /src/mylib/torch/tools/swa/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def update_swa(net1, net2, alpha=1): 5 | with torch.no_grad(): # TODO need this? 6 | for param1, param2 in zip(net1.parameters(), net2.parameters()): 7 | param1.data *= (1.0 - alpha) 8 | param1.data += param2.data * alpha 9 | 10 | 11 | def _check_bn(module, flag): 12 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 13 | flag[0] = True 14 | 15 | 16 | def check_bn(model): 17 | flag = [False] 18 | model.apply(lambda module: _check_bn(module, flag)) 19 | return flag[0] 20 | 21 | 22 | def reset_bn(module): 23 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 24 | module.running_mean = torch.zeros_like(module.running_mean) 25 | module.running_var = torch.ones_like(module.running_var) 26 | 27 | 28 | def _get_momenta(module, momenta): 29 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 30 | momenta[module] = module.momentum 31 | 32 | 33 | def _set_momenta(module, momenta): 34 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 35 | module.momentum = momenta[module] 36 | 37 | 38 | def bn_update(loader, model): 39 | """ 40 | BatchNorm buffers update (if any). 41 | Performs 1 epochs to estimate buffers average using train dataset. 42 | 43 | :param loader: train dataset loader for buffers average estimation. 44 | :param model: models being update 45 | :return: None 46 | """ 47 | if not check_bn(model): 48 | return 49 | model.train() 50 | momenta = {} 51 | model.apply(reset_bn) 52 | model.apply(lambda module: _get_momenta(module, momenta)) 53 | n = 0 54 | for input, _ in loader: 55 | input = input.cuda() 56 | input_var = torch.autograd.Variable(input) 57 | b = input_var.data.size(0) 58 | 59 | momentum = b / (n + b) 60 | for module in momenta.keys(): 61 | module.momentum = momentum 62 | 63 | model(input_var) 64 | n += b 65 | 66 | model.apply(lambda module: _set_momenta(module, momenta)) 67 | -------------------------------------------------------------------------------- /src/mylib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akirasosa/pytorch-dimenet/491154d7dea8770ffd853de895f6a2b5176787f2/src/mylib/utils/__init__.py -------------------------------------------------------------------------------- /src/mylib/utils/plt.py: -------------------------------------------------------------------------------- 1 | def rotate_ticks_label(plt, rotation=45): 2 | plt.xticks( 3 | rotation=rotation, 4 | horizontalalignment='right', 5 | fontweight='light', 6 | fontsize='x-large' 7 | ) 8 | -------------------------------------------------------------------------------- /src/mylib/utils/text.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | html_tags = ['
', '
', '', '', 6 | '
', '', '
', '', ' ', '', ' ', '', '']
8 |
9 | empty_expressions = ['<', '>', '&', ' ',
10 | ' ', '–', '—', ' '
11 | '"', ''']
12 |
13 | other = ['span', 'style', 'href', 'input']
14 |
15 |
16 | def pre_preprocess(x):
17 | return str(x).lower()
18 |
19 |
20 | def rm_spaces(text):
21 | spaces = ['\u200b', '\u200e', '\u202a', '\u2009', '\u2028', '\u202c', '\ufeff', '\uf0d8', '\u2061', '\u3000',
22 | '\x10', '\x7f', '\x9d', '\xad',
23 | '\x97', '\x9c', '\x8b', '\x81', '\x80', '\x8c', '\x85', '\x92', '\x88', '\x8d', '\x80', '\x8e', '\x9a',
24 | '\x94', '\xa0',
25 | '\x8f', '\x82', '\x8a', '\x93', '\x90', '\x83', '\x96', '\x9b', '\x9e', '\x99', '\x87', '\x84', '\x9f',
26 | ]
27 | for space in spaces:
28 | text = text.replace(space, ' ')
29 | return text
30 |
31 |
32 | def remove_urls(x):
33 | x = re.sub(r'(https?://[a-zA-Z0-9.-]*)', r'', x)
34 |
35 | # original
36 | x = re.sub(r'(quote=\w+\s?\w+;?\w+)', r'', x)
37 | return x
38 |
39 |
40 | def clean_html_tags(x, stop_words=[]):
41 | for r in html_tags:
42 | x = x.replace(r, '')
43 | for r in empty_expressions:
44 | x = x.replace(r, ' ')
45 | for r in stop_words:
46 | x = x.replace(r, '')
47 | return x
48 |
49 |
50 | def replace_num(text):
51 | text = re.sub('[0-9]{5,}', '', text)
52 | text = re.sub('[0-9]{4}', '', text)
53 | text = re.sub('[0-9]{3}', '', text)
54 | text = re.sub('[0-9]{2}', '', text)
55 | return text
56 |
57 |
58 | def get_url_num(x):
59 | pattern = "https?://[\w/:%#\$&\?\(\)~\.=\+\-]+"
60 | urls = re.findall(pattern, x)
61 | return len(urls)
62 |
63 |
64 | def clean_puncts(x):
65 | puncts = [',', '.', '"', ':', ')', '(', '-', '!', '?', '|', ';', "'", '$', '&', '/', '[', ']', '>', '%', '=', '#',
66 | '*',
67 | '+', '\\', '•', '~', '@', '£',
68 | '·', '_', '{', '}', '©', '^', '®', '`', '<', '→', '°', '€', '™', '›', '♥', '←', '×', '§', '″', '′', 'Â',
69 | '█',
70 | '½', 'à', '…', '\n', '\xa0', '\t',
71 | '“', '★', '”', '–', '●', 'â', '►', '−', '¢', '²', '¬', '░', '¶', '↑', '±', '¿', '▾', '═', '¦', '║', '―',
72 | '¥',
73 | '▓', '—', '‹', '─', '\u3000', '\u202f',
74 | '▒', ':', '¼', '⊕', '▼', '▪', '†', '■', '’', '▀', '¨', '▄', '♫', '☆', 'é', '¯', '♦', '¤', '▲', 'è', '¸',
75 | '¾',
76 | 'Ã', '⋅', '‘', '∞', '«',
77 | '∙', ')', '↓', '、', '│', '(', '»', ',', '♪', '╩', '╚', '³', '・', '╦', '╣', '╔', '╗', '▬', '❤', 'ï', 'Ø',
78 | '¹',
79 | '≤', '‡', '√', ]
80 | for punct in puncts:
81 | x = x.replace(punct, f' {punct} ')
82 | return x
83 |
84 |
85 | # zenkaku = '0,1,2,3,4,5,6,7,8,9,(,),*,「,」,[,],【,】,<,>,?,・,#,@,$,%,='.split(',')
86 | # hankaku = '0,1,2,3,4,5,6,7,8,9,q,a,z,w,s,x,c,d,e,r,f,v,b,g,t,y,h,n,m,j,u,i,k,l,o,p'.split(',')
87 |
88 | def clean_text_jp(x):
89 | x = x.replace('。', '')
90 | x = x.replace('、', '')
91 | x = x.replace('\n', '') # 改行削除
92 | x = x.replace('\t', '') # タブ削除
93 | x = x.replace('\r', '')
94 | x = x.replace('・', ' ')
95 | x = re.sub(re.compile(r'[!-\/:-@[-`{-~]'), ' ', x)
96 | x = re.sub(r'\[math\]', ' LaTex math ', x) # LaTex削除
97 | x = re.sub(r'\[\/math\]', ' LaTex math ', x) # LaTex削除
98 | x = re.sub(r'\\', ' LaTex ', x) # LaTex削除
99 | # for r in zenkaku+hankaku:
100 | # x = x.replace(str(r), '')
101 | x = re.sub(' +', ' ', x)
102 | return x
103 |
104 |
105 | def preprocess(data):
106 | data = data.progress_apply(lambda x: pre_preprocess(x))
107 | data = data.progress_apply(lambda x: rm_spaces(x))
108 | data = data.progress_apply(lambda x: remove_urls(x))
109 | data = data.progress_apply(lambda x: clean_puncts(x))
110 | data = data.progress_apply(lambda x: replace_num(x))
111 | data = data.progress_apply(lambda x: clean_html_tags(x, stop_words=other))
112 | data = data.progress_apply(lambda x: clean_text_jp(x))
113 | return data
114 |
--------------------------------------------------------------------------------
/src/params/001.yaml:
--------------------------------------------------------------------------------
1 | note: 'Target is mu.'
2 |
3 | module_params:
4 | target: mu
5 | optim: radam
6 | lr: 1e-4
7 | weight_decay: 1e-4
8 | batch_size: 32
9 | ema_decay: 0.9999
10 | ema_eval_freq: 1
11 | fold: 0
12 | n_splits: 4
13 | seed: 1
14 |
15 | trainer_params:
16 | epochs: 800
17 | gpus: [0]
18 | num_tpu_cores: null
19 | use_16bit: false
20 |
--------------------------------------------------------------------------------
/src/run_create_db.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm.auto import tqdm
6 |
7 | from dimenet.const import DATA_QM9_DIR, STRUCTURES_CSV, ATOM_MAP, QM9_DB
8 |
9 |
10 | def processQM9_file(filename):
11 | path = DATA_QM9_DIR / filename
12 |
13 | stats = pd.read_csv(path, sep=' |\t', engine='python', skiprows=1, nrows=1, header=None)
14 | stats = stats.loc[:, 2:]
15 | stats.columns = ['rc_A', 'rc_B', 'rc_C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G',
16 | 'Cv']
17 | stats.astype(np.float32)
18 |
19 | mm = pd.read_csv(path, sep='\t', engine='python', skiprows=2, skipfooter=3, names=range(5))[4]
20 | if mm.dtype == 'O':
21 | mm = mm.str.replace('*^', 'e', regex=False).astype(float)
22 |
23 | return {
24 | **stats.iloc[0].to_dict(),
25 | 'mulliken': mm.values.astype(np.float32),
26 | }
27 |
28 |
29 | def encode_atom(atom: str) -> int:
30 | return ATOM_MAP[atom]
31 |
32 |
33 | def create_db(out_path: Path):
34 | df = pd.read_csv(STRUCTURES_CSV)
35 | mol_grouped = df.groupby('molecule_name')
36 |
37 | def process(name):
38 | mol = mol_grouped.get_group(name)
39 | R = mol[['x', 'y', 'z']].values
40 | Z = mol['atom'].apply(encode_atom).values
41 |
42 | qm9_orig = processQM9_file(f'{name}.xyz')
43 |
44 | return {
45 | 'name': name,
46 | 'R': R.reshape(-1).astype(np.float32),
47 | 'Z': Z.astype(np.int32),
48 | **qm9_orig,
49 | }
50 |
51 | results = [
52 | process(name)
53 | for name in tqdm(mol_grouped.groups.keys())
54 | ]
55 | pd.DataFrame(results).to_parquet(str(out_path))
56 |
57 |
58 | if __name__ == '__main__':
59 | print(f'Create {QM9_DB}')
60 | create_db(QM9_DB)
61 |
--------------------------------------------------------------------------------
/src/run_train.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import dataclasses
3 | from functools import cached_property
4 | from logging import getLogger, FileHandler
5 | from multiprocessing import cpu_count
6 | from os import cpu_count
7 | from pathlib import Path
8 | from time import time
9 | from typing import Callable, List, Dict, Optional
10 |
11 | import pandas as pd
12 | import pytorch_lightning as pl
13 | import torch
14 | import torch.nn as nn
15 | from omegaconf import DictConfig
16 | from pytorch_lightning import seed_everything
17 | from pytorch_lightning.callbacks import ModelCheckpoint
18 | from pytorch_lightning.loggers import TensorBoardLogger
19 | from pytorch_ranger import Ranger
20 | from sklearn.model_selection import KFold
21 | from torch.optim.optimizer import Optimizer
22 | from torch.utils.data import Dataset
23 | from torch_optimizer import RAdam
24 |
25 | from dimenet.loader import get_loader, AtomsBatch
26 | from dimenet.logging import configure_logging
27 | from dimenet.loss import mae_loss
28 | from dimenet.modules.dimenet import DimeNet
29 | from dimenet.params import ModuleParams, Params
30 | from mylib.torch.data.dataset import PandasDataset
31 | from mylib.torch.tools.ema.utils import update_ema
32 |
33 |
34 | @dataclasses.dataclass(frozen=True)
35 | class Metrics:
36 | lr: float
37 | loss: float
38 | lmae: float
39 |
40 |
41 | class PLModule(pl.LightningModule):
42 | def __init__(self, hparams: DictConfig):
43 | super().__init__()
44 | self.hparams = hparams
45 | self.model = DimeNet(num_targets=1)
46 |
47 | self.train_dataset: Optional[Dataset] = None
48 | self.val_dataset: Optional[Dataset] = None
49 | self.best: float = float('inf')
50 | self.ema_model: Optional[nn.Module] = None
51 |
52 | def setup(self, stage: str):
53 | df = pd.read_parquet(self.hp.db_path)
54 |
55 | folds = KFold(
56 | n_splits=self.hp.n_splits,
57 | random_state=self.hp.seed,
58 | shuffle=True,
59 | )
60 | train_idx, val_idx = list(folds.split(df))[self.hp.fold]
61 |
62 | self.train_dataset = PandasDataset(df.iloc[train_idx])
63 | self.val_dataset = PandasDataset(df.iloc[val_idx])
64 |
65 | def on_train_start(self) -> None:
66 | super(PLModule, self).on_train_start()
67 | # Init ema model.
68 | if self.hp.ema_decay is not None:
69 | self.ema_model = copy.deepcopy(self.model)
70 | for p in self.ema_model.parameters():
71 | p.requires_grad_(False)
72 |
73 | def optimizer_step(
74 | self,
75 | epoch: int,
76 | batch_idx: int,
77 | optimizer: Optimizer,
78 | optimizer_idx: int,
79 | second_order_closure: Optional[Callable] = None,
80 | on_tpu: bool = False,
81 | using_native_amp: bool = False,
82 | using_lbfgs: bool = False,
83 | ) -> None:
84 | super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, second_order_closure)
85 | if self.ema_model is not None:
86 | update_ema(self.ema_model, self.model, self.hp.ema_decay)
87 |
88 | def forward(self, x):
89 | return self.model.forward(x)
90 |
91 | def training_step(self, batch, batch_idx):
92 | result = self.step(batch, prefix='train')
93 | return {
94 | 'loss': result['train_loss'],
95 | **result,
96 | }
97 |
98 | def validation_step(self, batch, batch_idx):
99 | result = self.step(batch, prefix='val')
100 |
101 | if self.eval_ema:
102 | result_ema = self.step(batch, prefix='ema', model=self.ema_model)
103 | else:
104 | result_ema = {}
105 |
106 | return {
107 | **result,
108 | **result_ema,
109 | }
110 |
111 | def step(self, batch, prefix: str, model=None) -> Dict:
112 | batch = AtomsBatch(**batch)
113 | y_true = batch[self.hp.target].unsqueeze(-1)
114 |
115 | if model is None:
116 | y_pred = self.forward(batch)
117 | else:
118 | y_pred = model(batch)
119 |
120 | assert y_pred.shape == y_true.shape, f'{y_pred.shape}, {y_true.shape}'
121 |
122 | mae = mae_loss(y_pred, y_true)
123 | lmae = torch.log(mae)
124 | size = len(y_true)
125 |
126 | return {
127 | f'{prefix}_loss': lmae,
128 | f'{prefix}_mae': mae,
129 | f'{prefix}_size': size,
130 | }
131 |
132 | def training_epoch_end(self, outputs):
133 | metrics = self.__collect_metrics(outputs, 'train')
134 | self.__log(metrics, 'train')
135 |
136 | return {}
137 |
138 | def validation_epoch_end(self, outputs):
139 | metrics = self.__collect_metrics(outputs, 'val')
140 | self.__log(metrics, 'val')
141 |
142 | if self.eval_ema:
143 | metrics_ema = self.__collect_metrics(outputs, 'ema')
144 | self.__log(metrics_ema, 'ema')
145 | else:
146 | metrics_ema = None
147 |
148 | if metrics.loss < self.best:
149 | self.best = metrics.loss
150 |
151 | return {
152 | 'progress_bar': {
153 | 'val_loss': metrics.loss,
154 | 'best': self.best,
155 | },
156 | 'val_loss': metrics.loss,
157 | 'ema_loss': metrics_ema.loss if metrics_ema is not None else None,
158 | }
159 |
160 | def __collect_metrics(self, outputs: List[Dict], prefix: str) -> Metrics:
161 | loss, mae = 0, 0
162 | total_size = 0
163 |
164 | for o in outputs:
165 | size = o[f'{prefix}_size']
166 | total_size += size
167 | loss += o[f'{prefix}_loss'] * size
168 | mae += o[f'{prefix}_mae'] * size
169 | loss = loss / total_size
170 | mae = mae / total_size
171 |
172 | # noinspection PyTypeChecker
173 | return Metrics(
174 | lr=self.trainer.optimizers[0].param_groups[0]['lr'],
175 | loss=loss,
176 | lmae=torch.log(mae),
177 | )
178 |
179 | def __log(self, metrics: Metrics, prefix: str):
180 | if self.global_step > 0:
181 | self.logger.experiment.add_scalar('lr', metrics.lr, self.current_epoch)
182 | for k, v in dataclasses.asdict(metrics).items():
183 | if k == 'lr':
184 | continue
185 | self.logger.experiment.add_scalars(k, {
186 | prefix: v,
187 | }, self.current_epoch)
188 |
189 | def train_dataloader(self):
190 | return get_loader(
191 | self.train_dataset,
192 | batch_size=self.hp.batch_size,
193 | shuffle=True,
194 | num_workers=cpu_count(),
195 | pin_memory=True,
196 | cutoff=5.,
197 | )
198 |
199 | def val_dataloader(self):
200 | return get_loader(
201 | self.val_dataset,
202 | batch_size=self.hp.batch_size,
203 | shuffle=False,
204 | num_workers=cpu_count(),
205 | pin_memory=True,
206 | cutoff=5.,
207 | )
208 |
209 | def configure_optimizers(self):
210 | if self.hp.optim == 'ranger':
211 | optim = Ranger
212 | elif self.hp.optim == 'radam':
213 | optim = RAdam
214 | else:
215 | raise Exception(f'Not supported optim: {self.hp.optim}')
216 | opt = optim(
217 | self.model.parameters(),
218 | lr=self.hp.lr,
219 | weight_decay=self.hp.weight_decay,
220 | )
221 | return [opt]
222 |
223 | @property
224 | def eval_ema(self) -> bool:
225 | if self.ema_model is None:
226 | return False
227 | f = self.hp.ema_eval_freq
228 | return self.current_epoch % f == f - 1
229 |
230 | @cached_property
231 | def hp(self) -> ModuleParams:
232 | return ModuleParams(**self.hparams)
233 |
234 |
235 | def train(params: Params):
236 | seed_everything(params.m.seed)
237 |
238 | tb_logger = TensorBoardLogger(
239 | params.t.save_dir,
240 | name=f'dimenet_{params.m.target}',
241 | version=str(int(time())),
242 | )
243 |
244 | log_dir = Path(tb_logger.log_dir)
245 | log_dir.mkdir(parents=True, exist_ok=True)
246 |
247 | logger = getLogger('lightning')
248 | logger.addHandler(FileHandler(log_dir / 'train.log'))
249 | logger.info(params.pretty())
250 |
251 | trainer = pl.Trainer(
252 | max_epochs=params.t.epochs,
253 | gpus=params.t.gpus,
254 | tpu_cores=params.t.num_tpu_cores,
255 | logger=tb_logger,
256 | precision=16 if params.t.use_16bit else 32,
257 | amp_level='O1' if params.t.use_16bit else None,
258 | resume_from_checkpoint=params.t.resume_from_checkpoint,
259 | weights_save_path=str(params.t.save_dir),
260 | checkpoint_callback=ModelCheckpoint(
261 | monitor='ema_loss',
262 | save_last=True,
263 | verbose=True,
264 | ),
265 | )
266 | net = PLModule(params.m.dict_config())
267 |
268 | trainer.fit(net)
269 |
270 |
271 | if __name__ == '__main__':
272 | configure_logging()
273 | params = Params.load()
274 | train(params)
275 |
--------------------------------------------------------------------------------