├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── LICENSE ├── README.md ├── build.py ├── chytorch ├── nn │ ├── __init__.py │ ├── lora │ │ ├── __init__.py │ │ ├── embedding.py │ │ └── linear.py │ ├── losses.py │ ├── molecule │ │ ├── __init__.py │ │ ├── _embedding.py │ │ └── encoder.py │ ├── reaction.py │ ├── slicer.py │ ├── transformer │ │ ├── __init__.py │ │ ├── attention │ │ │ ├── __init__.py │ │ │ └── graphormer.py │ │ └── encoder.py │ └── voting │ │ ├── __init__.py │ │ ├── _kfold.py │ │ ├── binary.py │ │ ├── classifier.py │ │ └── regressor.py ├── utils │ ├── __init__.py │ └── data │ │ ├── __init__.py │ │ ├── _utils.py │ │ ├── lmdb.py │ │ ├── molecule │ │ ├── __init__.py │ │ ├── _unpack.pyx │ │ ├── conformer.py │ │ ├── dummy.py │ │ ├── encoder.py │ │ └── rdkit.py │ │ ├── product.py │ │ ├── reaction │ │ ├── __init__.py │ │ └── encoder.py │ │ ├── sampler.py │ │ ├── smiles.py │ │ └── unpack.py └── zoo │ └── README.md ├── examples └── reference_start.py └── pyproject.toml /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Build Python packages 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | binary: 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | os: [windows-latest, macos-latest, ubuntu-20.04] 14 | python-version: ["3.8", "3.9", "3.10", "3.11"] 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip poetry twine 24 | - name: Build wheel 25 | run: | 26 | poetry build -f wheel 27 | - name: Publish package 28 | run: | 29 | twine upload -u __token__ -p ${{ secrets.PYPI_API_TOKEN }} --non-interactive --skip-existing dist/* 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .dvc/config.local 2 | .dvc/cache 3 | .dvc/plots 4 | .dvc/tmp 5 | 6 | __pycache__ 7 | *.py[cod] 8 | *.pt 9 | *.so 10 | *.dll 11 | *.dynlib 12 | *.c 13 | 14 | .idea 15 | venv 16 | zoo 17 | 18 | *.egg-info 19 | build 20 | dist 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Chytorch [kʌɪtɔːrtʃ] 2 | ==================== 3 | 4 | Library for modeling molecules and reactions in torch way. 5 | 6 | Installation 7 | ------------ 8 | 9 | Use `pip install chytorch` to install release version. 10 | 11 | Or `pip install .` in source code directory to install DEV version. 12 | 13 | Pretrained models 14 | ----------------- 15 | 16 | Chytorch main package doesn't include models zoo. 17 | Each model has its own named package and can be installed separately. 18 | Installed models can be imported as `from chytorch.zoo. import Model`. 19 | 20 | 21 | Usage 22 | ----- 23 | 24 | `chytorch.nn.MoleculeEncoder` - core graphormer layer for molecules encoding. 25 | API is combination of `torch.nn.TransformerEncoderLayer` with `torch.nn.TransformerEncoder`. 26 | 27 | **Batch preparation:** 28 | 29 | `chytorch.utils.data.MoleculeDataset` - Map-like on-the-fly dataset generators for molecules. 30 | Supported `chython.MoleculeContainer` objects, and PaCh structures. 31 | 32 | `chytorch.utils.data.collate_molecules` - collate function for `torch.utils.data.DataLoader`. 33 | 34 | Note: torch DataLoader automatically do proper collation since 1.13 release. 35 | 36 | Example: 37 | 38 | from chytorch.utils.data import MoleculeDataset, SMILESDataset 39 | from torch.utils.data import DataLoader 40 | 41 | data = ['CCO', 'CC=O'] 42 | ds = MoleculeDataset(SMILESDataset(data, cache={})) 43 | dl = DataLoader(ds, batch_size=10) 44 | 45 | **Forward call:** 46 | 47 | Molecules coded as tensors of: 48 | * atoms numbers shifted by 2 (e.g. hydrogen = 3). 49 | 0 - reserved for padding, 1 - reserved for CLS token, 2 - extra reservation. 50 | * neighbors count, including implicit hydrogens shifted by 2 (e.g. CO = CH3OH = [6, 4]). 51 | 0 - reserved for padding, 1 - extra reservation, 2 - no-neighbors, 3 - one neighbor. 52 | * topological distances' matrix shifted by 2 with upper limit. 53 | 0 - reserved for padding, 1 - reserved for not-connected graph components coding, 2 - self-loop, 3 - connected atoms. 54 | 55 | from chytorch.nn import MoleculeEncoder 56 | 57 | encoder = MoleculeEncoder() 58 | for b in dl: 59 | encoder(b) 60 | 61 | **Combine molecules and labels:** 62 | 63 | `chytorch.utils.data.chained_collate` - helper for combining different data parts. Useful for tricky input. 64 | 65 | from torch import stack 66 | from torch.utils.data import DataLoader, TensorDataset 67 | from chytorch.utils.data import chained_collate, collate_molecules, MoleculeDataset 68 | 69 | dl = DataLoader(TensorDataset(MoleculeDataset(molecules_list), properties_tensor), 70 | collate_fn=chained_collate(collate_molecules, stack)) 71 | 72 | **Voting NN with single hidden layer:** 73 | 74 | `chytorch.nn.VotingClassifier`, `chytorch.nn.BinaryVotingClassifier` and `chytorch.nn.VotingRegressor` - speed optimized multiple heads for ensemble predictions. 75 | 76 | **Helper Modules:** 77 | 78 | `chytorch.nn.Slicer` - do tensor slicing. Useful for transformer's CLS token extraction in `torch.nn.Sequence`. 79 | 80 | **Data Wrappers:** 81 | 82 | In `chytorch.utils.data` module stored different data wrappers for simplifying ML workflows. 83 | All wrappers have `torch.utils.data.Dataset` interface. 84 | 85 | * `SizedList` - list wrapper with `size()` method. Useful with `torch.utils.data.TensorDataset`. 86 | * `SMILESDataset` - on-the-fly smiles to `chython.MoleculeContainer` or `chython.ReactionContainer` parser. 87 | * `LMDBMapper` - LMDB KV storage to dataset mapper. 88 | * `TensorUnpack`, `StructUnpack`, `PickleUnpack` - bytes to tensor/object unpackers 89 | 90 | 91 | Publications 92 | ------------ 93 | 94 | [1](https://doi.org/10.1021/acs.jcim.2c00344) Bidirectional Graphormer for Reactivity Understanding: Neural Network Trained to Reaction Atom-to-Atom Mapping Task 95 | -------------------------------------------------------------------------------- /build.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from Cython.Build import build_ext, cythonize 24 | from numpy import get_include 25 | from pathlib import Path 26 | from setuptools import Extension 27 | from setuptools.dist import Distribution 28 | from shutil import copyfile 29 | from sysconfig import get_platform 30 | 31 | 32 | platform = get_platform() 33 | if platform == 'win-amd64': 34 | extra_compile_args = ['/O2'] 35 | elif platform == 'linux-x86_64': 36 | extra_compile_args = ['-O3'] 37 | else: 38 | extra_compile_args = [] 39 | 40 | extensions = [ 41 | Extension('chytorch.utils.data.molecule._unpack', 42 | ['chytorch/utils/data/molecule/_unpack.pyx'], 43 | extra_compile_args=extra_compile_args, 44 | include_dirs=[get_include()]), 45 | ] 46 | 47 | ext_modules = cythonize(extensions, language_level=3) 48 | cmd = build_ext(Distribution({'ext_modules': ext_modules})) 49 | cmd.ensure_finalized() 50 | cmd.run() 51 | 52 | for output in cmd.get_outputs(): 53 | output = Path(output) 54 | copyfile(output, output.relative_to(cmd.build_lib)) 55 | -------------------------------------------------------------------------------- /chytorch/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2021-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from .losses import * 24 | from .molecule import * 25 | from .reaction import * 26 | from .slicer import * 27 | from .voting import * 28 | 29 | 30 | __all__ = ['MoleculeEncoder', 31 | 'ReactionEncoder', 32 | 'Slicer', 33 | 'VotingClassifier', 'VotingRegressor', 'BinaryVotingClassifier', 34 | 'MultiTaskLoss', 35 | 'CensoredLoss', 36 | 'MaskedNaNLoss', 37 | 'MSLELoss'] 38 | -------------------------------------------------------------------------------- /chytorch/nn/lora/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from .embedding import * 24 | from .linear import * 25 | 26 | 27 | __all__ = ['Embedding', 'Linear'] 28 | -------------------------------------------------------------------------------- /chytorch/nn/lora/embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from math import inf 24 | from torch import empty, no_grad, addmm, Tensor 25 | from torch.nn import Embedding as tEmbedding, Parameter, init 26 | from torch.nn.functional import embedding 27 | from typing import Optional 28 | 29 | 30 | class Embedding(tEmbedding): 31 | """ 32 | LoRA wrapped Embedding layer. 33 | """ 34 | def __init__(self, *args, neg_inf_idx: Optional[int] = None, **kwargs): 35 | """ 36 | :param neg_inf_idx: -inf frozen embedding vector 37 | 38 | See torch.nn.Embedding for other params 39 | """ 40 | super().__init__(*args, **kwargs) 41 | self.neg_inf_idx = neg_inf_idx 42 | self.lora_r = 0 43 | if neg_inf_idx is not None: 44 | with no_grad(): 45 | self.weight[neg_inf_idx].fill_(-inf) 46 | 47 | def forward(self, x: Tensor) -> Tensor: 48 | emb = super().forward(x) 49 | if self.lora_r: 50 | a = embedding(x, self.lora_a, self.padding_idx, self.max_norm, 51 | self.norm_type, self.scale_grad_by_freq, self.sparse) 52 | return addmm(emb.flatten(end_dim=-2), a.flatten(end_dim=-2), self.lora_b.transpose(0, 1), 53 | alpha=self._lora_scaling).view(emb.shape) 54 | return emb 55 | 56 | def activate_lora(self, lora_r: int = 0, lora_alpha: float = 1.): 57 | """ 58 | :param lora_r: LoRA factorization dimension 59 | :param lora_alpha: LoRA scaling factor 60 | """ 61 | assert lora_r > 0, 'rank should be greater than zero' 62 | self.weight.requires_grad = False # freeze main weights 63 | self.lora_a = Parameter(init.zeros_(empty(self.num_embeddings, lora_r))) 64 | self.lora_b = Parameter(init.normal_(empty(self.embedding_dim, lora_r))) 65 | 66 | self.lora_r = lora_r 67 | self.lora_alpha = lora_alpha 68 | self._lora_scaling = lora_alpha / lora_r 69 | 70 | def merge_lora(self): 71 | """ 72 | Transform LoRA embedding to normal 73 | """ 74 | if not self.lora_r: 75 | return 76 | self.weight.data += (self.lora_a @ self.lora_b.transpose(0, 1)) * self._lora_scaling 77 | self.weight.requires_grad = True 78 | self.lora_r = 0 79 | del self.lora_a, self.lora_b, self.lora_alpha, self._lora_scaling 80 | 81 | def extra_repr(self) -> str: 82 | r = super().extra_repr() 83 | if self.neg_inf_idx is not None: 84 | r += f', neg_inf_idx={self.neg_inf_idx}' 85 | if self.lora_r: 86 | r += f', lora_r={self.lora_r}, lora_alpha={self.lora_alpha}' 87 | return r 88 | 89 | 90 | __all__ = ['Embedding'] 91 | -------------------------------------------------------------------------------- /chytorch/nn/lora/linear.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023, 2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from math import sqrt 24 | from torch import empty, addmm, Tensor 25 | from torch.nn import Linear as tLinear, Parameter, init 26 | from torch.nn.functional import dropout 27 | 28 | 29 | class Linear(tLinear): 30 | """ 31 | LoRA wrapped Linear layer. 32 | """ 33 | def __init__(self, *args, **kwargs): 34 | super().__init__(*args, **kwargs) 35 | self.lora_r = 0 36 | 37 | def forward(self, x: Tensor) -> Tensor: 38 | out = super().forward(x) 39 | if self.lora_r: 40 | if self.training and self.lora_dropout: 41 | x = dropout(x, self.lora_dropout) 42 | a = x @ self.lora_a.transpose(0, 1) 43 | return addmm(out.flatten(end_dim=-2), a.flatten(end_dim=-2), self.lora_b.transpose(0, 1), 44 | alpha=self._lora_scaling).view(out.shape) 45 | return out 46 | 47 | def activate_lora(self, lora_r: int = 0, lora_alpha: float = 1., lora_dropout: float = 0.): 48 | """ 49 | :param lora_r: LoRA factorization dimension 50 | :param lora_alpha: LoRA scaling factor 51 | :param lora_dropout: LoRA input dropout 52 | """ 53 | assert lora_r > 0, 'rank should be greater than zero' 54 | self.weight.requires_grad = False # freeze main weights 55 | self.lora_a = Parameter(init.kaiming_uniform_(empty(lora_r, self.in_features), a=sqrt(5))) 56 | self.lora_b = Parameter(init.zeros_(empty(self.out_features, lora_r))) 57 | 58 | self.lora_r = lora_r 59 | self.lora_dropout = lora_dropout 60 | self.lora_alpha = lora_alpha 61 | self._lora_scaling = lora_alpha / lora_r 62 | 63 | def merge_lora(self): 64 | """ 65 | Transform LoRA linear to normal 66 | """ 67 | if not self.lora_r: 68 | return 69 | self.weight.data += (self.lora_b @ self.lora_a) * self._lora_scaling 70 | self.weight.requires_grad = True 71 | self.lora_r = 0 72 | del self.lora_a, self.lora_b, self.lora_dropout, self.lora_alpha, self._lora_scaling 73 | 74 | def extra_repr(self) -> str: 75 | r = super().extra_repr() 76 | if self.lora_r: 77 | return r + f', lora_r={self.lora_r}, lora_alpha={self.lora_alpha}, lora_dropout={self.lora_dropout}' 78 | return r 79 | 80 | 81 | __all__ = ['Linear'] 82 | -------------------------------------------------------------------------------- /chytorch/nn/losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023, 2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from torch import float32, zeros_like, exp, Tensor 24 | from torch.nn import Parameter, MSELoss 25 | from torch.nn.modules.loss import _Loss 26 | from torchtyping import TensorType 27 | 28 | 29 | class MultiTaskLoss(_Loss): 30 | """ 31 | Auto-scalable loss for multitask training. 32 | 33 | https://arxiv.org/abs/1705.07115 34 | """ 35 | def __init__(self, loss_type: TensorType['loss_type', bool], *, reduction='mean'): 36 | """ 37 | :param loss_type: vector equal to the number of tasks losses. True for regression and False for classification. 38 | """ 39 | super().__init__(reduction=reduction) 40 | self.log = Parameter(zeros_like(loss_type, dtype=float32)) 41 | self.register_buffer('coefficient', (loss_type + 1.).to(float32)) 42 | 43 | def forward(self, x: TensorType['loss', float]): 44 | """ 45 | :param x: 1d vector of losses or 2d matrix of batch X losses. 46 | """ 47 | x = x / (self.coefficient * exp(self.log)) + self.log / 2 48 | 49 | if self.reduction == 'sum': 50 | return x.sum() 51 | elif self.reduction == 'mean': 52 | return x.mean() 53 | return x 54 | 55 | 56 | class CensoredLoss(_Loss): 57 | """ 58 | Loss wrapper masking input under target qualifier range. 59 | 60 | Masking strategy for different qualifiers (rows) and input-target relations (columns) described below: 61 | 62 | | I < T | I = T | I > T | 63 | ---|-------|-------|-------| 64 | -1 | Mask | | | 65 | 0 | | | | 66 | 1 | | | Mask | 67 | 68 | Wrapped loss should not be configured with mean reduction. 69 | Wrapper does proper mean reduction by ignoring masked values. 70 | 71 | Note: wrapped loss should correctly treat zero-valued input and targets. 72 | """ 73 | def __init__(self, loss: _Loss, reduction: str = 'mean', eps: float = 1e-5): 74 | assert loss.reduction != 'mean', 'given loss should not be configured to `mean` reduction' 75 | super().__init__(reduction=reduction) 76 | self.loss = loss 77 | self.eps = eps 78 | 79 | def forward(self, input: Tensor, target: Tensor, qualifier: Tensor) -> Tensor: 80 | mask = ((qualifier >= 0) | (input >= target)) & ((qualifier <= 0) | (input <= target)) 81 | loss = self.loss(input * mask, target * mask) 82 | if self.reduction == 'mean': 83 | if self.loss.reduction == 'none': 84 | loss = loss.sum() 85 | return loss / (mask.sum() + self.eps) 86 | elif self.reduction == 'sum': 87 | if self.loss.reduction == 'none': 88 | return loss.sum() 89 | return loss 90 | return loss # reduction='none' 91 | 92 | 93 | class MaskedNaNLoss(_Loss): 94 | """ 95 | Loss wrapper masking nan targets and corresponding input values as zeros. 96 | Wrapped loss should not be configured with mean reduction. 97 | Wrapper does proper mean reduction by ignoring masked values. 98 | 99 | Note: wrapped loss should correctly treat zero-valued input and targets. 100 | """ 101 | def __init__(self, loss: _Loss, reduction: str = 'mean', eps: float = 1e-5): 102 | assert loss.reduction != 'mean', 'given loss should not be configured to `mean` reduction' 103 | super().__init__(reduction=reduction) 104 | self.loss = loss 105 | self.eps = eps 106 | 107 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 108 | mask = ~target.isnan() 109 | loss = self.loss(input * mask, target.nan_to_num()) 110 | if self.reduction == 'mean': 111 | if self.loss.reduction == 'none': 112 | loss = loss.sum() 113 | return loss / (mask.sum() + self.eps) 114 | elif self.reduction == 'sum': 115 | if self.loss.reduction == 'none': 116 | return loss.sum() 117 | return loss 118 | return loss # reduction='none' 119 | 120 | 121 | class MSLELoss(MSELoss): 122 | r""" 123 | Mean Squared Logarithmic Error: 124 | 125 | .. math:: \text{MSLE} = \frac{1}{N}\sum_i^N (\log_e(1 + y_i) - \log_e(1 + \hat{y_i}))^2 126 | 127 | Note: Works only for positive target values range. Implicitly clamps negative input. 128 | """ 129 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 130 | return super().forward((input.clamp(min=0) + 1).log(), (target + 1).log()) 131 | 132 | 133 | __all__ = ['MultiTaskLoss', 134 | 'CensoredLoss', 135 | 'MaskedNaNLoss', 136 | 'MSLELoss'] 137 | -------------------------------------------------------------------------------- /chytorch/nn/molecule/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023, 2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from .encoder import * 24 | 25 | 26 | __all__ = ['MoleculeEncoder'] 27 | -------------------------------------------------------------------------------- /chytorch/nn/molecule/_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from torch import empty_like 24 | from torch.nn import Module 25 | from ..lora import Embedding 26 | 27 | 28 | class EmbeddingBag(Module): 29 | def __init__(self, max_neighbors: int = 14, d_model: int = 1024, perturbation: float = 0., max_tokens: int = 121): 30 | assert perturbation >= 0, 'zero or positive perturbation expected' 31 | assert max_tokens >= 121, 'at least 121 tokens should be' 32 | super().__init__() 33 | self.atoms_encoder = Embedding(max_tokens, d_model, 0) 34 | self.neighbors_encoder = Embedding(max_neighbors + 3, d_model, 0) 35 | 36 | self.max_neighbors = max_neighbors 37 | self.perturbation = perturbation 38 | self.max_tokens = max_tokens 39 | 40 | def forward(self, atoms, neighbors): 41 | # cls token in neighbors coded by 0 42 | x = self.atoms_encoder(atoms) + self.neighbors_encoder(neighbors) 43 | 44 | if self.perturbation and self.training: 45 | x = x + empty_like(x).uniform_(-self.perturbation, self.perturbation) 46 | return x 47 | 48 | 49 | __all__ = ['EmbeddingBag'] 50 | -------------------------------------------------------------------------------- /chytorch/nn/molecule/encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2021-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from itertools import repeat 24 | from torch.nn import GELU, Module, ModuleList, LayerNorm 25 | from torchtyping import TensorType 26 | from typing import Tuple, Optional, List 27 | from warnings import warn 28 | from ._embedding import EmbeddingBag 29 | from ..lora import Embedding 30 | from ..transformer import EncoderLayer 31 | from ...utils.data import MoleculeDataBatch 32 | 33 | 34 | def _update(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 35 | if prefix + 'centrality_encoder.weight' in state_dict: 36 | warn('fixed chytorch<1.37 checkpoint', DeprecationWarning) 37 | state_dict[prefix + 'neighbors_encoder.weight'] = state_dict.pop(prefix + 'centrality_encoder.weight') 38 | state_dict[prefix + 'distance_encoder.weight'] = state_dict.pop(prefix + 'spatial_encoder.weight') 39 | if prefix + 'atoms_encoder.weight' in state_dict: 40 | warn('fixed chytorch<1.61 checkpoint', DeprecationWarning) 41 | state_dict[prefix + 'embedding.atoms_encoder.weight'] = state_dict.pop(prefix + 'atoms_encoder.weight') 42 | state_dict[prefix + 'embedding.neighbors_encoder.weight'] = state_dict.pop(prefix + 'neighbors_encoder.weight') 43 | 44 | 45 | class MoleculeEncoder(Module): 46 | """ 47 | Inspired by https://arxiv.org/pdf/2106.05234.pdf 48 | """ 49 | def __init__(self, max_neighbors: int = 14, max_distance: int = 10, d_model: int = 1024, nhead: int = 16, 50 | num_layers: int = 8, dim_feedforward: int = 3072, shared_weights: bool = True, 51 | shared_attention_bias: bool = True, dropout: float = 0.1, activation=GELU, 52 | layer_norm_eps: float = 1e-5, norm_first: bool = False, post_norm: bool = False, 53 | zero_bias: bool = False, perturbation: float = 0., max_tokens: int = 121, 54 | projection_bias: bool = True, ff_bias: bool = True): 55 | """ 56 | Molecule Graphormer from https://doi.org/10.1021/acs.jcim.2c00344. 57 | 58 | :param max_neighbors: maximum atoms neighbors count. 59 | :param max_distance: maximal distance between atoms. 60 | :param shared_weights: ALBERT-like encoder weights sharing. 61 | :param norm_first: do pre-normalization in encoder layers. 62 | :param post_norm: do normalization of output. Works only when norm_first=True. 63 | :param zero_bias: use frozen zero bias of attention for non-reachable atoms. 64 | :param perturbation: add perturbation to embedding (https://aclanthology.org/2021.naacl-main.460.pdf). 65 | Disabled by default 66 | :param shared_attention_bias: use shared distance encoder or unique for each transformer layer. 67 | :param max_tokens: number of tokens in the atom encoder embedding layer. 68 | """ 69 | super().__init__() 70 | self.embedding = EmbeddingBag(max_neighbors, d_model, perturbation, max_tokens) 71 | 72 | self.shared_attention_bias = shared_attention_bias 73 | if shared_attention_bias: 74 | self.distance_encoder = Embedding(max_distance + 3, nhead, int(zero_bias) or None, neg_inf_idx=0) 75 | # None filled encoders mean reusing previously calculated bias. possible manually create different arch. 76 | # this done for speedup in comparison to layer duplication. 77 | self.distance_encoders = [None] * num_layers 78 | self.distance_encoders[0] = self.distance_encoder # noqa 79 | else: 80 | self.distance_encoders = ModuleList(Embedding(max_distance + 3, nhead, 81 | int(zero_bias) or None, neg_inf_idx=0) 82 | for _ in range(num_layers)) 83 | 84 | self.max_distance = max_distance 85 | self.max_neighbors = max_neighbors 86 | self.perturbation = perturbation 87 | self.num_layers = num_layers 88 | self.max_tokens = max_tokens 89 | self.post_norm = post_norm 90 | self.d_model = d_model 91 | self.nhead = nhead 92 | self.dim_feedforward = dim_feedforward 93 | self.dropout = dropout 94 | self.activation = activation 95 | self.layer_norm_eps = layer_norm_eps 96 | self.norm_first = norm_first 97 | self.zero_bias = zero_bias 98 | if post_norm: 99 | assert norm_first, 'post_norm requires norm_first' 100 | self.norm = LayerNorm(d_model, layer_norm_eps) 101 | 102 | self.shared_weights = shared_weights 103 | if shared_weights: 104 | self.layer = EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, norm_first, 105 | projection_bias=projection_bias, ff_bias=ff_bias) 106 | self.layers = [self.layer] * num_layers 107 | else: 108 | # layers sharing scheme can be manually changed. e.g. pairs of shared encoders 109 | self.layers = ModuleList(EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, 110 | layer_norm_eps, norm_first, projection_bias=projection_bias, 111 | ff_bias=ff_bias) for _ in range(num_layers)) 112 | self._register_load_state_dict_pre_hook(_update) 113 | 114 | def forward(self, batch: MoleculeDataBatch, /, *, 115 | cache: Optional[List[Tuple[TensorType['batch', 'atoms+conditions', 'embedding'], 116 | TensorType['batch', 'atoms+conditions', 'embedding']]]] = None) -> \ 117 | TensorType['batch', 'atoms', 'embedding']: 118 | """ 119 | Use 0 for padding. 120 | Atoms should be coded by atomic numbers + 2. 121 | Token 1 reserved for cls token, 2 reserved for molecule cls or training tricks like MLM. 122 | Neighbors should be coded from 2 (means no neighbors) to max neighbors + 2. 123 | Neighbors equal to 1 reserved for training tricks like MLM. Use 0 for cls. 124 | Distances should be coded from 2 (means self-loop) to max_distance + 2. 125 | Non-reachable atoms should be coded by 1. 126 | """ 127 | cache = repeat(None) if cache is None else iter(cache) 128 | atoms, neighbors, distances = batch 129 | 130 | x = self.embedding(atoms, neighbors) 131 | 132 | for lr, d, c in zip(self.layers, self.distance_encoders, cache): 133 | if d is not None: 134 | d_mask = d(distances).permute(0, 3, 1, 2) # BxNxNxH > BxHxNxN 135 | # else: reuse previously calculated mask 136 | x, _ = lr(x, d_mask, cache=c) # noqa 137 | 138 | if self.post_norm: 139 | return self.norm(x) 140 | return x 141 | 142 | @property 143 | def centrality_encoder(self): 144 | warn('centrality_encoder renamed to neighbors_encoder in chytorch 1.37', DeprecationWarning) 145 | return self.neighbors_encoder 146 | 147 | @property 148 | def spatial_encoder(self): 149 | warn('spatial_encoder renamed to distance_encoder in chytorch 1.37', DeprecationWarning) 150 | return self.distance_encoder 151 | 152 | @property 153 | def atoms_encoder(self): 154 | warn('neighbors_encoder moved to embedding submodule in chytorch 1.61', DeprecationWarning) 155 | return self.embedding.atoms_encoder 156 | 157 | @property 158 | def neighbors_encoder(self): 159 | warn('neighbors_encoder moved to embedding submodule in chytorch 1.61', DeprecationWarning) 160 | return self.embedding.neighbors_encoder 161 | 162 | 163 | __all__ = ['MoleculeEncoder'] 164 | -------------------------------------------------------------------------------- /chytorch/nn/reaction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2021-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from math import inf 24 | from torch import zeros_like, float as t_float 25 | from torch.nn import Embedding, GELU, Module 26 | from torchtyping import TensorType 27 | from .molecule import MoleculeEncoder 28 | from .transformer import EncoderLayer 29 | from ..utils.data import ReactionEncoderDataBatch 30 | 31 | 32 | class ReactionEncoder(Module): 33 | def __init__(self, max_neighbors: int = 14, max_distance: int = 10, d_model: int = 1024, n_in_head: int = 16, 34 | n_ex_head: int = 4, num_in_layers: int = 8, num_ex_layers: int = 8, 35 | dim_feedforward: int = 3072, dropout: float = 0.1, activation=GELU, layer_norm_eps: float = 1e-5): 36 | """ 37 | Reaction Graphormer from https://doi.org/10.1021/acs.jcim.2c00344. 38 | 39 | :param max_neighbors: maximum atoms neighbors count. 40 | :param max_distance: maximal distance between atoms. 41 | :param num_in_layers: intramolecular layers count 42 | :param num_ex_layers: reaction-level layers count 43 | """ 44 | super().__init__() 45 | self.molecule_encoder = MoleculeEncoder(max_neighbors=max_neighbors, max_distance=max_distance, d_model=d_model, 46 | nhead=n_in_head, num_layers=num_in_layers, 47 | dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, 48 | layer_norm_eps=layer_norm_eps) 49 | self.role_encoder = Embedding(4, d_model, 0) 50 | self.layer = EncoderLayer(d_model, n_ex_head, dim_feedforward, dropout, activation, layer_norm_eps) 51 | self.layers = [self.layer] * num_ex_layers 52 | self.nhead = n_ex_head 53 | 54 | @property 55 | def max_distance(self): 56 | """ 57 | Distance cutoff in spatial encoder. 58 | """ 59 | return self.molecule_encoder.max_distance 60 | 61 | def forward(self, batch: ReactionEncoderDataBatch) -> TensorType['batch', 'atoms', 'embedding']: 62 | """ 63 | Use 0 for padding. Roles should be coded by 2 for reactants, 3 for products and 1 for special cls token. 64 | Distances - same as molecular encoder distances but batched diagonally. 65 | Used 0 for disabling sharing between molecules. 66 | """ 67 | atoms, neighbors, distances, roles = batch 68 | n = atoms.size(1) 69 | d_mask = zeros_like(roles, dtype=t_float).masked_fill_(roles == 0, -inf).view(-1, 1, 1, n) # BxN > Bx1x1xN > 70 | d_mask = d_mask.expand(-1, self.nhead, n, -1) # > BxHxNxN 71 | 72 | # role is bert sentence encoder used to separate reactants from products and rxn CLS token coding. 73 | # multiplication by roles > 1 used to zeroing rxn cls token and padding. this zeroing gradients too. 74 | x = self.molecule_encoder((atoms, neighbors, distances)) * (roles > 1).unsqueeze_(-1) 75 | x = x + self.role_encoder(roles) 76 | 77 | for lr in self.layers: 78 | x, _ = lr(x, d_mask) 79 | return x 80 | 81 | 82 | __all__ = ['ReactionEncoder'] 83 | -------------------------------------------------------------------------------- /chytorch/nn/slicer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from torch import Tensor 24 | from torch.nn import Module 25 | from typing import Tuple, Union 26 | 27 | 28 | class Slicer(Module): 29 | def __init__(self, *slc: Union[int, slice, Tuple[int, ...]]): 30 | """ 31 | Slice input tensor. For use with Sequential. 32 | 33 | E.g. Slicer(slice(None), 0) equal to Tensor[:, 0] 34 | """ 35 | super().__init__() 36 | self.slice = slc if len(slc) > 1 else slc[0] 37 | 38 | def forward(self, x: Tensor): 39 | return x.__getitem__(self.slice) 40 | 41 | 42 | __all__ = ['Slicer'] 43 | -------------------------------------------------------------------------------- /chytorch/nn/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023, 2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from .attention import * 24 | from .encoder import * 25 | 26 | 27 | __all__ = ['EncoderLayer', 28 | 'MLP', 29 | 'LLaMAMLP', 30 | 'GraphormerAttention'] 31 | -------------------------------------------------------------------------------- /chytorch/nn/transformer/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from .graphormer import * 24 | 25 | 26 | __all__ = ['GraphormerAttention'] 27 | -------------------------------------------------------------------------------- /chytorch/nn/transformer/attention/graphormer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023, 2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from math import sqrt 24 | from torch import softmax, cat, Tensor 25 | from torch.nn import Module 26 | from torch.nn.functional import dropout 27 | from typing import Optional, Tuple 28 | from warnings import warn 29 | from ...lora import Linear 30 | 31 | 32 | def _update_unpacked(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 33 | if prefix + 'in_proj_weight' in state_dict: 34 | warn('fixed chytorch<1.44 checkpoint', DeprecationWarning) 35 | state_dict[prefix + 'qkv_proj.weight'] = state_dict.pop(prefix + 'in_proj_weight') 36 | state_dict[prefix + 'qkv_proj.bias'] = state_dict.pop(prefix + 'in_proj_bias') 37 | state_dict[prefix + 'o_proj.weight'] = state_dict.pop(prefix + 'out_proj.weight') 38 | state_dict[prefix + 'o_proj.bias'] = state_dict.pop(prefix + 'out_proj.bias') 39 | 40 | if prefix + 'qkv_proj.weight' in state_dict: # transform packed projection 41 | q_w, k_w, v_w = state_dict.pop(prefix + 'qkv_proj.weight').chunk(3, dim=0) 42 | state_dict[prefix + 'q_proj.weight'] = q_w 43 | state_dict[prefix + 'k_proj.weight'] = k_w 44 | state_dict[prefix + 'v_proj.weight'] = v_w 45 | 46 | if prefix + 'qkv_proj.bias' in state_dict: 47 | q_b, k_b, v_b = state_dict.pop(prefix + 'qkv_proj.bias').chunk(3, dim=0) 48 | state_dict[prefix + 'q_proj.bias'] = q_b 49 | state_dict[prefix + 'k_proj.bias'] = k_b 50 | state_dict[prefix + 'v_proj.bias'] = v_b 51 | 52 | 53 | def _update_packed(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 54 | if prefix + 'in_proj_weight' in state_dict: 55 | warn('fixed chytorch<1.44 checkpoint', DeprecationWarning) 56 | state_dict[prefix + 'o_proj.weight'] = state_dict.pop(prefix + 'out_proj.weight') 57 | state_dict[prefix + 'o_proj.bias'] = state_dict.pop(prefix + 'out_proj.bias') 58 | 59 | state_dict[prefix + 'qkv_proj.weight'] = state_dict.pop(prefix + 'in_proj_weight') 60 | state_dict[prefix + 'qkv_proj.bias'] = state_dict.pop(prefix + 'in_proj_bias') 61 | elif prefix + 'q_proj.weight' in state_dict: # transform unpacked projection 62 | q_w = state_dict.pop(prefix + 'q_proj.weight') 63 | k_w = state_dict.pop(prefix + 'k_proj.weight') 64 | v_w = state_dict.pop(prefix + 'v_proj.weight') 65 | state_dict[prefix + 'qkv_proj.weight'] = cat([q_w, k_w, v_w]) 66 | 67 | if prefix + 'q_proj.bias' in state_dict: 68 | q_b = state_dict.pop(prefix + 'q_proj.bias') 69 | k_b = state_dict.pop(prefix + 'k_proj.bias') 70 | v_b = state_dict.pop(prefix + 'v_proj.bias') 71 | state_dict[prefix + 'qkv_proj.bias'] = cat([q_b, k_b, v_b]) 72 | 73 | 74 | class GraphormerAttention(Module): 75 | """ 76 | LoRA wrapped Multi-Head Attention 77 | """ 78 | def __init__(self, embed_dim, num_heads, dropout: float = .1, bias: bool = True, separate_proj: bool = False): 79 | """ 80 | :param embed_dim: the size of each embedding vector 81 | :param num_heads: number of heads 82 | :param dropout: attention dropout 83 | :param separate_proj: use separated projections calculations or optimized 84 | """ 85 | assert not embed_dim % num_heads, 'embed_dim must be divisible by num_heads' 86 | super().__init__() 87 | self.embed_dim = embed_dim 88 | self.num_heads = num_heads 89 | self.dropout = dropout 90 | self.separate_proj = separate_proj 91 | self._scale = 1 / sqrt(embed_dim / num_heads) 92 | 93 | if separate_proj: 94 | self.q_proj = Linear(embed_dim, embed_dim, bias=bias) 95 | self.k_proj = Linear(embed_dim, embed_dim, bias=bias) 96 | self.v_proj = Linear(embed_dim, embed_dim, bias=bias) 97 | self._register_load_state_dict_pre_hook(_update_unpacked) 98 | else: # packed projection 99 | self.qkv_proj = Linear(embed_dim, 3 * embed_dim, bias=bias) 100 | self._register_load_state_dict_pre_hook(_update_packed) 101 | self.o_proj = Linear(embed_dim, embed_dim, bias=bias) 102 | 103 | def forward(self, x: Tensor, attn_mask: Tensor, *, 104 | cache: Optional[Tuple[Tensor, Tensor]] = None, 105 | need_weights: bool = False) -> Tuple[Tensor, Optional[Tensor]]: 106 | if self.separate_proj: 107 | q = self.q_proj(x) # BxTxH*E 108 | k = self.k_proj(x) # BxSxH*E (KV seq len can differ from tgt_len with enabled cache trick) 109 | v = self.v_proj(x) # BxSxH*E 110 | else: # optimized 111 | q, k, v = self.qkv_proj(x).chunk(3, dim=-1) 112 | 113 | if cache is not None: 114 | # inference caching. batch should be left padded. shape should be BxSxH*E 115 | bsz, tgt_len, _ = x.shape 116 | ck, cv = cache 117 | ck[:bsz, -tgt_len:] = k 118 | cv[:bsz, -tgt_len:] = v 119 | k, v = ck[:bsz], cv[:bsz] 120 | 121 | # BxTxH*E > BxTxHxE > BxHxTxE 122 | q = q.unflatten(2, (self.num_heads, -1)).transpose(1, 2) 123 | # BxSxH*E > BxSxHxE > BxHxExS 124 | k = k.unflatten(2, (self.num_heads, -1)).permute(0, 2, 3, 1) 125 | # BxSxH*E > BxSxHxE > BxHxSxE 126 | v = v.unflatten(2, (self.num_heads, -1)).transpose(1, 2) 127 | 128 | # BxHxTxE @ BxHxExS > BxHxTxS 129 | a = (q @ k) * self._scale + attn_mask 130 | a = softmax(a, dim=-1) 131 | if self.training and self.dropout: 132 | a = dropout(a, self.dropout) 133 | 134 | # BxHxTxS @ BxHxSxE > BxHxTxE > BxTxHxE > BxTxH*E 135 | o = (a @ v).transpose(1, 2).flatten(2) 136 | o = self.o_proj(o) 137 | 138 | if need_weights: 139 | a = a.sum(dim=1) / self.num_heads 140 | return o, a 141 | else: 142 | return o, None 143 | 144 | 145 | __all__ = ['GraphormerAttention'] 146 | -------------------------------------------------------------------------------- /chytorch/nn/transformer/encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2021-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from torch import Tensor, nn 24 | from torch.nn import Dropout, GELU, LayerNorm, Module, SiLU 25 | from typing import Tuple, Optional, Type 26 | from warnings import warn 27 | from .attention import GraphormerAttention 28 | from ..lora import Linear 29 | 30 | 31 | def _update(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 32 | if prefix + 'linear1.weight' in state_dict: 33 | warn('fixed chytorch<1.64 checkpoint', DeprecationWarning) 34 | state_dict[prefix + 'mlp.linear1.weight'] = state_dict.pop(prefix + 'linear1.weight') 35 | state_dict[prefix + 'mlp.linear2.weight'] = state_dict.pop(prefix + 'linear2.weight') 36 | if prefix + 'linear1.bias' in state_dict: 37 | state_dict[prefix + 'mlp.linear1.bias'] = state_dict.pop(prefix + 'linear1.bias') 38 | state_dict[prefix + 'mlp.linear2.bias'] = state_dict.pop(prefix + 'linear2.bias') 39 | 40 | 41 | class MLP(Module): 42 | def __init__(self, d_model, dim_feedforward, dropout=0.1, activation=GELU, bias: bool = True): 43 | super().__init__() 44 | self.linear1 = Linear(d_model, dim_feedforward, bias=bias) 45 | self.linear2 = Linear(dim_feedforward, d_model, bias=bias) 46 | self.dropout = Dropout(dropout) 47 | 48 | # ad-hoc for resolving class from name 49 | if isinstance(activation, str): 50 | activation = getattr(nn, activation) 51 | self.activation = activation() 52 | 53 | def forward(self, x): 54 | return self.linear2(self.dropout(self.activation(self.linear1(x)))) 55 | 56 | 57 | class LLaMAMLP(Module): 58 | def __init__(self, d_model, dim_feedforward, dropout=0.1, activation=SiLU, bias: bool = False): 59 | super().__init__() 60 | self.linear1 = Linear(d_model, dim_feedforward, bias=bias) 61 | self.linear2 = Linear(d_model, dim_feedforward, bias=bias) 62 | self.linear3 = Linear(dim_feedforward, d_model, bias=bias) 63 | self.dropout = Dropout(dropout) 64 | 65 | # ad-hoc for resolving class from name 66 | if isinstance(activation, str): 67 | activation = getattr(nn, activation) 68 | self.activation = activation() 69 | 70 | def forward(self, x): 71 | return self.linear3(self.dropout(self.activation(self.linear1(x))) * self.linear2(x)) 72 | 73 | 74 | class EncoderLayer(Module): 75 | r"""EncoderLayer based on torch.nn.TransformerEncoderLayer, but batch always first and returns also attention. 76 | 77 | :param d_model: the number of expected features in the input (required). 78 | :param nhead: the number of heads in the multiheadattention models (required). 79 | :param dim_feedforward: the dimension of the feedforward network model (required). 80 | :param dropout: the dropout value (default=0.1). 81 | :param activation: the activation function of the intermediate layer. Default: GELU. 82 | :param layer_norm_eps: the eps value in layer normalization components (default=1e-5). 83 | :param norm_first: if `True`, layer norm is done prior to self attention, multihead 84 | attention and feedforward operations, respectively. Otherwise, it's done after. 85 | """ 86 | def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1, activation=GELU, layer_norm_eps=1e-5, 87 | norm_first: bool = False, attention: Type[Module] = GraphormerAttention, mlp: Type[Module] = MLP, 88 | norm_layer: Type[Module] = LayerNorm, projection_bias: bool = True, ff_bias: bool = True): 89 | super().__init__() 90 | self.self_attn = attention(d_model, nhead, dropout, projection_bias) 91 | self.mlp = mlp(d_model, dim_feedforward, dropout, activation, ff_bias) 92 | 93 | self.norm1 = norm_layer(d_model, eps=layer_norm_eps) 94 | self.norm2 = norm_layer(d_model, eps=layer_norm_eps) 95 | self.dropout1 = Dropout(dropout) 96 | self.dropout2 = Dropout(dropout) 97 | self.norm_first = norm_first 98 | self._register_load_state_dict_pre_hook(_update) 99 | 100 | def forward(self, x: Tensor, attn_mask: Optional[Tensor], *, 101 | need_embedding: bool = True, need_weights: bool = False, 102 | **kwargs) -> Tuple[Optional[Tensor], Optional[Tensor]]: 103 | nx = self.norm1(x) if self.norm_first else x # pre-norm or post-norm 104 | e, a = self.self_attn(nx, attn_mask, need_weights=need_weights, **kwargs) 105 | 106 | if need_embedding: 107 | x = x + self.dropout1(e) 108 | if self.norm_first: 109 | return x + self.dropout2(self.mlp(self.norm2(x))), a 110 | # else: post-norm 111 | x = self.norm1(x) 112 | return self.norm2(x + self.dropout2(self.mlp(x))), a 113 | return None, a 114 | 115 | 116 | __all__ = ['EncoderLayer', 'MLP', 'LLaMAMLP'] 117 | -------------------------------------------------------------------------------- /chytorch/nn/voting/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022, 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from .binary import * 24 | from .regressor import * 25 | from .classifier import * 26 | 27 | 28 | __all__ = ['VotingRegressor', 'VotingClassifier', 'BinaryVotingClassifier'] 29 | -------------------------------------------------------------------------------- /chytorch/nn/voting/_kfold.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022, 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from functools import lru_cache 24 | from torch import ones, zeros 25 | 26 | 27 | @lru_cache() 28 | def k_fold_mask(k_fold, ensemble, batch_size, train, device=None): 29 | """ 30 | :param k_fold: number of folds 31 | :param ensemble: number of predicting heads 32 | :param batch_size: size of batch 33 | :param train: create train of test mask. train - mask only 1/k_fold of data. test - mask 4/k_fold of data. 34 | 35 | :param device: device of mask 36 | """ 37 | assert k_fold >= 3, 'k-fold should be at least 3' 38 | assert not ensemble % k_fold, 'ensemble should be divisible by k-fold' 39 | assert not batch_size % k_fold, 'batch size should be divisible by k-fold' 40 | 41 | if train: 42 | m = ones(batch_size, ensemble, device=device) # k-th fold mask 43 | disable = 0. 44 | else: # test/validation 45 | m = zeros(batch_size, ensemble, device=device) # k-th fold mask 46 | disable = 1. 47 | 48 | batch_size //= k_fold 49 | ensemble //= k_fold 50 | for n in range(k_fold): # disable folds 51 | m[n * batch_size: n * batch_size + batch_size, n * ensemble: n * ensemble + ensemble] = disable 52 | return m 53 | 54 | 55 | __all__ = ['k_fold_mask'] 56 | -------------------------------------------------------------------------------- /chytorch/nn/voting/binary.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022, 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from math import nan 24 | from torch import sigmoid, no_grad 25 | from torch.nn import GELU 26 | from torch.nn.functional import binary_cross_entropy_with_logits 27 | from torchtyping import TensorType 28 | from typing import Union, Optional 29 | from ._kfold import k_fold_mask 30 | from .regressor import VotingRegressor 31 | 32 | 33 | class BinaryVotingClassifier(VotingRegressor): 34 | """ 35 | Simple two-layer perceptron with layer normalization and dropout adopted for effective 36 | ensemble binary classification tasks. 37 | """ 38 | def __init__(self, ensemble: int = 10, output: int = 1, hidden: int = 256, input: Optional[int] = None, 39 | dropout: float = .5, activation=GELU, layer_norm_eps: float = 1e-5, 40 | loss_function=binary_cross_entropy_with_logits, norm_first: bool = False): 41 | super().__init__(ensemble, output, hidden, input, dropout, activation, 42 | layer_norm_eps, loss_function, norm_first) 43 | 44 | @no_grad() 45 | def predict(self, x: TensorType['batch', 'embedding'], *, 46 | k_fold: Optional[int] = None) -> Union[TensorType['batch', int], TensorType['batch', 'output', int]]: 47 | """ 48 | Average class prediction 49 | 50 | :param x: features 51 | :param k_fold: average ensemble according to k-fold trick described in the `loss` method. 52 | """ 53 | return (self.predict_proba(x, k_fold=k_fold) > .5).long() 54 | 55 | @no_grad() 56 | def predict_proba(self, x: TensorType['batch', 'embedding'], *, 57 | k_fold: Optional[int] = None) -> Union[TensorType['batch', float], 58 | TensorType['batch', 'output', float]]: 59 | """ 60 | Average probability 61 | 62 | :param x: features 63 | :param k_fold: average ensemble according to k-fold trick described in the `loss` method. 64 | """ 65 | p = sigmoid(self.forward(x)) 66 | if k_fold is not None: 67 | m = k_fold_mask(k_fold, self._ensemble, x.size(0), True, p.device).bool() # B x E 68 | if self._output != 1: 69 | m.unsqueeze_(1) # B x 1 x E 70 | p.masked_fill_(m, nan) 71 | return p.nanmean(-1) 72 | 73 | 74 | __all__ = ['BinaryVotingClassifier'] 75 | -------------------------------------------------------------------------------- /chytorch/nn/voting/classifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022, 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from math import nan 24 | from torch import bmm, no_grad, Tensor 25 | from torch.nn import Dropout, GELU, LayerNorm, LazyLinear, Linear, Module 26 | from torch.nn.functional import cross_entropy, softmax 27 | from torchtyping import TensorType 28 | from typing import Optional, Union 29 | from ._kfold import k_fold_mask 30 | 31 | 32 | class VotingClassifier(Module): 33 | """ 34 | Simple two-layer perceptron with layer normalization and dropout adopted for effective ensemble classification. 35 | """ 36 | def __init__(self, ensemble: int = 10, output: int = 1, n_classes: int = 2, hidden: int = 256, 37 | input: Optional[int] = None, dropout: float = .5, activation=GELU, layer_norm_eps: float = 1e-5, 38 | loss_function=cross_entropy, norm_first: bool = False): 39 | """ 40 | :param ensemble: number of predictive heads per output 41 | :param input: input features size. By-default do lazy initialization 42 | :param output: number of predicted properties in multitask mode. By-default single task mode is active 43 | :param n_classes: number of classes 44 | :param norm_first: do normalization of input 45 | """ 46 | assert n_classes >= 2, 'number of classes should be higher or equal than 2' 47 | assert ensemble > 0 and output > 0, 'ensemble and output should be positive integers' 48 | super().__init__() 49 | if input is None: 50 | self.linear1 = LazyLinear(hidden * ensemble * output) 51 | assert not norm_first, 'input size required for prenormalization' 52 | else: 53 | if norm_first: 54 | self.norm_first = LayerNorm(input, layer_norm_eps) 55 | self.linear1 = Linear(input, hidden * ensemble * output) 56 | self.layer_norm = LayerNorm(hidden, layer_norm_eps) 57 | self.activation = activation() 58 | self.dropout = Dropout(dropout) 59 | self.linear2 = Linear(hidden, ensemble * output * n_classes) 60 | self.loss_function = loss_function 61 | 62 | self._n_classes = n_classes 63 | self._ensemble = ensemble 64 | self._input = input 65 | self._hidden = hidden 66 | self._output = output 67 | self._norm_first = norm_first 68 | 69 | def forward(self, x): 70 | """ 71 | Returns ensemble of predictions in shape [Batch x Ensemble x Classes]. 72 | """ 73 | if self._norm_first: 74 | x = self.norm_first(x) 75 | # B x E >> B x N*H >> B x N x H >> N x B x H 76 | x = self.linear1(x).view(-1, self._ensemble * self._output, self._hidden).transpose(0, 1) 77 | x = self.dropout(self.activation(self.layer_norm(x))) 78 | # N * C x H >> N x C x H >> N x H x C 79 | w = self.linear2.weight.view(-1, self._n_classes, self._hidden).transpose(1, 2) 80 | # N x B x C >> B x N x C 81 | x = bmm(x, w).transpose(0, 1).contiguous() + self.linear2.bias.view(-1, self._n_classes) 82 | if self._output != 1: # MT mode 83 | return x.view(-1, self._output, self._ensemble, self._n_classes) # B x O x E x C 84 | return x # B x E x C 85 | 86 | def loss(self, x: TensorType['batch', 'embedding'], 87 | y: Union[TensorType['batch', 1, int], TensorType['batch', 'output', int]], 88 | k_fold: Optional[int] = None, ignore_index: int = -100) -> Tensor: 89 | """ 90 | Apply loss function to ensemble of predictions. 91 | 92 | Note: y should be a column vector in single task or Batch x Output matrix in multitask mode. 93 | 94 | :param x: features 95 | :param y: properties 96 | :param k_fold: Cross-validation training procedure of ensemble model. 97 | Only k - 1 / k heads of ensemble trains on k - 1 / k items of each batch. 98 | On validation step 1 / k of same batches used to evaluate heads. 99 | Batch and ensemble sizes should be divisible by k. Disabled by default. 100 | """ 101 | p = self.forward(x) # B x E x C or B x O x E x C 102 | if self._output != 1: # MT mode 103 | y = y.unsqueeze(-1).expand(-1, -1, self._ensemble) # B x O > B x O x 1 > B x O x E 104 | else: 105 | y = y.expand(-1, self._ensemble) # B x E 106 | 107 | if k_fold is not None: 108 | m = k_fold_mask(k_fold, self._ensemble, x.size(0), not self.training, p.device).bool() # B x E 109 | if self._output != 1: 110 | # B x E > B x 1 x E 111 | m.unsqueeze_(1) 112 | y = y.masked_fill(m, ignore_index) 113 | 114 | # B x E x C >> B * E x C 115 | # B x O x E x C >> B * O * E x C 116 | p = p.flatten(end_dim=-2) 117 | # B x E >> B * E 118 | # B x O x E >> B * O * E 119 | y = y.flatten() 120 | return self.loss_function(p, y) 121 | 122 | @no_grad() 123 | def predict(self, x: TensorType['batch', 'embedding'], *, 124 | k_fold: Optional[int] = None) -> Union[TensorType['batch', int], TensorType['batch', 'output', int]]: 125 | """ 126 | Average class prediction 127 | 128 | :param k_fold: average ensemble according to k-fold trick described in the `loss` method. 129 | """ 130 | return self.predict_proba(x, k_fold=k_fold).argmax(-1) # B or B x O 131 | 132 | @no_grad() 133 | def predict_proba(self, x: TensorType['batch', 'embedding'], *, 134 | k_fold: Optional[int] = None) -> Union[TensorType['batch', 'classes', float], 135 | TensorType['batch', 'output', 'classes', float]]: 136 | """ 137 | Average probability 138 | 139 | :param x: features 140 | :param k_fold: average ensemble according to k-fold trick described in the `loss` method. 141 | """ 142 | p = softmax(self.forward(x), -1) 143 | if k_fold is not None: 144 | m = k_fold_mask(k_fold, self._ensemble, x.size(0), True, p.device).bool().unsqueeze_(-1) # B x E x 1 145 | if self._output != 1: 146 | m.unsqueeze_(1) # B x 1 x E x 1 147 | p.masked_fill_(m, nan) 148 | return p.nanmean(-2) # B x C or B x O x C 149 | 150 | 151 | __all__ = ['VotingClassifier'] 152 | -------------------------------------------------------------------------------- /chytorch/nn/voting/regressor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022, 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from math import nan 24 | from torch import bmm, no_grad, Tensor 25 | from torch.nn import Dropout, GELU, LayerNorm, LazyLinear, Linear, Module 26 | from torch.nn.functional import smooth_l1_loss 27 | from torchtyping import TensorType 28 | from typing import Optional, Union 29 | from ._kfold import k_fold_mask 30 | 31 | 32 | class VotingRegressor(Module): 33 | """ 34 | Simple two-layer perceptron with layer normalization and dropout adopted for effective ensemble regression modeling. 35 | """ 36 | def __init__(self, ensemble: int = 10, output: int = 1, hidden: int = 256, input: Optional[int] = None, 37 | dropout: float = .5, activation=GELU, layer_norm_eps: float = 1e-5, loss_function=smooth_l1_loss, 38 | norm_first: bool = False): 39 | """ 40 | :param ensemble: number of predictive heads per output 41 | :param input: input features size. By-default do lazy initialization 42 | :param output: number of predicted properties in multitask mode. By-default single task mode is active 43 | :param norm_first: do normalization of input 44 | """ 45 | assert ensemble > 0 and output > 0, 'ensemble and output should be positive integers' 46 | super().__init__() 47 | if input is None: 48 | self.linear1 = LazyLinear(hidden * ensemble * output) 49 | assert not norm_first, 'input size required for prenormalization' 50 | else: 51 | if norm_first: 52 | self.norm_first = LayerNorm(input, layer_norm_eps) 53 | self.linear1 = Linear(input, hidden * ensemble * output) 54 | 55 | self.layer_norm = LayerNorm(hidden, layer_norm_eps) 56 | self.activation = activation() 57 | self.dropout = Dropout(dropout) 58 | self.linear2 = Linear(hidden, ensemble * output) 59 | self.loss_function = loss_function 60 | 61 | self._ensemble = ensemble 62 | self._input = input 63 | self._hidden = hidden 64 | self._output = output 65 | self._norm_first = norm_first 66 | 67 | def forward(self, x): 68 | """ 69 | Returns ensemble of predictions in shape [Batch x Output*Ensemble]. 70 | """ 71 | if self._norm_first: 72 | x = self.norm_first(x) 73 | # B x E >> B x N*H >> B x N x H >> N x B x H 74 | x = self.linear1(x).view(-1, self._ensemble * self._output, self._hidden).transpose(0, 1) 75 | x = self.dropout(self.activation(self.layer_norm(x))) 76 | # N x H >> N x H x 1 77 | w = self.linear2.weight.unsqueeze(2) 78 | # N x B x 1 >> N x B >> B x N 79 | x = bmm(x, w).squeeze(-1).transpose(0, 1).contiguous() + self.linear2.bias 80 | if self._output != 1: 81 | return x.view(-1, self._output, self._ensemble) # B x O x E 82 | return x # B x E 83 | 84 | def loss(self, x: TensorType['batch', 'embedding'], 85 | y: Union[TensorType['batch', 1, float], TensorType['batch', 'output', float]], 86 | k_fold: Optional[int] = None) -> Tensor: 87 | """ 88 | Apply loss function to ensemble of predictions. 89 | 90 | Note: y should be a column vector in single task or Batch x Output matrix in multitask mode. 91 | 92 | :param x: features 93 | :param y: properties 94 | :param k_fold: Cross-validation training procedure of ensemble model. 95 | Only k - 1 / k heads of ensemble do train on k - 1 / k items of each batch. 96 | On validation step 1 / k of same batches used to evaluate heads. 97 | Batch and ensemble sizes should be divisible by k. Disabled by default. 98 | """ 99 | p = self.forward(x) 100 | if self._output != 1: # MT mode 101 | y = y.unsqueeze(-1) # B x O > B x O x 1 102 | y = y.expand(p.size()) # B x E or B x O x E 103 | 104 | if k_fold is not None: 105 | m = k_fold_mask(k_fold, self._ensemble, x.size(0), self.training, p.device) # B x E 106 | if self._output != 1: 107 | m = m.unsqueeze(1) # B x 1 x E 108 | p = p * m # zeros in mask disable gradients 109 | y = y * m # disable errors in test/val loss 110 | return self.loss_function(p, y) 111 | 112 | @no_grad() 113 | def predict(self, x: TensorType['batch', 'embedding'], *, 114 | k_fold: Optional[int] = None) -> Union[TensorType['batch', float], 115 | TensorType['batch', 'output', float]]: 116 | """ 117 | Average prediction 118 | 119 | :param x: features. 120 | :param k_fold: average ensemble according to k-fold trick described in the `loss` method. 121 | """ 122 | p = self.forward(x) 123 | if k_fold is not None: 124 | m = k_fold_mask(k_fold, self._ensemble, x.size(0), True, p.device).bool() # B x E 125 | if self._output != 1: 126 | m = m.unsqueeze(1) # B x 1 x E 127 | p.masked_fill_(m, nan) 128 | return p.nanmean(-1) 129 | 130 | 131 | __all__ = ['VotingRegressor'] 132 | -------------------------------------------------------------------------------- /chytorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2021-2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | -------------------------------------------------------------------------------- /chytorch/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2021-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from .lmdb import * 24 | from .molecule import * 25 | from .product import * 26 | from .reaction import * 27 | from .sampler import * 28 | from .smiles import * 29 | from .unpack import * 30 | from ._utils import * 31 | 32 | 33 | __all__ = ['MoleculeDataset', 'collate_molecules', 'left_padded_collate_molecules', 34 | 'ConformerDataset', 'collate_conformers', 35 | 'ReactionEncoderDataset', 'collate_encoded_reactions', 36 | 'RDKitConformerDataset', 37 | 'SMILESDataset', 38 | 'ProductDataset', 39 | 'SuppressException', 'SizedList', 'ShuffledList', 40 | 'ByteRange', 41 | 'LMDBMapper', 42 | 'PickleUnpack', 43 | 'JsonUnpack', 44 | 'StructUnpack', 45 | 'TensorUnpack', 46 | 'Decompress', 47 | 'Decode', 48 | 'chained_collate', 49 | 'skip_none_collate', 50 | 'load_lmdb', 'load_lmdb_zstd_dict', 51 | 'StructureSampler', 52 | 'DistributedStructureSampler', 53 | 'thiacalix_n_arene_dataset'] 54 | -------------------------------------------------------------------------------- /chytorch/utils/data/_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022, 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from random import shuffle 24 | from torch import Size 25 | from torch.utils.data import Dataset 26 | from typing import List, Sequence, TypeVar 27 | from .lmdb import LMDBMapper 28 | from .unpack import Decompress 29 | 30 | 31 | element = TypeVar('element') 32 | 33 | 34 | class chained_collate: 35 | """ 36 | Collate batch of tuples with different data structures by different collate functions. 37 | 38 | :param skip_nones: ignore entities with Nones 39 | """ 40 | def __init__(self, *collate_fns, skip_nones=True): 41 | self.collate_fns = collate_fns 42 | self.skip_nones = skip_nones 43 | 44 | def __call__(self, batch): 45 | sub_batches = [[] for _ in self.collate_fns] 46 | for x in batch: 47 | if self.skip_nones and (x is None or None in x): 48 | continue 49 | for y, s in zip(x, sub_batches): 50 | s.append(y) 51 | return [f(x) for x, f in zip(sub_batches, self.collate_fns)] 52 | 53 | 54 | def skip_none_collate(collate_fn): 55 | def w(batch): 56 | return collate_fn([x for x in batch if x is not None]) 57 | return w 58 | 59 | 60 | def load_lmdb(path, size=4, order='big'): 61 | """ 62 | Helper for loading LMDB datasets with continuous integer keys. 63 | Note: keys of datapoints should be positive indices coded as bytes with size `size` and `order` endianness. 64 | 65 | Example structure of DB with a key size=2: 66 | 0000 (0): first record 67 | 0001 (1): second record 68 | ... 69 | ffff (65535): last record 70 | 71 | :param path: path to the database 72 | :param size: key size in bytes 73 | :param order: big or little endian 74 | """ 75 | db = LMDBMapper(path) 76 | db._mapping = ByteRange(len(db), size=size, order=order) 77 | return db 78 | 79 | 80 | def load_lmdb_zstd_dict(path, size=4, order='big', key=b'\xff\xff\xff\xff'): 81 | """ 82 | Helper for loading LMDB datasets with continuous integer keys compressed by zstd with external dictionary. 83 | Note: keys of datapoints should be positive indices coded as bytes with size `size` and `order` endianness. 84 | Database should contain one additional record with key `key` with decompression dictionary. 85 | 86 | Example structure of DB a key size=2: 87 | ffff (65535): zstd dict bytes 88 | 0000 (0): first record 89 | 0001 (1): second record 90 | ... 91 | fffe (65534): last record 92 | 93 | :param path: path to the database 94 | :param key: LMDB entry with dictionary data 95 | :param size: key size in bytes 96 | :param order: big or little endian 97 | """ 98 | db = LMDBMapper(path) 99 | db._mapping = ByteRange(len(db) - 1, size=size, order=order) 100 | db[0] # connect db 101 | dc = Decompress(db, 'zstd', db._tr.get(key)) 102 | return dc 103 | 104 | 105 | class SizedList(List): 106 | """ 107 | List with tensor-like size method. 108 | """ 109 | def size(self, dim=None): 110 | if dim == 0: 111 | return len(self) 112 | elif dim is None: 113 | return Size((len(self),)) 114 | raise IndexError 115 | 116 | 117 | class ByteRange: 118 | """ 119 | Range returning values as bytes 120 | """ 121 | def __init__(self, *args, size=4, order='big', **kwargs): 122 | self.range = range(*args, **kwargs) 123 | self.size = size 124 | self.order = order 125 | 126 | def __getitem__(self, item): 127 | return self.range[item].to_bytes(self.size, self.order) 128 | 129 | def __len__(self): 130 | return len(self.range) 131 | 132 | 133 | class ShuffledList(Dataset): 134 | """ 135 | Returns randomly shuffled sequences 136 | """ 137 | def __init__(self, data: Sequence[Sequence[element]]): 138 | self.data = data 139 | 140 | def __getitem__(self, item: int) -> List[element]: 141 | x = list(self.data[item]) 142 | shuffle(x) 143 | return x 144 | 145 | def __len__(self): 146 | return len(self.data) 147 | 148 | def size(self, dim): 149 | if dim == 0: 150 | return len(self) 151 | elif dim is None: 152 | return Size((len(self),)) 153 | raise IndexError 154 | 155 | 156 | class SuppressException(Dataset): 157 | """ 158 | Catch exceptions in wrapped dataset and return None instead 159 | """ 160 | def __init__(self, dataset): 161 | self.dataset = dataset 162 | 163 | def __getitem__(self, item): 164 | try: 165 | return self.dataset[item] 166 | except IndexError: 167 | raise 168 | except Exception: 169 | pass 170 | 171 | def __len__(self): 172 | return len(self.dataset) 173 | 174 | def size(self, dim): 175 | if dim == 0: 176 | return len(self) 177 | elif dim is None: 178 | return Size((len(self),)) 179 | raise IndexError 180 | 181 | 182 | __all__ = ['SizedList', 183 | 'ShuffledList', 184 | 'SuppressException', 185 | 'ByteRange', 186 | 'chained_collate', 'skip_none_collate', 187 | 'load_lmdb', 'load_lmdb_zstd_dict'] 188 | -------------------------------------------------------------------------------- /chytorch/utils/data/lmdb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022, 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from pathlib import Path 24 | from pickle import load, dump 25 | from torch import Size 26 | from torch.utils.data import Dataset 27 | from typing import Union 28 | 29 | 30 | class LMDBMapper(Dataset): 31 | """ 32 | Map LMDB key-value storage to the Sequence Dataset of bytestrings. 33 | """ 34 | def __init__(self, db: str, *, cache: Union[Path, str, None] = None): 35 | """ 36 | Note: mapper internally uses python list for index to bytes-key mapping and can be huge on big datasets. 37 | 38 | :param db: lmdb dir path 39 | :param cache: path to cache file for [re]storing index. caching disabled by default. 40 | """ 41 | self.db = db 42 | self.cache = cache 43 | 44 | if cache is None: 45 | return 46 | if isinstance(cache, str): 47 | cache = Path(cache) 48 | if not cache.exists(): 49 | return 50 | # load existing cache 51 | with cache.open('rb') as f: 52 | self._mapping = load(f) 53 | 54 | def __getitem__(self, item: int) -> bytes: 55 | try: 56 | tr = self._tr 57 | except AttributeError: 58 | from lmdb import Environment 59 | 60 | self._db = db = Environment(self.db, readonly=True, lock=False) 61 | self._tr = tr = db.begin() 62 | 63 | try: 64 | mapping = self._mapping 65 | except AttributeError: 66 | with tr.cursor() as c: 67 | # build mapping 68 | self._mapping = mapping = list(c.iternext(keys=True, values=False)) 69 | if (cache := self.cache) is not None: # save to cache 70 | if isinstance(cache, str): 71 | cache = Path(cache) 72 | with cache.open('wb') as f: 73 | dump(mapping, f) 74 | 75 | return tr.get(mapping[item]) 76 | 77 | def __len__(self): 78 | try: 79 | return len(self._mapping) 80 | except AttributeError: 81 | # temporary open db 82 | from lmdb import Environment 83 | 84 | with Environment(self.db, readonly=True, lock=False) as f: 85 | return f.stat()['entries'] 86 | 87 | def size(self, dim): 88 | if dim == 0: 89 | return len(self) 90 | elif dim is None: 91 | return Size((len(self),)) 92 | raise IndexError 93 | 94 | def __del__(self): 95 | try: 96 | self._tr.commit() 97 | self._db.close() 98 | except AttributeError: 99 | pass 100 | else: 101 | del self._tr, self._db 102 | 103 | def __getstate__(self): 104 | return {'db': self.db, 'cache': self.cache} 105 | 106 | def __setstate__(self, state): 107 | self.__init__(**state) 108 | 109 | 110 | __all__ = ['LMDBMapper'] 111 | -------------------------------------------------------------------------------- /chytorch/utils/data/molecule/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from .conformer import * 24 | from .dummy import * 25 | from .encoder import * 26 | from .rdkit import * 27 | 28 | 29 | __all__ = ['MoleculeDataset', 'MoleculeDataPoint', 'MoleculeDataBatch', 30 | 'collate_molecules', 'left_padded_collate_molecules', 31 | 'ConformerDataset', 'ConformerDataPoint', 'ConformerDataBatch', 'collate_conformers', 32 | 'RDKitConformerDataset', 33 | 'thiacalix_n_arene_dataset'] 34 | -------------------------------------------------------------------------------- /chytorch/utils/data/molecule/_unpack.pyx: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023, 2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | import numpy as np 24 | 25 | cimport cython 26 | cimport numpy as cnp 27 | from cpython.mem cimport PyMem_Malloc, PyMem_Free 28 | 29 | cnp.import_array() 30 | DTYPE = np.int32 31 | ctypedef cnp.int32_t DTYPE_t 32 | 33 | 34 | # Format specification:: 35 | # 36 | # Big endian bytes order 37 | # 8 bit - 0x02 (current format specification) 38 | # 12 bit - number of atoms 39 | # 12 bit - cis/trans stereo block size 40 | # 41 | # Atom block 9 bytes (repeated): 42 | # 12 bit - atom number 43 | # 4 bit - number of neighbors 44 | # 2 bit tetrahedron sign (00 - not stereo, 10 or 11 - has stereo) 45 | # 2 bit - allene sign 46 | # 5 bit - isotope (00000 - not specified, over = isotope - common_isotope + 16) 47 | # 7 bit - atomic number (<=118) 48 | # 32 bit - XY float16 coordinates 49 | # 3 bit - hydrogens (0-7). Note: 7 == None 50 | # 4 bit - charge (charge + 4. possible range -4 - 4) 51 | # 1 bit - radical state 52 | # Connection table: flatten list of neighbors. neighbors count stored in atom block. 53 | # For example CC(=O)O - {1: [2], 2: [1, 3, 4], 3: [2], 4: [2]} >> [2, 1, 3, 4, 2, 2]. 54 | # Repeated block (equal to bonds count). 55 | # 24 bit - paired 12 bit numbers. 56 | # Bonds order block 3 bit per bond zero-padded to full byte at the end. 57 | # Cis/trans data block (repeated): 58 | # 24 bit - atoms pair 59 | # 7 bit - zero padding. in future can be used for extra bond-level stereo, like atropoisomers. 60 | # 1 bit - sign 61 | 62 | @cython.nonecheck(False) 63 | @cython.boundscheck(False) 64 | @cython.cdivision(True) 65 | @cython.wraparound(False) 66 | def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsigned short symmetric_attention, 67 | unsigned short components_attention, DTYPE_t max_neighbors, DTYPE_t max_distance): 68 | """ 69 | Optimized chython pack to graph tensor converter. 70 | Ignores charge, radicals, isotope, coordinates, bond order, and stereo info 71 | """ 72 | cdef unsigned char a, b, c, hydrogens, neighbors_count 73 | cdef unsigned char *connections 74 | 75 | cdef unsigned short atoms_count, bonds_count = 0, order_count = 0, cis_trans_count 76 | cdef unsigned short i, j, k, n, m 77 | cdef unsigned short[4096] mapping 78 | cdef unsigned int size, shift = 4 79 | 80 | cdef cnp.ndarray[DTYPE_t, ndim=1] atoms, neighbors 81 | cdef cnp.ndarray[DTYPE_t, ndim=2] distance 82 | cdef DTYPE_t d 83 | 84 | # read header 85 | if data[0] != 2: 86 | raise ValueError('invalid pack version') 87 | 88 | a, b, c = data[1], data[2], data[3] 89 | atoms_count = (a << 4| b >> 4) + add_cls 90 | cis_trans_count = (b & 0x0f) << 8 | c 91 | 92 | atoms = np.empty(atoms_count, dtype=DTYPE) 93 | neighbors = np.zeros(atoms_count, dtype=DTYPE) 94 | distance = np.full((atoms_count, atoms_count), 9999, dtype=DTYPE) # fill with unreachable value 95 | 96 | # allocate memory 97 | connections = PyMem_Malloc(atoms_count * sizeof(unsigned char)) 98 | if not connections: 99 | raise MemoryError() 100 | 101 | if add_cls: 102 | atoms[0] = 1 103 | neighbors[0] = 0 104 | distance[0] = 1 # set CLS to all atoms attention 105 | 106 | if symmetric_attention: # set all atoms to CLS attention 107 | distance[1:, 0] = 1 108 | else: # disable atom to CLS attention 109 | distance[1:, 0] = 0 110 | 111 | # unpack atom block 112 | for i in range(add_cls, atoms_count): 113 | distance[i, i] = 0 # set diagonal to zero 114 | a, b = data[shift], data[shift + 1] 115 | n = a << 4 | b >> 4 116 | mapping[n] = i 117 | connections[i] = neighbors_count = b & 0x0f 118 | bonds_count += neighbors_count 119 | 120 | atoms[i] = (data[shift + 3] & 0x7f) + 2 121 | 122 | hydrogens = data[shift + 8] >> 5 123 | if hydrogens != 7: # hydrogens is not None 124 | neighbors_count += hydrogens 125 | if neighbors_count > max_neighbors: 126 | neighbors_count = max_neighbors 127 | neighbors[i] = neighbors_count + 2 # neighbors + hydrogens 128 | shift += 9 129 | 130 | if bonds_count: 131 | bonds_count /= 2 132 | 133 | order_count = bonds_count * 3 134 | if order_count % 8: 135 | order_count = order_count / 8 + 1 136 | else: 137 | order_count /= 8 138 | 139 | n = add_cls 140 | for i in range(0, 2 * bonds_count, 2): 141 | a, b, c = data[shift], data[shift + 1], data[shift + 2] 142 | m = mapping[a << 4 | b >> 4] 143 | while not connections[n]: 144 | n += 1 145 | connections[n] -= 1 146 | distance[n, m] = distance[m, n] = 1 147 | 148 | m = mapping[(b & 0x0f) << 8 | c] 149 | while not connections[n]: 150 | n += 1 151 | connections[n] -= 1 152 | distance[n, m] = distance[m, n] = 1 153 | shift += 3 154 | 155 | # floyd-warshall algo 156 | for k in range(add_cls, atoms_count): 157 | for i in range(add_cls, atoms_count): 158 | if i == k or distance[i, k] == 9999: 159 | continue 160 | for j in range(add_cls, atoms_count): 161 | d = distance[i, k] + distance[k, j] 162 | if d < distance[i, j]: 163 | distance[i, j] = d 164 | 165 | # reset distances to proper values 166 | for i in range(add_cls, atoms_count): 167 | for j in range(i, atoms_count): 168 | d = distance[i, j] 169 | if d == 9999: 170 | # set attention between subgraphs 171 | distance[i, j] = distance[j, i] = components_attention 172 | elif d > max_distance: 173 | distance[i, j] = distance[j, i] = max_distance + 2 174 | else: 175 | distance[i, j] = distance[j, i] = d + 2 176 | 177 | size = shift + order_count + 4 * cis_trans_count 178 | PyMem_Free(connections) 179 | return atoms, neighbors, distance, size 180 | -------------------------------------------------------------------------------- /chytorch/utils/data/molecule/conformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from chython import MoleculeContainer 24 | from functools import cached_property 25 | from numpy import empty, ndarray, sqrt, square, ones, digitize, arange, int32 26 | from numpy.random import default_rng 27 | from torch import IntTensor, Size, zeros, ones as t_ones, int32 as t_int32, eye 28 | from torch.nn.utils.rnn import pad_sequence 29 | from torch.utils.data import Dataset 30 | from torch.utils.data._utils.collate import default_collate_fn_map 31 | from torchtyping import TensorType 32 | from typing import Sequence, Tuple, Union, NamedTuple 33 | 34 | 35 | class ConformerDataPoint(NamedTuple): 36 | atoms: TensorType['atoms', int] 37 | hydrogens: TensorType['atoms', int] 38 | distances: TensorType['atoms', 'atoms', int] 39 | 40 | 41 | class ConformerDataBatch(NamedTuple): 42 | atoms: TensorType['batch', 'atoms', int] 43 | hydrogens: TensorType['batch', 'atoms', int] 44 | distances: TensorType['batch', 'atoms', 'atoms', int] 45 | 46 | def to(self, *args, **kwargs): 47 | return ConformerDataBatch(*(x.to(*args, **kwargs) for x in self)) 48 | 49 | def cpu(self, *args, **kwargs): 50 | return ConformerDataBatch(*(x.cpu(*args, **kwargs) for x in self)) 51 | 52 | def cuda(self, *args, **kwargs): 53 | return ConformerDataBatch(*(x.cuda(*args, **kwargs) for x in self)) 54 | 55 | 56 | def collate_conformers(batch, *, collate_fn_map=None) -> ConformerDataBatch: 57 | """ 58 | Prepares batches of conformers. 59 | 60 | :return: atoms, hydrogens, distances. 61 | """ 62 | atoms, hydrogens, distances = [], [], [] 63 | 64 | for a, h, d in batch: 65 | atoms.append(a) 66 | hydrogens.append(h) 67 | distances.append(d) 68 | 69 | pa = pad_sequence(atoms, True) 70 | b, s = pa.shape 71 | tmp = eye(s, dtype=t_int32).repeat(b, 1, 1) # prevent nan in MHA softmax on padding 72 | for i, d in enumerate(distances): 73 | s = d.size(0) 74 | tmp[i, :s, :s] = d 75 | return ConformerDataBatch(pa, pad_sequence(hydrogens, True), tmp) 76 | 77 | 78 | default_collate_fn_map[ConformerDataPoint] = collate_conformers # add auto_collation to the DataLoader 79 | 80 | 81 | class ConformerDataset(Dataset): 82 | def __init__(self, molecules: Sequence[Union[MoleculeContainer, Tuple[ndarray, ndarray, ndarray]]], *, 83 | short_cutoff: float = .9, long_cutoff: float = 5., precision: float = .05, 84 | add_cls: bool = True, unpack: bool = True, xyz: bool = True, noisy_distance: bool = False): 85 | """ 86 | convert molecules to tuple of: 87 | atoms vector with atomic numbers + 2, 88 | vector with implicit hydrogens count shifted by 2, 89 | matrix with the discretized Euclidian distances between atoms shifted by 3. 90 | 91 | Note: atoms shifted to differentiate from padding equal to zero, special cls token equal to 1, 92 | and reserved task specific token equal to 2. 93 | hydrogens shifted to differentiate from padding equal to zero and reserved task-specific token equal to 1. 94 | distances shifted to differentiate from padding equal to zero, from special distance equal to 1 95 | that code unreachable atoms/tokens, and self-attention of atoms equal to 2. 96 | 97 | :param molecules: map-like molecules collection or tuples of atomic numbers array, 98 | hydrogens array, and coordinates/distances array. 99 | :param short_cutoff: shortest possible distance between atoms 100 | :param long_cutoff: radius of visible neighbors sphere 101 | :param precision: discretized segment size 102 | :param add_cls: add special token at first position 103 | :param unpack: unpack coordinates from `chython.MoleculeContainer` (True) or use prepared data (False). 104 | predefined data structure: (vector of atomic numbers, vector of neighbors, 105 | matrix of coordinates or distances). 106 | :param xyz: provided xyz or distance matrix if unpack=False 107 | :param noisy_distance: add noise in [-1, 1] range into binarized distance 108 | """ 109 | if unpack: 110 | assert xyz, 'xyz should be True if unpack True' 111 | assert precision > .01 and short_cutoff > .1 and long_cutoff > 1, 'invalid cutoff and precision' 112 | assert long_cutoff - short_cutoff > precision, 'precision should be less than cutoff interval' 113 | 114 | self.molecules = molecules 115 | self.short_cutoff = short_cutoff 116 | self.long_cutoff = long_cutoff 117 | self.precision = precision 118 | self.add_cls = add_cls 119 | self.unpack = unpack 120 | self.xyz = xyz 121 | self.noisy_distance = noisy_distance 122 | 123 | # discrete bins intervals. first 3 bins reserved for shifted coding 124 | self._bins = arange(short_cutoff - 3 * precision, long_cutoff, precision) 125 | self._bins[:3] = [-1, 0, .01] # trick for self-loop coding 126 | self.max_distance = len(self._bins) - 2 # param for MoleculeEncoder 127 | 128 | def __getitem__(self, item: int) -> ConformerDataPoint: 129 | mol = self.molecules[item] 130 | if self.unpack: 131 | if self.add_cls: 132 | atoms = t_ones(len(mol) + 1, dtype=t_int32) # cls token = 1 133 | hydrogens = zeros(len(mol) + 1, dtype=t_int32) # cls centrality-encoder disabled by padding trick 134 | else: 135 | atoms = IntTensor(len(mol)) 136 | hydrogens = IntTensor(len(mol)) 137 | 138 | for i, (n, a) in enumerate(mol.atoms(), self.add_cls): 139 | atoms[i] = a.atomic_number + 2 140 | hydrogens[i] = (a.implicit_hydrogens or 0) + 2 141 | 142 | xyz = empty((len(mol), 3)) 143 | conformer = mol._conformers[0] # noqa 144 | for i, n in enumerate(mol): 145 | xyz[i] = conformer[n] 146 | else: 147 | a, hgs, xyz = mol 148 | if self.add_cls: 149 | atoms = t_ones(len(a) + 1, dtype=t_int32) 150 | hydrogens = zeros(len(a) + 1, dtype=t_int32) 151 | 152 | atoms[1:] = IntTensor(a + 2) 153 | hydrogens[1:] = IntTensor(hgs + 2) 154 | else: 155 | atoms = IntTensor(a + 2) 156 | hydrogens = IntTensor(hgs + 2) 157 | 158 | if self.xyz: 159 | diff = xyz[None, :, :] - xyz[:, None, :] # NxNx3 160 | dist = sqrt(square(diff).sum(axis=-1)) # NxN 161 | else: 162 | dist = xyz 163 | 164 | dist = digitize(dist, self._bins) 165 | if self.noisy_distance: 166 | dist += (dist > 2) * self.generator.integers(-1, 1, size=dist.shape, endpoint=True) 167 | if self.add_cls: # set cls to atoms distance equal to 0 168 | tmp = ones((len(atoms), len(atoms)), dtype=int32) 169 | tmp[1:, 1:] = dist 170 | dist = tmp 171 | return ConformerDataPoint(atoms, hydrogens, IntTensor(dist)) 172 | 173 | def __len__(self): 174 | return len(self.molecules) 175 | 176 | def size(self, dim): 177 | if dim == 0: 178 | return len(self.molecules) 179 | elif dim is None: 180 | return Size((len(self),)) 181 | raise IndexError 182 | 183 | @cached_property 184 | def generator(self): 185 | return default_rng() 186 | 187 | 188 | __all__ = ['ConformerDataset', 'ConformerDataPoint', 'ConformerDataBatch', 'collate_conformers'] 189 | -------------------------------------------------------------------------------- /chytorch/utils/data/molecule/dummy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from chython import smiles 24 | from .encoder import MoleculeDataset 25 | 26 | 27 | def thiacalix_n_arene_dataset(n=4, size=10_000, **kwargs): 28 | """ 29 | Create a dummy dataset for testing purposes with thiacalix[n]arenes. 30 | 31 | :param n: number of macrocycle fragments. Each fragment contains 12 atoms. 32 | :param size: dataset size 33 | :param kwargs: other params of MoleculeDataset 34 | """ 35 | assert n >= 3, 'n must be greater than 3' 36 | prefix = 'C12=CC(C(C)(C)C)=CC(=C2O)S' 37 | postfix = 'C2=CC(C(C)(C)C)=CC(=C2O)S1' 38 | chain = ''.join('C2=CC(C(C)(C)C)=CC(=C2O)S' for _ in range(n - 2)) 39 | 40 | return MoleculeDataset([smiles(prefix + chain + postfix)] * size, **kwargs) 41 | 42 | 43 | __all__ = ['thiacalix_n_arene_dataset'] 44 | -------------------------------------------------------------------------------- /chytorch/utils/data/molecule/encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2021-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from chython import MoleculeContainer 24 | from functools import partial 25 | from numpy import minimum, nan_to_num 26 | from scipy.sparse.csgraph import shortest_path 27 | from torch import IntTensor, Size, int32, ones, zeros, eye, empty, triu, Tensor, cat 28 | from torch.nn.utils.rnn import pad_sequence 29 | from torch.utils.data import Dataset 30 | from torch.utils.data._utils.collate import default_collate_fn_map 31 | from torchtyping import TensorType 32 | from typing import Sequence, Union, NamedTuple, Tuple 33 | from zlib import decompress 34 | 35 | 36 | class MoleculeDataPoint(NamedTuple): 37 | atoms: TensorType['atoms', int] 38 | neighbors: TensorType['atoms', int] 39 | distances: TensorType['atoms', 'atoms', int] 40 | 41 | 42 | class MoleculeDataBatch(NamedTuple): 43 | atoms: TensorType['batch', 'atoms', int] 44 | neighbors: TensorType['batch', 'atoms', int] 45 | distances: TensorType['batch', 'atoms', 'atoms', int] 46 | 47 | def to(self, *args, **kwargs): 48 | return MoleculeDataBatch(*(x.to(*args, **kwargs) for x in self)) 49 | 50 | def cpu(self, *args, **kwargs): 51 | return MoleculeDataBatch(*(x.cpu(*args, **kwargs) for x in self)) 52 | 53 | def cuda(self, *args, **kwargs): 54 | return MoleculeDataBatch(*(x.cuda(*args, **kwargs) for x in self)) 55 | 56 | 57 | def collate_molecules(batch, *, padding_left: bool = False, collate_fn_map=None) -> MoleculeDataBatch: 58 | """ 59 | Prepares batches of molecules. 60 | 61 | :return: atoms, neighbors, distances. 62 | """ 63 | atoms, neighbors, distances = [], [], [] 64 | 65 | for a, n, d in batch: 66 | if padding_left: 67 | atoms.append(a.flipud()) 68 | neighbors.append(n.flipud()) 69 | else: 70 | atoms.append(a) 71 | neighbors.append(n) 72 | distances.append(d) 73 | 74 | pa = pad_sequence(atoms, True) 75 | b, s = pa.shape 76 | tmp = eye(s, dtype=int32).repeat(b, 1, 1) # prevent nan in MHA softmax on padding 77 | for i, d in enumerate(distances): 78 | s = d.size(0) 79 | if padding_left: 80 | tmp[i, -s:, -s:] = d 81 | else: 82 | tmp[i, :s, :s] = d 83 | if padding_left: 84 | return MoleculeDataBatch(pa.fliplr(), pad_sequence(neighbors, True).fliplr(), tmp) 85 | return MoleculeDataBatch(pa, pad_sequence(neighbors, True), tmp) 86 | 87 | 88 | left_padded_collate_molecules = partial(collate_molecules, padding_left=True) 89 | default_collate_fn_map[MoleculeDataPoint] = collate_molecules # add auto_collation to the DataLoader 90 | 91 | 92 | class MoleculeDataset(Dataset): 93 | def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *, 94 | add_cls: bool = True, cls_token: Union[int, Tuple[int, ...], Sequence[int], Sequence[Tuple[int, ...]], 95 | TensorType['cls', int], TensorType['dataset', 1, int], TensorType['dataset', 'cls', int]] = 1, 96 | max_distance: int = 10, max_neighbors: int = 14, 97 | attention_schema: str = 'bert', components_attention: bool = True, 98 | unpack: bool = False, compressed: bool = True, distance_cutoff=None): 99 | """ 100 | convert molecules to tuple of: 101 | atoms vector with atomic numbers + 2, 102 | neighbors vector with connected neighbored atoms count including implicit hydrogens count shifted by 2, 103 | distance matrix with the shortest paths between atoms shifted by 2. 104 | 105 | Note: atoms shifted to differentiate from padding equal to zero, special cls token equal to 1, and reserved MLM 106 | task token equal to 2. 107 | neighbors shifted to differentiate from padding equal to zero and reserved MLM task token equal to 1. 108 | distances shifted to differentiate from padding equal to zero and from special distance equal to 1 109 | that code unreachable atoms (e.g. salts). 110 | 111 | :param molecules: molecules collection 112 | :param max_distance: set distances greater than cutoff to cutoff value 113 | :param add_cls: add special token at first position 114 | :param max_neighbors: set neighbors count greater than cutoff to cutoff value 115 | :param attention_schema: attention between CLS and atoms: 116 | bert - symmetrical without masks; 117 | causal - masked from atoms to cls, causal between cls multi-prompts (triangle mask) and full to atoms; 118 | directed - masked from atoms to cls and between cls, but full to atoms (self-attention of cls is kept); 119 | :param components_attention: enable or disable attention between subgraphs 120 | :param unpack: unpack molecules 121 | :param compressed: packed molecules are compressed 122 | :param cls_token: idx of cls token (int), or multiple tokens for multi-prompt (tuple, 1d-tensor), 123 | or individual token per sample (Sequence, column-vector) or 124 | individual multi-prompt per sample (Sequence of tuples, 2d-tensor). 125 | """ 126 | if not isinstance(cls_token, int) or cls_token != 1: 127 | assert add_cls, 'non-default value of cls_token requires add_cls=True' 128 | assert attention_schema in ('bert', 'causal', 'directed'), 'Invalid attention schema' 129 | 130 | self.molecules = molecules 131 | # distance_cutoff is deprecated 132 | self.max_distance = distance_cutoff if distance_cutoff is not None else max_distance 133 | self.add_cls = add_cls 134 | self.max_neighbors = max_neighbors 135 | self.unpack = unpack 136 | self.compressed = compressed 137 | self.cls_token = cls_token 138 | self.attention_schema = attention_schema 139 | self.components_attention = components_attention 140 | 141 | def __getitem__(self, item: int) -> MoleculeDataPoint: 142 | mol = self.molecules[item] 143 | 144 | # cls setup lookup 145 | cls_token = self.cls_token 146 | if not self.add_cls: 147 | cls_cnt = 0 148 | elif isinstance(cls_token, int): 149 | cls_cnt = 1 150 | elif isinstance(cls_token, tuple): 151 | cls_cnt = len(cls_token) 152 | assert cls_cnt > 1, 'wrong multi-prompt setup' 153 | cls_token = IntTensor(cls_token) 154 | elif isinstance(cls_token, Sequence): 155 | if isinstance(ct := cls_token[item], int): 156 | assert isinstance(cls_token[0], int), 'inconsistent cls_token data' 157 | cls_cnt = 1 158 | cls_token = ct 159 | elif isinstance(ct, Sequence): 160 | cls_cnt = len(ct) 161 | assert cls_cnt > 1, 'wrong multi-prompt setup' 162 | assert isinstance(cls_token[0], Sequence) and cls_cnt == len(cls_token[0]), 'inconsistent cls_token data' 163 | cls_token = IntTensor(ct) 164 | else: 165 | raise TypeError('cls_token must be int, tuple of ints, sequence of ints or tuples of ints or 1,2-d tensor') 166 | elif isinstance(cls_token, Tensor): 167 | if cls_token.dim() == 1: 168 | cls_cnt = cls_token.size(0) 169 | assert cls_cnt > 1, 'wrong multi-prompt setup' 170 | elif cls_token.dim() == 2: 171 | cls_cnt = cls_token.size(1) 172 | cls_token = cls_token[item] 173 | else: 174 | raise TypeError('cls_token must be int, tuple of ints, sequence of ints or tuples of ints or 1,2-d tensor') 175 | else: 176 | raise TypeError('cls_token must be int, tuple of ints, sequence of ints or tuples of ints or 1,2-d tensor') 177 | 178 | if self.unpack: 179 | try: 180 | from ._unpack import unpack 181 | except ImportError: # windows? 182 | mol = MoleculeContainer.unpack(mol, compressed=self.compressed) 183 | else: 184 | if self.compressed: 185 | mol = decompress(mol) 186 | atoms, neighbors, distances, _ = unpack(mol, cls_cnt == 1, # only single cls token supported by cython ext 187 | # causal and directed have the same mask for 1 cls token case 188 | self.attention_schema == 'bert', 189 | self.components_attention, self.max_neighbors, 190 | self.max_distance) 191 | atoms = IntTensor(atoms) 192 | neighbors = IntTensor(neighbors) 193 | distances = IntTensor(distances) 194 | if cls_cnt == 1: 195 | # token already pre-allocated 196 | if isinstance(cls_token, Tensor) or cls_token != 1: 197 | # change default value (1) 198 | atoms[0] = cls_token 199 | elif cls_cnt: # expand atoms with cls tokens 200 | atoms = cat([cls_token, atoms]) 201 | neighbors = cat([zeros(cls_cnt, dtype=int32), neighbors]) 202 | distances = self._add_cls_to_distances(distances, cls_cnt) 203 | return MoleculeDataPoint(atoms, neighbors, distances) 204 | 205 | token_cnt = len(mol) + cls_cnt 206 | atoms = empty(token_cnt, dtype=int32) 207 | neighbors = zeros(token_cnt, dtype=int32) # cls centrality-encoder disabled by padding trick 208 | 209 | nc = self.max_neighbors 210 | ngb = mol._bonds # noqa speedup 211 | for i, (n, a) in enumerate(mol.atoms(), cls_cnt): 212 | atoms[i] = a.atomic_number + 2 213 | nb = len(ngb[n]) + (a.implicit_hydrogens or 0) # treat bad valence as 0-hydrogen 214 | if nb > nc: 215 | nb = nc 216 | neighbors[i] = nb + 2 217 | 218 | distances = shortest_path(mol.adjacency_matrix(), method='FW', directed=False, unweighted=True) + 2 219 | nan_to_num(distances, copy=False, posinf=self.components_attention) 220 | minimum(distances, self.max_distance + 2, out=distances) 221 | distances = IntTensor(distances) 222 | 223 | if cls_cnt: 224 | atoms[:cls_cnt] = cls_token 225 | distances = self._add_cls_to_distances(distances, cls_cnt) 226 | return MoleculeDataPoint(atoms, neighbors, distances) 227 | 228 | def __len__(self): 229 | return len(self.molecules) 230 | 231 | def size(self, dim): 232 | if dim == 0: 233 | return len(self) 234 | elif dim is None: 235 | return Size((len(self),)) 236 | raise IndexError 237 | 238 | def _add_cls_to_distances(self, distances, cls_cnt): 239 | total = distances.size(0) + cls_cnt 240 | if self.attention_schema == 'bert': # everything to everything 241 | tmp = ones(total, total, dtype=int32) 242 | elif self.attention_schema == 'causal': 243 | tmp = triu(ones(total, total, dtype=int32)) 244 | else: # CLS to atoms but not back 245 | tmp = eye(total, dtype=int32) # self attention of cls tokens 246 | tmp[:cls_cnt, cls_cnt:] = 1 # cls to atom attention 247 | tmp[cls_cnt:, cls_cnt:] = distances 248 | return tmp 249 | 250 | 251 | __all__ = ['MoleculeDataset', 'MoleculeDataPoint', 'MoleculeDataBatch', 252 | 'collate_molecules', 'left_padded_collate_molecules'] 253 | -------------------------------------------------------------------------------- /chytorch/utils/data/molecule/rdkit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from numpy import array 24 | from random import choice 25 | from torch import Size 26 | from torch.utils.data import Dataset 27 | from typing import Sequence, Optional, Dict 28 | 29 | 30 | def generate(smiles, num_conf=10, max_attempts=100, prune=.2): 31 | from rdkit.Chem import MolFromSmiles, AddHs, RemoveAllHs 32 | from rdkit.Chem.AllChem import EmbedMultipleConfs 33 | 34 | m = MolFromSmiles(smiles) 35 | m = AddHs(m) 36 | if not EmbedMultipleConfs(m, numConfs=num_conf, maxAttempts=max_attempts, pruneRmsThresh=prune): 37 | # try again ignoring chirality 38 | EmbedMultipleConfs(m, numConfs=num_conf, maxAttempts=max_attempts, pruneRmsThresh=prune, enforceChirality=False) 39 | m = RemoveAllHs(m) 40 | a = array([x.GetAtomicNum() for x in m.GetAtoms()]) 41 | h = array([x.GetNumExplicitHs() for x in m.GetAtoms()]) 42 | return [(a, h, c.GetPositions()) for c in m.GetConformers()] 43 | 44 | 45 | class RDKitConformerDataset(Dataset): 46 | """ 47 | Random conformer generator dataset 48 | """ 49 | def __init__(self, molecules: Sequence[str], num_conf=10, max_attempts=100, prune=.2, 50 | cache: Optional[Dict[int, Sequence]] = None): 51 | """ 52 | :param molecules: list of molecules' SMILES strings 53 | :param num_conf: numConfs parameter in EmbedMultipleConfs 54 | :param max_attempts: maxAttempts parameter in EmbedMultipleConfs 55 | :param prune: pruneRmsThresh parameter in EmbedMultipleConfs 56 | :param cache: cache for generated conformers 57 | """ 58 | self.molecules = molecules 59 | self.num_conf = num_conf 60 | self.max_attempts = max_attempts 61 | self.prune = prune 62 | self.cache = cache 63 | 64 | def __getitem__(self, item: int): 65 | if self.cache is not None and item in self.cache: 66 | return choice(self.cache[item]) 67 | 68 | confs = generate(self.molecules[item], self.num_conf, self.max_attempts, self.prune) 69 | if not confs: 70 | raise ValueError('conformer generation failed') 71 | 72 | if self.cache is not None: 73 | self.cache[item] = confs 74 | return choice(confs) 75 | 76 | def __len__(self): 77 | return len(self.molecules) 78 | 79 | def size(self, dim): 80 | if dim == 0: 81 | return len(self.molecules) 82 | elif dim is None: 83 | return Size((len(self),)) 84 | raise IndexError 85 | 86 | 87 | __all__ = ['RDKitConformerDataset'] 88 | -------------------------------------------------------------------------------- /chytorch/utils/data/product.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from math import floor 24 | from torch import Size 25 | from torch.utils.data import Dataset 26 | from typing import Sequence, List, TypeVar 27 | 28 | 29 | element = TypeVar('element') 30 | 31 | 32 | class ProductDataset(Dataset): 33 | """ 34 | Lazy product enumeration dataset for combinatorial libraries. 35 | """ 36 | def __init__(self, *sets: Sequence[element]): 37 | self.sets = sets 38 | 39 | # calculate lazy product metadata 40 | self._divs = divs = [] 41 | self._mods = mods = [] 42 | 43 | factor = 1 44 | for x in reversed(sets): 45 | s = len(x) 46 | divs.insert(0, factor) 47 | mods.insert(0, s) 48 | factor *= s 49 | self._size = factor 50 | 51 | def __getitem__(self, item: int) -> List[element]: 52 | if item < 0: 53 | item += self._size 54 | if item < 0 or item >= self._size: 55 | raise IndexError 56 | 57 | return [s[floor(item / d) % m] for s, d, m in zip(self.sets, self._divs, self._mods)] 58 | 59 | def __len__(self): 60 | return self._size 61 | 62 | def size(self, dim): 63 | if dim == 0: 64 | return len(self) 65 | elif dim is None: 66 | return Size((len(self),)) 67 | raise IndexError 68 | 69 | 70 | __all__ = ['ProductDataset'] 71 | -------------------------------------------------------------------------------- /chytorch/utils/data/reaction/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from .encoder import * 24 | 25 | 26 | # reverse compatibility 27 | ReactionDataset = ReactionEncoderDataset 28 | collate_reactions = collate_encoded_reactions 29 | 30 | 31 | __all__ = ['ReactionEncoderDataset', 'ReactionEncoderDataPoint', 'ReactionEncoderDataBatch', 32 | 'collate_encoded_reactions', 33 | # reverse compatibility 34 | 'ReactionDataset', 'collate_reactions'] 35 | -------------------------------------------------------------------------------- /chytorch/utils/data/reaction/encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2021-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from chython import ReactionContainer 24 | from itertools import chain, repeat 25 | from torch import IntTensor, cat, zeros, int32, Size, eye 26 | from torch.nn.utils.rnn import pad_sequence 27 | from torch.utils.data import Dataset 28 | from torch.utils.data._utils.collate import default_collate_fn_map 29 | from torchtyping import TensorType 30 | from typing import Sequence, Union, NamedTuple 31 | from ..molecule import MoleculeDataset 32 | 33 | 34 | class ReactionEncoderDataPoint(NamedTuple): 35 | atoms: TensorType['atoms', int] 36 | neighbors: TensorType['atoms', int] 37 | distances: TensorType['atoms', 'atoms', int] 38 | roles: TensorType['atoms', int] 39 | 40 | 41 | class ReactionEncoderDataBatch(NamedTuple): 42 | atoms: TensorType['batch', 'atoms', int] 43 | neighbors: TensorType['batch', 'atoms', int] 44 | distances: TensorType['batch', 'atoms', 'atoms', int] 45 | roles: TensorType['batch', 'atoms', int] 46 | 47 | def to(self, *args, **kwargs): 48 | return ReactionEncoderDataBatch(*(x.to(*args, **kwargs) for x in self)) 49 | 50 | def cpu(self, *args, **kwargs): 51 | return ReactionEncoderDataBatch(*(x.cpu(*args, **kwargs) for x in self)) 52 | 53 | def cuda(self, *args, **kwargs): 54 | return ReactionEncoderDataBatch(*(x.cuda(*args, **kwargs) for x in self)) 55 | 56 | 57 | def collate_encoded_reactions(batch, *, collate_fn_map=None) -> ReactionEncoderDataBatch: 58 | """ 59 | Prepares batches of reactions. 60 | 61 | :return: atoms, neighbors, distances, atoms roles. 62 | """ 63 | atoms, neighbors, distances, roles = [], [], [], [] 64 | for a, n, d, r in batch: 65 | atoms.append(a) 66 | neighbors.append(n) 67 | roles.append(r) 68 | distances.append(d) 69 | 70 | pa = pad_sequence(atoms, True) 71 | b, s = pa.shape 72 | tmp = eye(s, dtype=int32).repeat(b, 1, 1) # prevent nan in MHA softmax on padding 73 | for n, d in enumerate(distances): 74 | s = d.size(0) 75 | tmp[n, :s, :s] = d 76 | return ReactionEncoderDataBatch(pa, pad_sequence(neighbors, True), tmp, pad_sequence(roles, True)) 77 | 78 | 79 | default_collate_fn_map[ReactionEncoderDataPoint] = collate_encoded_reactions # add auto_collation to the DataLoader 80 | 81 | 82 | class ReactionEncoderDataset(Dataset): 83 | def __init__(self, reactions: Sequence[Union[ReactionContainer, bytes]], *, max_distance: int = 10, 84 | max_neighbors: int = 14, add_cls: bool = True, add_molecule_cls: bool = True, 85 | hide_molecule_cls: bool = True, unpack: bool = False, distance_cutoff=None, compressed: bool = True): 86 | """ 87 | convert reactions to tuple of: 88 | atoms, neighbors and distances tensors similar to molecule dataset. 89 | distances - merged molecular distances matrices filled by zero for isolating attention. 90 | roles: 2 reactants, 3 products, 0 padding, 1 cls token. 91 | 92 | :param reactions: reactions collection 93 | :param max_distance: set distances greater than cutoff to cutoff value 94 | :param add_cls: add special token at first position 95 | :param add_molecule_cls: add special token at first position of each molecule 96 | :param hide_molecule_cls: disable molecule cls in reaction lvl (mark as padding) 97 | :param max_neighbors: set neighbors count greater than cutoff to cutoff value 98 | :param unpack: unpack reactions 99 | :param compressed: packed reactions are compressed 100 | """ 101 | if not add_molecule_cls: 102 | assert not hide_molecule_cls, 'add_molecule_cls should be True if hide_molecule_cls is True' 103 | self.reactions = reactions 104 | # distance_cutoff is deprecated 105 | self.max_distance = distance_cutoff if distance_cutoff is not None else max_distance 106 | self.add_cls = add_cls 107 | self.add_molecule_cls = add_molecule_cls 108 | self.hide_molecule_cls = hide_molecule_cls 109 | self.max_neighbors = max_neighbors 110 | self.unpack = unpack 111 | self.compressed = compressed 112 | 113 | def __getitem__(self, item: int) -> ReactionEncoderDataPoint: 114 | rxn = self.reactions[item] 115 | if self.unpack: 116 | rxn = ReactionContainer.unpack(rxn, compressed=self.compressed) 117 | molecules = MoleculeDataset(rxn.reactants + rxn.products, max_distance=self.max_distance, 118 | max_neighbors=self.max_neighbors, add_cls=self.add_molecule_cls) 119 | 120 | if self.add_cls: 121 | # disable rxn cls in molecules encoder 122 | atoms, neighbors, roles = [IntTensor([0])], [IntTensor([0])], [1] 123 | else: 124 | atoms, neighbors, roles = [], [], [] 125 | distances = [] 126 | for i, (m, r) in enumerate(chain(zip(rxn.reactants, repeat(2)), zip(rxn.products, repeat(3)))): 127 | a, n, d = molecules[i] 128 | atoms.append(a) 129 | neighbors.append(n) 130 | distances.append(d) 131 | if self.add_molecule_cls: 132 | # (dis|en)able molecule cls in reaction encoder 133 | roles.append(0 if self.hide_molecule_cls else r) 134 | roles.extend(repeat(r, len(m))) 135 | 136 | tmp = zeros(len(roles), len(roles), dtype=int32) 137 | if self.add_cls: 138 | tmp[0, 0] = 1 # prevent nan in MHA softmax. 139 | i = 1 140 | else: 141 | i = 0 142 | for d in distances: 143 | j = i + d.size(0) 144 | tmp[i:j, i:j] = d 145 | i = j 146 | return ReactionEncoderDataPoint(cat(atoms), cat(neighbors), tmp, IntTensor(roles)) 147 | 148 | def __len__(self): 149 | return len(self.reactions) 150 | 151 | def size(self, dim): 152 | if dim == 0: 153 | return len(self) 154 | elif dim is None: 155 | return Size((len(self),)) 156 | raise IndexError 157 | 158 | 159 | __all__ = ['ReactionEncoderDataset', 'ReactionEncoderDataPoint', 'ReactionEncoderDataBatch', 160 | 'collate_encoded_reactions'] 161 | -------------------------------------------------------------------------------- /chytorch/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2022-2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | import torch.distributed as dist 24 | from chython import MoleculeContainer 25 | from collections import defaultdict 26 | from itertools import chain, islice 27 | from math import ceil 28 | from torch import Generator, randperm 29 | from torch.utils.data import Sampler 30 | from typing import Optional, List, Iterator 31 | from .molecule import MoleculeDataset 32 | 33 | 34 | def _build_index(dataset, sizes): 35 | if not isinstance(dataset, MoleculeDataset): 36 | raise TypeError('Unsupported Dataset') 37 | if sizes is not None: 38 | return sizes 39 | elif dataset.unpack: 40 | return [MoleculeContainer.pack_len(m) for m in dataset.molecules] 41 | else: 42 | return [len(m) for m in dataset.molecules] 43 | 44 | 45 | def _indices(order, sizes, batch_size): 46 | groups = defaultdict(list) 47 | for n in order: 48 | groups[sizes[n]].append(n) 49 | 50 | iterable_groups = {k: iter(v) for k, v in groups.items()} 51 | sorted_groups = sorted(groups) 52 | chained_groups = {} 53 | for k in groups: 54 | chained_groups[k] = chain(*(iterable_groups[x] for x in sorted_groups[sorted_groups.index(k)::-1]), 55 | *(iterable_groups[x] for x in sorted_groups[sorted_groups.index(k) + 1:])) 56 | 57 | indices = [] 58 | seen = set() 59 | for n in order: 60 | if n not in seen: 61 | for x in islice(chained_groups[sizes[n]], batch_size): 62 | indices.append(x) 63 | if x != n: 64 | seen.add(x) 65 | if len(indices) == len(sizes): 66 | break 67 | else: 68 | seen.discard(n) 69 | return indices 70 | 71 | 72 | class StructureSampler(Sampler[List[int]]): 73 | def __init__(self, dataset: MoleculeDataset, batch_size: int, shuffle: bool = True, seed: int = 0, *, 74 | sizes: Optional[List[int]] = None): 75 | """ 76 | Sample molecules locally grouped by size to reduce idle calculations on paddings. 77 | 78 | Example: 79 | [3, 4, 3, 3, 4, 5, 4] - sizes of molecules in dataset 80 | [[0, 2, 3], [1, 4, 6], [5]] - output indices for batch_size=3 81 | 82 | :param batch_size: batch size 83 | :param sizes: precalculated sizes of molecules. 84 | """ 85 | self.dataset = dataset 86 | self.batch_size = batch_size 87 | self.shuffle = shuffle 88 | self.seed = seed 89 | self.sizes = _build_index(dataset, sizes) 90 | 91 | def __iter__(self) -> Iterator[List[int]]: 92 | if self.shuffle: 93 | generator = Generator() 94 | generator.manual_seed(self.seed) 95 | index = _indices(randperm(len(self.sizes), generator=generator).tolist(), self.sizes, self.batch_size) 96 | else: 97 | index = _indices(range(len(self.sizes)), self.sizes, self.batch_size) 98 | 99 | index = iter(index) 100 | while batch := list(islice(index, self.batch_size)): 101 | yield batch 102 | 103 | def __len__(self): 104 | return ceil(len(self.sizes) / self.batch_size) 105 | 106 | 107 | class DistributedStructureSampler(Sampler[List[int]]): 108 | def __init__(self, dataset: MoleculeDataset, batch_size: int, num_replicas: Optional[int] = None, 109 | rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, *, 110 | sizes: Optional[List[int]] = None): 111 | """ 112 | Sample molecules locally grouped by size to reduce idle calculations on paddings. 113 | 114 | :param batch_size: expected batch size 115 | :param sizes: precalculated sizes of molecules. 116 | :param num_replicas, rank, shuffle, seed: see torch.utils.data.DistributedSampler for details. 117 | """ 118 | self.dataset = dataset 119 | self.batch_size = batch_size 120 | self.shuffle = shuffle 121 | self.seed = seed 122 | self.epoch = 0 123 | self.sizes = _build_index(dataset, sizes) 124 | 125 | # adapted from torch/utils/data/distributed.py 126 | if num_replicas is None: 127 | if not dist.is_available(): 128 | raise RuntimeError('Requires distributed package to be available') 129 | num_replicas = dist.get_world_size() 130 | if rank is None: 131 | if not dist.is_available(): 132 | raise RuntimeError('Requires distributed package to be available') 133 | rank = dist.get_rank() 134 | if rank >= num_replicas or rank < 0: 135 | raise ValueError(f'Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]') 136 | 137 | self.num_replicas = num_replicas 138 | self.rank = rank 139 | self.num_samples = ceil(len(self.sizes) / num_replicas) 140 | self.total_size = self.num_samples * num_replicas 141 | 142 | def __len__(self) -> int: 143 | return ceil(self.num_samples / self.batch_size) 144 | 145 | def __iter__(self) -> Iterator[List[int]]: 146 | if self.shuffle: 147 | generator = Generator() 148 | generator.manual_seed(self.seed + self.epoch) 149 | indices = _indices(randperm(len(self.sizes), generator=generator).tolist(), self.sizes, self.batch_size) 150 | else: 151 | indices = _indices(range(len(self.sizes)), self.sizes, self.batch_size) 152 | 153 | # adapted from torch/utils/data/distributed.py 154 | padding_size = self.total_size - len(indices) 155 | if padding_size <= len(indices): 156 | indices += indices[:padding_size] 157 | else: 158 | indices += (indices * ceil(padding_size / len(indices)))[:padding_size] 159 | assert len(indices) == self.total_size 160 | 161 | # subsample 162 | indices = indices[self.rank:self.total_size:self.num_replicas] 163 | assert len(indices) == self.num_samples 164 | 165 | indices = iter(indices) 166 | while batch := list(islice(indices, self.batch_size)): 167 | yield batch 168 | 169 | def set_epoch(self, epoch: int): 170 | self.epoch = epoch 171 | 172 | 173 | __all__ = ['StructureSampler', 'DistributedStructureSampler'] 174 | -------------------------------------------------------------------------------- /chytorch/utils/data/smiles.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from chython import MoleculeContainer, ReactionContainer, smiles 24 | from torch import Size 25 | from torch.utils.data import Dataset 26 | from typing import Dict, Union, Sequence, Optional, Type 27 | 28 | 29 | class SMILESDataset(Dataset): 30 | def __init__(self, data: Sequence[str], *, canonicalize: bool = False, cache: Optional[Dict[int, bytes]] = None, 31 | dtype: Union[Type[MoleculeContainer], Type[ReactionContainer]] = MoleculeContainer, 32 | unpack: bool = True, ignore_stereo: bool = True, ignore_bad_isotopes: bool = False, 33 | keep_implicit: bool = False, ignore_carbon_radicals: bool = False): 34 | """ 35 | Smiles to chython containers on-the-fly parser dataset. 36 | Note: SMILES strings or coded structures can be invalid and lead to exception raising. 37 | Make sure you have validated input. 38 | 39 | :param data: smiles dataset 40 | :param canonicalize: do standardization (slow, better to prepare data in advance and keep in kekule form) 41 | :param cache: dict-like object for caching processed data. caching disabled by default. 42 | :param dtype: expected type of smiles (reaction or molecule) 43 | :param unpack: return unpacked structure or chython pack 44 | :param ignore_stereo: Ignore stereo data. 45 | :param keep_implicit: keep given in smiles implicit hydrogen count, otherwise ignore on valence error. 46 | :param ignore_bad_isotopes: reset invalid isotope mark to non-isotopic. 47 | :param ignore_carbon_radicals: fill carbon radicals with hydrogen (X[C](X)X case). 48 | """ 49 | self.data = data 50 | self.canonicalize = canonicalize 51 | self.cache = cache 52 | self.dtype = dtype 53 | self.unpack = unpack 54 | self.ignore_stereo = ignore_stereo 55 | self.ignore_bad_isotopes = ignore_bad_isotopes 56 | self.keep_implicit = keep_implicit 57 | self.ignore_carbon_radicals = ignore_carbon_radicals 58 | 59 | def __getitem__(self, item: int) -> Union[MoleculeContainer, ReactionContainer, bytes]: 60 | if self.cache is not None and item in self.cache: 61 | s = self.cache[item] 62 | if self.unpack: 63 | return self.dtype.unpack(s) 64 | return s 65 | 66 | s = smiles(self.data[item], ignore_stereo=self.ignore_stereo, ignore_bad_isotopes=self.ignore_bad_isotopes, 67 | keep_implicit=self.keep_implicit, ignore_carbon_radicals=self.ignore_carbon_radicals) 68 | if not isinstance(s, self.dtype): 69 | raise TypeError(f'invalid SMILES: {self.dtype} expected, but {type(s)} given') 70 | if self.canonicalize: 71 | s.canonicalize() 72 | 73 | if self.cache is not None: 74 | p = s.pack() 75 | self.cache[item] = p 76 | if self.unpack: 77 | return s 78 | return p 79 | if self.unpack: 80 | return s 81 | return s.pack() 82 | 83 | def __len__(self): 84 | return len(self.data) 85 | 86 | def size(self, dim): 87 | if dim == 0: 88 | return len(self) 89 | elif dim is None: 90 | return Size((len(self),)) 91 | raise IndexError 92 | 93 | 94 | __all__ = ['SMILESDataset'] 95 | -------------------------------------------------------------------------------- /chytorch/utils/data/unpack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright 2023, 2024 Ramil Nugmanov 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the “Software”), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is furnished 10 | # to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | from functools import cached_property, partial 24 | from json import loads as json_loads 25 | from pickle import loads 26 | from struct import Struct 27 | from torch import Tensor, tensor, float32, Size, frombuffer 28 | from torch.utils.data import Dataset 29 | from typing import List, Tuple 30 | from zlib import decompress 31 | 32 | 33 | class StructUnpack(Dataset): 34 | def __init__(self, data: List[bytes], format_spec: str, dtype=float32, shape: Tuple[int, ...] = None): 35 | """ 36 | Unpack python.struct packed tensors to 1d-tensors. 37 | Useful in case of highly compressed data. 38 | 39 | :param data: packed data 40 | :param format_spec: python.struct format for unpacking data 41 | (e.g. '>bbl' - 2 one-byte ints and 1 big-endian 4 byte int) 42 | :param dtype: output tensor dtype 43 | :param shape: reshape unpacked 1-D tensor 44 | """ 45 | self.data = data 46 | self.format_spec = format_spec 47 | self.dtype = dtype 48 | self.shape = shape 49 | self._struct = Struct(format_spec) 50 | 51 | def __getitem__(self, item: int) -> Tensor: 52 | x = tensor(self._struct.unpack(self.data[item]), dtype=self.dtype) 53 | if self.shape is not None: 54 | return x.reshape(self.shape) 55 | return x 56 | 57 | def __len__(self): 58 | return len(self.data) 59 | 60 | def size(self, dim): 61 | if dim == 0: 62 | return len(self) 63 | elif dim is None: 64 | return Size((len(self),)) 65 | raise IndexError 66 | 67 | 68 | class TensorUnpack(Dataset): 69 | def __init__(self, data: List[bytes], dtype=float32, shape: Tuple[int, ...] = None): 70 | """ 71 | Unpack raw tensor byte buffers to 1d-tensors. 72 | 73 | :param data: packed data 74 | :param dtype: dtype of buffer 75 | :param shape: reshape unpacked 1-D tensor 76 | """ 77 | self.data = data 78 | self.dtype = dtype 79 | self.shape = shape 80 | 81 | def __getitem__(self, item: int) -> Tensor: 82 | x = frombuffer(self.data[item], dtype=self.dtype) 83 | if self.shape is not None: 84 | return x.reshape(self.shape) 85 | return x 86 | 87 | def __len__(self): 88 | return len(self.data) 89 | 90 | def size(self, dim): 91 | if dim == 0: 92 | return len(self) 93 | elif dim is None: 94 | return Size((len(self),)) 95 | raise IndexError 96 | 97 | 98 | class PickleUnpack(Dataset): 99 | def __init__(self, data: List[bytes]): 100 | """ 101 | Unpack python-pickled data. 102 | 103 | :param data: packed data 104 | """ 105 | self.data = data 106 | 107 | def __getitem__(self, item: int): 108 | return loads(self.data[item]) 109 | 110 | def __len__(self): 111 | return len(self.data) 112 | 113 | def size(self, dim): 114 | if dim == 0: 115 | return len(self) 116 | elif dim is None: 117 | return Size((len(self),)) 118 | raise IndexError 119 | 120 | 121 | class JsonUnpack(Dataset): 122 | def __init__(self, data: List[str]): 123 | """ 124 | Unpack Json data. 125 | 126 | :param data: json strings 127 | """ 128 | self.data = data 129 | 130 | def __getitem__(self, item: int): 131 | return json_loads(self.data[item]) 132 | 133 | def __len__(self): 134 | return len(self.data) 135 | 136 | def size(self, dim): 137 | if dim == 0: 138 | return len(self) 139 | elif dim is None: 140 | return Size((len(self),)) 141 | raise IndexError 142 | 143 | 144 | class Decompress(Dataset): 145 | def __init__(self, data: List[bytes], method: str = 'zlib', zdict: bytes = None): 146 | """ 147 | Decompress zipped data. 148 | 149 | :param data: compressed data 150 | :param method: zlib or zstd 151 | :param zdict: zstd decompression dictionary 152 | """ 153 | assert method in ('zlib', 'zstd') 154 | self.data = data 155 | self.method = method 156 | self.zdict = zdict 157 | 158 | def __getitem__(self, item: int) -> bytes: 159 | return self.decompress(self.data[item]) 160 | 161 | @cached_property 162 | def decompress(self): 163 | if self.method == 'zlib': 164 | return decompress 165 | # zstd 166 | from pyzstd import decompress as dc, ZstdDict 167 | 168 | if self.zdict is not None: 169 | return partial(dc, zstd_dict=ZstdDict(self.zdict)) 170 | return dc 171 | 172 | def __len__(self): 173 | return len(self.data) 174 | 175 | def size(self, dim): 176 | if dim == 0: 177 | return len(self) 178 | elif dim is None: 179 | return Size((len(self),)) 180 | raise IndexError 181 | 182 | 183 | class Decode(Dataset): 184 | def __init__(self, data: List[bytes], encoding: str = 'utf8'): 185 | """ 186 | Bytes to string decoder dataset 187 | 188 | :param data: byte-coded strings 189 | :param encoding: string encoding 190 | """ 191 | self.data = data 192 | self.encoding = encoding 193 | 194 | def __getitem__(self, item: int) -> str: 195 | return self.data[item].decode(encoding=self.encoding) 196 | 197 | def __len__(self): 198 | return len(self.data) 199 | 200 | def size(self, dim): 201 | if dim == 0: 202 | return len(self) 203 | elif dim is None: 204 | return Size((len(self),)) 205 | raise IndexError 206 | 207 | 208 | __all__ = ['TensorUnpack', 'StructUnpack', 'PickleUnpack', 'JsonUnpack', 'Decompress', 'Decode'] 209 | -------------------------------------------------------------------------------- /chytorch/zoo/README.md: -------------------------------------------------------------------------------- 1 | Trained models 2 | -------------- 3 | 4 | Namespace package for trained models collection. 5 | -------------------------------------------------------------------------------- /examples/reference_start.py: -------------------------------------------------------------------------------- 1 | import click 2 | import pandas as pd 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn as nn 6 | from chython import smiles 7 | from chython.exceptions import InvalidAromaticRing 8 | from torch.optim import Adam 9 | from torch.utils.data import DataLoader, TensorDataset 10 | from typing import Optional, Union 11 | 12 | from chytorch.nn import MoleculeEncoder, Slicer 13 | from chytorch.utils.data import MoleculeDataset, chained_collate, collate_molecules 14 | 15 | torch.manual_seed(1) 16 | 17 | # check GPU 18 | device = "cuda" if torch.cuda.is_available() else "cpu" 19 | print(f"Using {device} device") 20 | 21 | 22 | class PandasData(pl.LightningDataModule): 23 | def __init__( 24 | self, 25 | csv: str, 26 | structure: str, 27 | property: str, 28 | dataset_type: str, 29 | prepared_df_path: str, 30 | batch_size: int = 32, 31 | ): 32 | super().__init__() 33 | self.train_x = None 34 | self.train_y = None 35 | self.test_x = None 36 | self.test_y = None 37 | self.validation_x = None 38 | self.validation_y = None 39 | self.prepared_df_path = prepared_df_path 40 | self.csv = csv 41 | self.structure = structure 42 | self.property = property 43 | self.dataset_type = dataset_type 44 | self.batch_size = batch_size 45 | 46 | @staticmethod 47 | def prepare_mol(mol_smi): 48 | try: 49 | mol = smiles(mol_smi) 50 | try: 51 | mol.kekule() 52 | except InvalidAromaticRing: 53 | mol = None 54 | except Exception: 55 | mol = None 56 | return mol 57 | 58 | def prepare_data(self): 59 | df = pd.read_csv(self.csv) 60 | df = df[[self.structure, self.property, self.dataset_type]] 61 | df[self.structure] = df[self.structure].apply(self.prepare_mol) 62 | df.dropna(inplace=True) 63 | df.to_pickle(self.prepared_df_path) 64 | 65 | def setup(self, stage: Optional[str] = None): 66 | df = pd.read_pickle(self.prepared_df_path) 67 | if stage == "fit" or stage is None: 68 | df_train = df[df.dataset == "train"] 69 | mols = df_train[self.structure].to_list() 70 | self.train_x = MoleculeDataset(mols) 71 | self.train_y = torch.Tensor(df_train[self.property].to_numpy()) 72 | 73 | if stage == "validation" or stage is None: 74 | df_validation = df[df.dataset == "validation"] 75 | mols = df_validation[self.structure].to_list() 76 | self.validation_x = MoleculeDataset(mols) 77 | self.validation_y = torch.Tensor(df_validation[self.property].to_numpy()) 78 | 79 | if stage == "test" or stage is None: 80 | df_test = df[df.dataset == "test"] 81 | mols = df_test[self.structure].to_list() 82 | self.test_x = MoleculeDataset(mols) 83 | self.test_y = torch.Tensor(df_test[self.property].to_numpy()) 84 | 85 | def train_dataloader(self): 86 | return DataLoader( 87 | dataset=TensorDataset(self.train_x, self.train_y), 88 | collate_fn=chained_collate(collate_molecules, torch.stack), 89 | batch_size=self.batch_size, 90 | shuffle=True, 91 | ) 92 | 93 | def validation_dataloader(self): 94 | return DataLoader( 95 | dataset=TensorDataset(self.train_x, self.train_y), 96 | collate_fn=chained_collate(collate_molecules, torch.stack), 97 | batch_size=self.batch_size, 98 | shuffle=True, 99 | ) 100 | 101 | def test_dataloader(self): 102 | return DataLoader( 103 | dataset=TensorDataset(self.test_x, self.test_y), 104 | collate_fn=chained_collate(collate_molecules, torch.stack), 105 | batch_size=self.batch_size, 106 | ) 107 | 108 | 109 | class Modeler: 110 | def __init__( 111 | self, 112 | loss_function, 113 | epochs: int, 114 | learning_rate: Union[float, int], 115 | model_path: Optional[str] = None, 116 | ): 117 | self.network = None 118 | self.optimizer = None 119 | self.loss_function = loss_function 120 | self.learning_rate = learning_rate 121 | self.epochs = epochs 122 | self.model_path = model_path 123 | 124 | def train_loop(self, dataset_loader: DataLoader): 125 | size = len(dataset_loader.dataset) 126 | for batch, (X, y) in enumerate(dataset_loader): 127 | # compute prediction and loss 128 | predictions = self.network(X) 129 | loss = self.loss_function(predictions.squeeze(-1), y) 130 | # backpropagation 131 | self.optimizer.zero_grad() 132 | loss.backward() 133 | self.optimizer.step() 134 | if batch % 10 == 0: 135 | loss_v, current = loss.item(), batch * len(X[0]) 136 | print(f"loss: {loss_v:>3f} [{current:>5d}/{size:>5d}]") 137 | 138 | def validation_loop(self, dataset_loader: DataLoader): 139 | size = len(dataset_loader.dataset) 140 | num_batches = len(dataset_loader) 141 | self.network.eval() 142 | test_loss, correct = 0, 0 143 | with torch.no_grad(): 144 | for X, y in dataset_loader: 145 | predictions = self.network(X) 146 | test_loss += self.loss_function(predictions.squeeze(-1), y).item() 147 | correct += (predictions.argmax(1) == y).type(torch.float).sum().item() 148 | test_loss = test_loss / num_batches 149 | correct = correct / size 150 | print( 151 | f"Validation Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n" 152 | ) 153 | 154 | def save(self): 155 | torch.save(self.network.state_dict(), self.model_path) 156 | 157 | def fit(self, dataset): 158 | """ 159 | Run model training 160 | """ 161 | self.network = nn.Sequential( 162 | MoleculeEncoder(), 163 | Slicer(slice(None), 0), 164 | nn.Linear(in_features=1024, out_features=512), 165 | nn.ReLU(), 166 | nn.Dropout(0.1), 167 | nn.Linear(in_features=512, out_features=1), 168 | nn.Sigmoid(), 169 | ) 170 | self.optimizer = Adam(self.network.parameters(), lr=self.learning_rate) 171 | for epoch in range(self.epochs): 172 | print(f"Epoch {epoch + 1}\n------") 173 | self.train_loop(dataset.train_dataloader()) 174 | self.validation_loop(dataset.validation_dataloader()) 175 | if self.model_path: 176 | self.save() 177 | 178 | 179 | @click.command() 180 | @click.option( 181 | "-d", "--path_to_csv", type=click.Path(), help="Path to csv file with data." 182 | ) 183 | @click.option( 184 | "-i", 185 | "--path_to_interm_dataset", 186 | type=click.Path(), 187 | help="Path to pickle with intermediate data.", 188 | ) 189 | @click.option("-m", "--path_to_model", type=click.Path(), help="Path to model.pt.") 190 | def train(path_to_csv, path_to_interm_dataset, path_to_model): 191 | dataset = PandasData( 192 | csv=path_to_csv, 193 | structure="std_smiles", 194 | property="activity", 195 | dataset_type="dataset", 196 | prepared_df_path=path_to_interm_dataset, 197 | batch_size=10, 198 | ) 199 | dataset.prepare_data() 200 | dataset.setup() 201 | modeler = Modeler( 202 | loss_function=nn.BCELoss(), 203 | epochs=3, 204 | learning_rate=2e-5, 205 | model_path=path_to_model, 206 | ) 207 | modeler.fit(dataset) 208 | 209 | 210 | if __name__ == "__main__": 211 | train() 212 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = 'chytorch' 3 | version = '1.66' 4 | description = 'Library for modeling molecules and reactions in torch way' 5 | authors = ['Ramil Nugmanov '] 6 | license = 'MIT' 7 | readme = 'README.md' 8 | homepage = 'https://github.com/chython/chytorch' 9 | classifiers=[ 10 | 'Environment :: Plugins', 11 | 'Intended Audience :: Science/Research', 12 | 'License :: OSI Approved :: MIT License', 13 | 'Operating System :: OS Independent', 14 | 'Programming Language :: Python', 15 | 'Programming Language :: Python :: 3 :: Only', 16 | 'Programming Language :: Python :: 3.8', 17 | 'Topic :: Scientific/Engineering', 18 | 'Topic :: Scientific/Engineering :: Chemistry', 19 | 'Topic :: Scientific/Engineering :: Information Analysis', 20 | 'Topic :: Software Development', 21 | 'Topic :: Software Development :: Libraries', 22 | 'Topic :: Software Development :: Libraries :: Python Modules' 23 | ] 24 | 25 | include = [ 26 | {path = 'chytorch/utils/data/molecule/*.pyd', format = 'wheel'}, 27 | {path = 'chytorch/utils/data/molecule/*.so', format = 'wheel'} 28 | ] 29 | 30 | [tool.poetry.dependencies] 31 | python = '>=3.8,<3.12' 32 | torchtyping = '^0.1.4' 33 | chython = '^1.70' 34 | scipy = '^1.10' 35 | torch = '>=1.8' 36 | lmdb = {version='^1.4.1', optional = true} 37 | psycopg2-binary = {version='^2.9', optional = true} 38 | rdkit = {version = '^2023.9.1', optional = true} 39 | pyzstd = {version = '^0.15.9', optional = true} 40 | 41 | [tool.poetry.extras] 42 | lmdb = ['lmdb'] 43 | postgres = ['psycopg2-binary'] 44 | rdkit = ['rdkit'] 45 | zstd = ['pyzstd'] 46 | 47 | [build-system] 48 | requires = ['poetry-core', 'setuptools', 'cython>=3.0.5', 'numpy>=1.23.3'] 49 | build-backend = 'poetry.core.masonry.api' 50 | 51 | [tool.poetry.build] 52 | script = 'build.py' 53 | generate-setup-file = false 54 | --------------------------------------------------------------------------------