├── .github
└── workflows
│ └── python-package.yml
├── .gitignore
├── LICENSE
├── README.md
├── build.py
├── chytorch
├── nn
│ ├── __init__.py
│ ├── lora
│ │ ├── __init__.py
│ │ ├── embedding.py
│ │ └── linear.py
│ ├── losses.py
│ ├── molecule
│ │ ├── __init__.py
│ │ ├── _embedding.py
│ │ └── encoder.py
│ ├── reaction.py
│ ├── slicer.py
│ ├── transformer
│ │ ├── __init__.py
│ │ ├── attention
│ │ │ ├── __init__.py
│ │ │ └── graphormer.py
│ │ └── encoder.py
│ └── voting
│ │ ├── __init__.py
│ │ ├── _kfold.py
│ │ ├── binary.py
│ │ ├── classifier.py
│ │ └── regressor.py
├── utils
│ ├── __init__.py
│ └── data
│ │ ├── __init__.py
│ │ ├── _utils.py
│ │ ├── lmdb.py
│ │ ├── molecule
│ │ ├── __init__.py
│ │ ├── _unpack.pyx
│ │ ├── conformer.py
│ │ ├── dummy.py
│ │ ├── encoder.py
│ │ └── rdkit.py
│ │ ├── product.py
│ │ ├── reaction
│ │ ├── __init__.py
│ │ └── encoder.py
│ │ ├── sampler.py
│ │ ├── smiles.py
│ │ └── unpack.py
└── zoo
│ └── README.md
├── examples
└── reference_start.py
└── pyproject.toml
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | name: Build Python packages
2 |
3 | on:
4 | release:
5 | types: [published]
6 | workflow_dispatch:
7 |
8 | jobs:
9 | binary:
10 | runs-on: ${{ matrix.os }}
11 | strategy:
12 | matrix:
13 | os: [windows-latest, macos-latest, ubuntu-20.04]
14 | python-version: ["3.8", "3.9", "3.10", "3.11"]
15 | steps:
16 | - uses: actions/checkout@v3
17 | - name: Set up Python ${{ matrix.python-version }}
18 | uses: actions/setup-python@v3
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip poetry twine
24 | - name: Build wheel
25 | run: |
26 | poetry build -f wheel
27 | - name: Publish package
28 | run: |
29 | twine upload -u __token__ -p ${{ secrets.PYPI_API_TOKEN }} --non-interactive --skip-existing dist/*
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .dvc/config.local
2 | .dvc/cache
3 | .dvc/plots
4 | .dvc/tmp
5 |
6 | __pycache__
7 | *.py[cod]
8 | *.pt
9 | *.so
10 | *.dll
11 | *.dynlib
12 | *.c
13 |
14 | .idea
15 | venv
16 | zoo
17 |
18 | *.egg-info
19 | build
20 | dist
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Chytorch [kʌɪtɔːrtʃ]
2 | ====================
3 |
4 | Library for modeling molecules and reactions in torch way.
5 |
6 | Installation
7 | ------------
8 |
9 | Use `pip install chytorch` to install release version.
10 |
11 | Or `pip install .` in source code directory to install DEV version.
12 |
13 | Pretrained models
14 | -----------------
15 |
16 | Chytorch main package doesn't include models zoo.
17 | Each model has its own named package and can be installed separately.
18 | Installed models can be imported as `from chytorch.zoo. import Model`.
19 |
20 |
21 | Usage
22 | -----
23 |
24 | `chytorch.nn.MoleculeEncoder` - core graphormer layer for molecules encoding.
25 | API is combination of `torch.nn.TransformerEncoderLayer` with `torch.nn.TransformerEncoder`.
26 |
27 | **Batch preparation:**
28 |
29 | `chytorch.utils.data.MoleculeDataset` - Map-like on-the-fly dataset generators for molecules.
30 | Supported `chython.MoleculeContainer` objects, and PaCh structures.
31 |
32 | `chytorch.utils.data.collate_molecules` - collate function for `torch.utils.data.DataLoader`.
33 |
34 | Note: torch DataLoader automatically do proper collation since 1.13 release.
35 |
36 | Example:
37 |
38 | from chytorch.utils.data import MoleculeDataset, SMILESDataset
39 | from torch.utils.data import DataLoader
40 |
41 | data = ['CCO', 'CC=O']
42 | ds = MoleculeDataset(SMILESDataset(data, cache={}))
43 | dl = DataLoader(ds, batch_size=10)
44 |
45 | **Forward call:**
46 |
47 | Molecules coded as tensors of:
48 | * atoms numbers shifted by 2 (e.g. hydrogen = 3).
49 | 0 - reserved for padding, 1 - reserved for CLS token, 2 - extra reservation.
50 | * neighbors count, including implicit hydrogens shifted by 2 (e.g. CO = CH3OH = [6, 4]).
51 | 0 - reserved for padding, 1 - extra reservation, 2 - no-neighbors, 3 - one neighbor.
52 | * topological distances' matrix shifted by 2 with upper limit.
53 | 0 - reserved for padding, 1 - reserved for not-connected graph components coding, 2 - self-loop, 3 - connected atoms.
54 |
55 | from chytorch.nn import MoleculeEncoder
56 |
57 | encoder = MoleculeEncoder()
58 | for b in dl:
59 | encoder(b)
60 |
61 | **Combine molecules and labels:**
62 |
63 | `chytorch.utils.data.chained_collate` - helper for combining different data parts. Useful for tricky input.
64 |
65 | from torch import stack
66 | from torch.utils.data import DataLoader, TensorDataset
67 | from chytorch.utils.data import chained_collate, collate_molecules, MoleculeDataset
68 |
69 | dl = DataLoader(TensorDataset(MoleculeDataset(molecules_list), properties_tensor),
70 | collate_fn=chained_collate(collate_molecules, stack))
71 |
72 | **Voting NN with single hidden layer:**
73 |
74 | `chytorch.nn.VotingClassifier`, `chytorch.nn.BinaryVotingClassifier` and `chytorch.nn.VotingRegressor` - speed optimized multiple heads for ensemble predictions.
75 |
76 | **Helper Modules:**
77 |
78 | `chytorch.nn.Slicer` - do tensor slicing. Useful for transformer's CLS token extraction in `torch.nn.Sequence`.
79 |
80 | **Data Wrappers:**
81 |
82 | In `chytorch.utils.data` module stored different data wrappers for simplifying ML workflows.
83 | All wrappers have `torch.utils.data.Dataset` interface.
84 |
85 | * `SizedList` - list wrapper with `size()` method. Useful with `torch.utils.data.TensorDataset`.
86 | * `SMILESDataset` - on-the-fly smiles to `chython.MoleculeContainer` or `chython.ReactionContainer` parser.
87 | * `LMDBMapper` - LMDB KV storage to dataset mapper.
88 | * `TensorUnpack`, `StructUnpack`, `PickleUnpack` - bytes to tensor/object unpackers
89 |
90 |
91 | Publications
92 | ------------
93 |
94 | [1](https://doi.org/10.1021/acs.jcim.2c00344) Bidirectional Graphormer for Reactivity Understanding: Neural Network Trained to Reaction Atom-to-Atom Mapping Task
95 |
--------------------------------------------------------------------------------
/build.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from Cython.Build import build_ext, cythonize
24 | from numpy import get_include
25 | from pathlib import Path
26 | from setuptools import Extension
27 | from setuptools.dist import Distribution
28 | from shutil import copyfile
29 | from sysconfig import get_platform
30 |
31 |
32 | platform = get_platform()
33 | if platform == 'win-amd64':
34 | extra_compile_args = ['/O2']
35 | elif platform == 'linux-x86_64':
36 | extra_compile_args = ['-O3']
37 | else:
38 | extra_compile_args = []
39 |
40 | extensions = [
41 | Extension('chytorch.utils.data.molecule._unpack',
42 | ['chytorch/utils/data/molecule/_unpack.pyx'],
43 | extra_compile_args=extra_compile_args,
44 | include_dirs=[get_include()]),
45 | ]
46 |
47 | ext_modules = cythonize(extensions, language_level=3)
48 | cmd = build_ext(Distribution({'ext_modules': ext_modules}))
49 | cmd.ensure_finalized()
50 | cmd.run()
51 |
52 | for output in cmd.get_outputs():
53 | output = Path(output)
54 | copyfile(output, output.relative_to(cmd.build_lib))
55 |
--------------------------------------------------------------------------------
/chytorch/nn/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2021-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from .losses import *
24 | from .molecule import *
25 | from .reaction import *
26 | from .slicer import *
27 | from .voting import *
28 |
29 |
30 | __all__ = ['MoleculeEncoder',
31 | 'ReactionEncoder',
32 | 'Slicer',
33 | 'VotingClassifier', 'VotingRegressor', 'BinaryVotingClassifier',
34 | 'MultiTaskLoss',
35 | 'CensoredLoss',
36 | 'MaskedNaNLoss',
37 | 'MSLELoss']
38 |
--------------------------------------------------------------------------------
/chytorch/nn/lora/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from .embedding import *
24 | from .linear import *
25 |
26 |
27 | __all__ = ['Embedding', 'Linear']
28 |
--------------------------------------------------------------------------------
/chytorch/nn/lora/embedding.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from math import inf
24 | from torch import empty, no_grad, addmm, Tensor
25 | from torch.nn import Embedding as tEmbedding, Parameter, init
26 | from torch.nn.functional import embedding
27 | from typing import Optional
28 |
29 |
30 | class Embedding(tEmbedding):
31 | """
32 | LoRA wrapped Embedding layer.
33 | """
34 | def __init__(self, *args, neg_inf_idx: Optional[int] = None, **kwargs):
35 | """
36 | :param neg_inf_idx: -inf frozen embedding vector
37 |
38 | See torch.nn.Embedding for other params
39 | """
40 | super().__init__(*args, **kwargs)
41 | self.neg_inf_idx = neg_inf_idx
42 | self.lora_r = 0
43 | if neg_inf_idx is not None:
44 | with no_grad():
45 | self.weight[neg_inf_idx].fill_(-inf)
46 |
47 | def forward(self, x: Tensor) -> Tensor:
48 | emb = super().forward(x)
49 | if self.lora_r:
50 | a = embedding(x, self.lora_a, self.padding_idx, self.max_norm,
51 | self.norm_type, self.scale_grad_by_freq, self.sparse)
52 | return addmm(emb.flatten(end_dim=-2), a.flatten(end_dim=-2), self.lora_b.transpose(0, 1),
53 | alpha=self._lora_scaling).view(emb.shape)
54 | return emb
55 |
56 | def activate_lora(self, lora_r: int = 0, lora_alpha: float = 1.):
57 | """
58 | :param lora_r: LoRA factorization dimension
59 | :param lora_alpha: LoRA scaling factor
60 | """
61 | assert lora_r > 0, 'rank should be greater than zero'
62 | self.weight.requires_grad = False # freeze main weights
63 | self.lora_a = Parameter(init.zeros_(empty(self.num_embeddings, lora_r)))
64 | self.lora_b = Parameter(init.normal_(empty(self.embedding_dim, lora_r)))
65 |
66 | self.lora_r = lora_r
67 | self.lora_alpha = lora_alpha
68 | self._lora_scaling = lora_alpha / lora_r
69 |
70 | def merge_lora(self):
71 | """
72 | Transform LoRA embedding to normal
73 | """
74 | if not self.lora_r:
75 | return
76 | self.weight.data += (self.lora_a @ self.lora_b.transpose(0, 1)) * self._lora_scaling
77 | self.weight.requires_grad = True
78 | self.lora_r = 0
79 | del self.lora_a, self.lora_b, self.lora_alpha, self._lora_scaling
80 |
81 | def extra_repr(self) -> str:
82 | r = super().extra_repr()
83 | if self.neg_inf_idx is not None:
84 | r += f', neg_inf_idx={self.neg_inf_idx}'
85 | if self.lora_r:
86 | r += f', lora_r={self.lora_r}, lora_alpha={self.lora_alpha}'
87 | return r
88 |
89 |
90 | __all__ = ['Embedding']
91 |
--------------------------------------------------------------------------------
/chytorch/nn/lora/linear.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023, 2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from math import sqrt
24 | from torch import empty, addmm, Tensor
25 | from torch.nn import Linear as tLinear, Parameter, init
26 | from torch.nn.functional import dropout
27 |
28 |
29 | class Linear(tLinear):
30 | """
31 | LoRA wrapped Linear layer.
32 | """
33 | def __init__(self, *args, **kwargs):
34 | super().__init__(*args, **kwargs)
35 | self.lora_r = 0
36 |
37 | def forward(self, x: Tensor) -> Tensor:
38 | out = super().forward(x)
39 | if self.lora_r:
40 | if self.training and self.lora_dropout:
41 | x = dropout(x, self.lora_dropout)
42 | a = x @ self.lora_a.transpose(0, 1)
43 | return addmm(out.flatten(end_dim=-2), a.flatten(end_dim=-2), self.lora_b.transpose(0, 1),
44 | alpha=self._lora_scaling).view(out.shape)
45 | return out
46 |
47 | def activate_lora(self, lora_r: int = 0, lora_alpha: float = 1., lora_dropout: float = 0.):
48 | """
49 | :param lora_r: LoRA factorization dimension
50 | :param lora_alpha: LoRA scaling factor
51 | :param lora_dropout: LoRA input dropout
52 | """
53 | assert lora_r > 0, 'rank should be greater than zero'
54 | self.weight.requires_grad = False # freeze main weights
55 | self.lora_a = Parameter(init.kaiming_uniform_(empty(lora_r, self.in_features), a=sqrt(5)))
56 | self.lora_b = Parameter(init.zeros_(empty(self.out_features, lora_r)))
57 |
58 | self.lora_r = lora_r
59 | self.lora_dropout = lora_dropout
60 | self.lora_alpha = lora_alpha
61 | self._lora_scaling = lora_alpha / lora_r
62 |
63 | def merge_lora(self):
64 | """
65 | Transform LoRA linear to normal
66 | """
67 | if not self.lora_r:
68 | return
69 | self.weight.data += (self.lora_b @ self.lora_a) * self._lora_scaling
70 | self.weight.requires_grad = True
71 | self.lora_r = 0
72 | del self.lora_a, self.lora_b, self.lora_dropout, self.lora_alpha, self._lora_scaling
73 |
74 | def extra_repr(self) -> str:
75 | r = super().extra_repr()
76 | if self.lora_r:
77 | return r + f', lora_r={self.lora_r}, lora_alpha={self.lora_alpha}, lora_dropout={self.lora_dropout}'
78 | return r
79 |
80 |
81 | __all__ = ['Linear']
82 |
--------------------------------------------------------------------------------
/chytorch/nn/losses.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023, 2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from torch import float32, zeros_like, exp, Tensor
24 | from torch.nn import Parameter, MSELoss
25 | from torch.nn.modules.loss import _Loss
26 | from torchtyping import TensorType
27 |
28 |
29 | class MultiTaskLoss(_Loss):
30 | """
31 | Auto-scalable loss for multitask training.
32 |
33 | https://arxiv.org/abs/1705.07115
34 | """
35 | def __init__(self, loss_type: TensorType['loss_type', bool], *, reduction='mean'):
36 | """
37 | :param loss_type: vector equal to the number of tasks losses. True for regression and False for classification.
38 | """
39 | super().__init__(reduction=reduction)
40 | self.log = Parameter(zeros_like(loss_type, dtype=float32))
41 | self.register_buffer('coefficient', (loss_type + 1.).to(float32))
42 |
43 | def forward(self, x: TensorType['loss', float]):
44 | """
45 | :param x: 1d vector of losses or 2d matrix of batch X losses.
46 | """
47 | x = x / (self.coefficient * exp(self.log)) + self.log / 2
48 |
49 | if self.reduction == 'sum':
50 | return x.sum()
51 | elif self.reduction == 'mean':
52 | return x.mean()
53 | return x
54 |
55 |
56 | class CensoredLoss(_Loss):
57 | """
58 | Loss wrapper masking input under target qualifier range.
59 |
60 | Masking strategy for different qualifiers (rows) and input-target relations (columns) described below:
61 |
62 | | I < T | I = T | I > T |
63 | ---|-------|-------|-------|
64 | -1 | Mask | | |
65 | 0 | | | |
66 | 1 | | | Mask |
67 |
68 | Wrapped loss should not be configured with mean reduction.
69 | Wrapper does proper mean reduction by ignoring masked values.
70 |
71 | Note: wrapped loss should correctly treat zero-valued input and targets.
72 | """
73 | def __init__(self, loss: _Loss, reduction: str = 'mean', eps: float = 1e-5):
74 | assert loss.reduction != 'mean', 'given loss should not be configured to `mean` reduction'
75 | super().__init__(reduction=reduction)
76 | self.loss = loss
77 | self.eps = eps
78 |
79 | def forward(self, input: Tensor, target: Tensor, qualifier: Tensor) -> Tensor:
80 | mask = ((qualifier >= 0) | (input >= target)) & ((qualifier <= 0) | (input <= target))
81 | loss = self.loss(input * mask, target * mask)
82 | if self.reduction == 'mean':
83 | if self.loss.reduction == 'none':
84 | loss = loss.sum()
85 | return loss / (mask.sum() + self.eps)
86 | elif self.reduction == 'sum':
87 | if self.loss.reduction == 'none':
88 | return loss.sum()
89 | return loss
90 | return loss # reduction='none'
91 |
92 |
93 | class MaskedNaNLoss(_Loss):
94 | """
95 | Loss wrapper masking nan targets and corresponding input values as zeros.
96 | Wrapped loss should not be configured with mean reduction.
97 | Wrapper does proper mean reduction by ignoring masked values.
98 |
99 | Note: wrapped loss should correctly treat zero-valued input and targets.
100 | """
101 | def __init__(self, loss: _Loss, reduction: str = 'mean', eps: float = 1e-5):
102 | assert loss.reduction != 'mean', 'given loss should not be configured to `mean` reduction'
103 | super().__init__(reduction=reduction)
104 | self.loss = loss
105 | self.eps = eps
106 |
107 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
108 | mask = ~target.isnan()
109 | loss = self.loss(input * mask, target.nan_to_num())
110 | if self.reduction == 'mean':
111 | if self.loss.reduction == 'none':
112 | loss = loss.sum()
113 | return loss / (mask.sum() + self.eps)
114 | elif self.reduction == 'sum':
115 | if self.loss.reduction == 'none':
116 | return loss.sum()
117 | return loss
118 | return loss # reduction='none'
119 |
120 |
121 | class MSLELoss(MSELoss):
122 | r"""
123 | Mean Squared Logarithmic Error:
124 |
125 | .. math:: \text{MSLE} = \frac{1}{N}\sum_i^N (\log_e(1 + y_i) - \log_e(1 + \hat{y_i}))^2
126 |
127 | Note: Works only for positive target values range. Implicitly clamps negative input.
128 | """
129 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
130 | return super().forward((input.clamp(min=0) + 1).log(), (target + 1).log())
131 |
132 |
133 | __all__ = ['MultiTaskLoss',
134 | 'CensoredLoss',
135 | 'MaskedNaNLoss',
136 | 'MSLELoss']
137 |
--------------------------------------------------------------------------------
/chytorch/nn/molecule/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023, 2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from .encoder import *
24 |
25 |
26 | __all__ = ['MoleculeEncoder']
27 |
--------------------------------------------------------------------------------
/chytorch/nn/molecule/_embedding.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from torch import empty_like
24 | from torch.nn import Module
25 | from ..lora import Embedding
26 |
27 |
28 | class EmbeddingBag(Module):
29 | def __init__(self, max_neighbors: int = 14, d_model: int = 1024, perturbation: float = 0., max_tokens: int = 121):
30 | assert perturbation >= 0, 'zero or positive perturbation expected'
31 | assert max_tokens >= 121, 'at least 121 tokens should be'
32 | super().__init__()
33 | self.atoms_encoder = Embedding(max_tokens, d_model, 0)
34 | self.neighbors_encoder = Embedding(max_neighbors + 3, d_model, 0)
35 |
36 | self.max_neighbors = max_neighbors
37 | self.perturbation = perturbation
38 | self.max_tokens = max_tokens
39 |
40 | def forward(self, atoms, neighbors):
41 | # cls token in neighbors coded by 0
42 | x = self.atoms_encoder(atoms) + self.neighbors_encoder(neighbors)
43 |
44 | if self.perturbation and self.training:
45 | x = x + empty_like(x).uniform_(-self.perturbation, self.perturbation)
46 | return x
47 |
48 |
49 | __all__ = ['EmbeddingBag']
50 |
--------------------------------------------------------------------------------
/chytorch/nn/molecule/encoder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2021-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from itertools import repeat
24 | from torch.nn import GELU, Module, ModuleList, LayerNorm
25 | from torchtyping import TensorType
26 | from typing import Tuple, Optional, List
27 | from warnings import warn
28 | from ._embedding import EmbeddingBag
29 | from ..lora import Embedding
30 | from ..transformer import EncoderLayer
31 | from ...utils.data import MoleculeDataBatch
32 |
33 |
34 | def _update(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
35 | if prefix + 'centrality_encoder.weight' in state_dict:
36 | warn('fixed chytorch<1.37 checkpoint', DeprecationWarning)
37 | state_dict[prefix + 'neighbors_encoder.weight'] = state_dict.pop(prefix + 'centrality_encoder.weight')
38 | state_dict[prefix + 'distance_encoder.weight'] = state_dict.pop(prefix + 'spatial_encoder.weight')
39 | if prefix + 'atoms_encoder.weight' in state_dict:
40 | warn('fixed chytorch<1.61 checkpoint', DeprecationWarning)
41 | state_dict[prefix + 'embedding.atoms_encoder.weight'] = state_dict.pop(prefix + 'atoms_encoder.weight')
42 | state_dict[prefix + 'embedding.neighbors_encoder.weight'] = state_dict.pop(prefix + 'neighbors_encoder.weight')
43 |
44 |
45 | class MoleculeEncoder(Module):
46 | """
47 | Inspired by https://arxiv.org/pdf/2106.05234.pdf
48 | """
49 | def __init__(self, max_neighbors: int = 14, max_distance: int = 10, d_model: int = 1024, nhead: int = 16,
50 | num_layers: int = 8, dim_feedforward: int = 3072, shared_weights: bool = True,
51 | shared_attention_bias: bool = True, dropout: float = 0.1, activation=GELU,
52 | layer_norm_eps: float = 1e-5, norm_first: bool = False, post_norm: bool = False,
53 | zero_bias: bool = False, perturbation: float = 0., max_tokens: int = 121,
54 | projection_bias: bool = True, ff_bias: bool = True):
55 | """
56 | Molecule Graphormer from https://doi.org/10.1021/acs.jcim.2c00344.
57 |
58 | :param max_neighbors: maximum atoms neighbors count.
59 | :param max_distance: maximal distance between atoms.
60 | :param shared_weights: ALBERT-like encoder weights sharing.
61 | :param norm_first: do pre-normalization in encoder layers.
62 | :param post_norm: do normalization of output. Works only when norm_first=True.
63 | :param zero_bias: use frozen zero bias of attention for non-reachable atoms.
64 | :param perturbation: add perturbation to embedding (https://aclanthology.org/2021.naacl-main.460.pdf).
65 | Disabled by default
66 | :param shared_attention_bias: use shared distance encoder or unique for each transformer layer.
67 | :param max_tokens: number of tokens in the atom encoder embedding layer.
68 | """
69 | super().__init__()
70 | self.embedding = EmbeddingBag(max_neighbors, d_model, perturbation, max_tokens)
71 |
72 | self.shared_attention_bias = shared_attention_bias
73 | if shared_attention_bias:
74 | self.distance_encoder = Embedding(max_distance + 3, nhead, int(zero_bias) or None, neg_inf_idx=0)
75 | # None filled encoders mean reusing previously calculated bias. possible manually create different arch.
76 | # this done for speedup in comparison to layer duplication.
77 | self.distance_encoders = [None] * num_layers
78 | self.distance_encoders[0] = self.distance_encoder # noqa
79 | else:
80 | self.distance_encoders = ModuleList(Embedding(max_distance + 3, nhead,
81 | int(zero_bias) or None, neg_inf_idx=0)
82 | for _ in range(num_layers))
83 |
84 | self.max_distance = max_distance
85 | self.max_neighbors = max_neighbors
86 | self.perturbation = perturbation
87 | self.num_layers = num_layers
88 | self.max_tokens = max_tokens
89 | self.post_norm = post_norm
90 | self.d_model = d_model
91 | self.nhead = nhead
92 | self.dim_feedforward = dim_feedforward
93 | self.dropout = dropout
94 | self.activation = activation
95 | self.layer_norm_eps = layer_norm_eps
96 | self.norm_first = norm_first
97 | self.zero_bias = zero_bias
98 | if post_norm:
99 | assert norm_first, 'post_norm requires norm_first'
100 | self.norm = LayerNorm(d_model, layer_norm_eps)
101 |
102 | self.shared_weights = shared_weights
103 | if shared_weights:
104 | self.layer = EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, norm_first,
105 | projection_bias=projection_bias, ff_bias=ff_bias)
106 | self.layers = [self.layer] * num_layers
107 | else:
108 | # layers sharing scheme can be manually changed. e.g. pairs of shared encoders
109 | self.layers = ModuleList(EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation,
110 | layer_norm_eps, norm_first, projection_bias=projection_bias,
111 | ff_bias=ff_bias) for _ in range(num_layers))
112 | self._register_load_state_dict_pre_hook(_update)
113 |
114 | def forward(self, batch: MoleculeDataBatch, /, *,
115 | cache: Optional[List[Tuple[TensorType['batch', 'atoms+conditions', 'embedding'],
116 | TensorType['batch', 'atoms+conditions', 'embedding']]]] = None) -> \
117 | TensorType['batch', 'atoms', 'embedding']:
118 | """
119 | Use 0 for padding.
120 | Atoms should be coded by atomic numbers + 2.
121 | Token 1 reserved for cls token, 2 reserved for molecule cls or training tricks like MLM.
122 | Neighbors should be coded from 2 (means no neighbors) to max neighbors + 2.
123 | Neighbors equal to 1 reserved for training tricks like MLM. Use 0 for cls.
124 | Distances should be coded from 2 (means self-loop) to max_distance + 2.
125 | Non-reachable atoms should be coded by 1.
126 | """
127 | cache = repeat(None) if cache is None else iter(cache)
128 | atoms, neighbors, distances = batch
129 |
130 | x = self.embedding(atoms, neighbors)
131 |
132 | for lr, d, c in zip(self.layers, self.distance_encoders, cache):
133 | if d is not None:
134 | d_mask = d(distances).permute(0, 3, 1, 2) # BxNxNxH > BxHxNxN
135 | # else: reuse previously calculated mask
136 | x, _ = lr(x, d_mask, cache=c) # noqa
137 |
138 | if self.post_norm:
139 | return self.norm(x)
140 | return x
141 |
142 | @property
143 | def centrality_encoder(self):
144 | warn('centrality_encoder renamed to neighbors_encoder in chytorch 1.37', DeprecationWarning)
145 | return self.neighbors_encoder
146 |
147 | @property
148 | def spatial_encoder(self):
149 | warn('spatial_encoder renamed to distance_encoder in chytorch 1.37', DeprecationWarning)
150 | return self.distance_encoder
151 |
152 | @property
153 | def atoms_encoder(self):
154 | warn('neighbors_encoder moved to embedding submodule in chytorch 1.61', DeprecationWarning)
155 | return self.embedding.atoms_encoder
156 |
157 | @property
158 | def neighbors_encoder(self):
159 | warn('neighbors_encoder moved to embedding submodule in chytorch 1.61', DeprecationWarning)
160 | return self.embedding.neighbors_encoder
161 |
162 |
163 | __all__ = ['MoleculeEncoder']
164 |
--------------------------------------------------------------------------------
/chytorch/nn/reaction.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2021-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from math import inf
24 | from torch import zeros_like, float as t_float
25 | from torch.nn import Embedding, GELU, Module
26 | from torchtyping import TensorType
27 | from .molecule import MoleculeEncoder
28 | from .transformer import EncoderLayer
29 | from ..utils.data import ReactionEncoderDataBatch
30 |
31 |
32 | class ReactionEncoder(Module):
33 | def __init__(self, max_neighbors: int = 14, max_distance: int = 10, d_model: int = 1024, n_in_head: int = 16,
34 | n_ex_head: int = 4, num_in_layers: int = 8, num_ex_layers: int = 8,
35 | dim_feedforward: int = 3072, dropout: float = 0.1, activation=GELU, layer_norm_eps: float = 1e-5):
36 | """
37 | Reaction Graphormer from https://doi.org/10.1021/acs.jcim.2c00344.
38 |
39 | :param max_neighbors: maximum atoms neighbors count.
40 | :param max_distance: maximal distance between atoms.
41 | :param num_in_layers: intramolecular layers count
42 | :param num_ex_layers: reaction-level layers count
43 | """
44 | super().__init__()
45 | self.molecule_encoder = MoleculeEncoder(max_neighbors=max_neighbors, max_distance=max_distance, d_model=d_model,
46 | nhead=n_in_head, num_layers=num_in_layers,
47 | dim_feedforward=dim_feedforward, dropout=dropout, activation=activation,
48 | layer_norm_eps=layer_norm_eps)
49 | self.role_encoder = Embedding(4, d_model, 0)
50 | self.layer = EncoderLayer(d_model, n_ex_head, dim_feedforward, dropout, activation, layer_norm_eps)
51 | self.layers = [self.layer] * num_ex_layers
52 | self.nhead = n_ex_head
53 |
54 | @property
55 | def max_distance(self):
56 | """
57 | Distance cutoff in spatial encoder.
58 | """
59 | return self.molecule_encoder.max_distance
60 |
61 | def forward(self, batch: ReactionEncoderDataBatch) -> TensorType['batch', 'atoms', 'embedding']:
62 | """
63 | Use 0 for padding. Roles should be coded by 2 for reactants, 3 for products and 1 for special cls token.
64 | Distances - same as molecular encoder distances but batched diagonally.
65 | Used 0 for disabling sharing between molecules.
66 | """
67 | atoms, neighbors, distances, roles = batch
68 | n = atoms.size(1)
69 | d_mask = zeros_like(roles, dtype=t_float).masked_fill_(roles == 0, -inf).view(-1, 1, 1, n) # BxN > Bx1x1xN >
70 | d_mask = d_mask.expand(-1, self.nhead, n, -1) # > BxHxNxN
71 |
72 | # role is bert sentence encoder used to separate reactants from products and rxn CLS token coding.
73 | # multiplication by roles > 1 used to zeroing rxn cls token and padding. this zeroing gradients too.
74 | x = self.molecule_encoder((atoms, neighbors, distances)) * (roles > 1).unsqueeze_(-1)
75 | x = x + self.role_encoder(roles)
76 |
77 | for lr in self.layers:
78 | x, _ = lr(x, d_mask)
79 | return x
80 |
81 |
82 | __all__ = ['ReactionEncoder']
83 |
--------------------------------------------------------------------------------
/chytorch/nn/slicer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from torch import Tensor
24 | from torch.nn import Module
25 | from typing import Tuple, Union
26 |
27 |
28 | class Slicer(Module):
29 | def __init__(self, *slc: Union[int, slice, Tuple[int, ...]]):
30 | """
31 | Slice input tensor. For use with Sequential.
32 |
33 | E.g. Slicer(slice(None), 0) equal to Tensor[:, 0]
34 | """
35 | super().__init__()
36 | self.slice = slc if len(slc) > 1 else slc[0]
37 |
38 | def forward(self, x: Tensor):
39 | return x.__getitem__(self.slice)
40 |
41 |
42 | __all__ = ['Slicer']
43 |
--------------------------------------------------------------------------------
/chytorch/nn/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023, 2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from .attention import *
24 | from .encoder import *
25 |
26 |
27 | __all__ = ['EncoderLayer',
28 | 'MLP',
29 | 'LLaMAMLP',
30 | 'GraphormerAttention']
31 |
--------------------------------------------------------------------------------
/chytorch/nn/transformer/attention/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from .graphormer import *
24 |
25 |
26 | __all__ = ['GraphormerAttention']
27 |
--------------------------------------------------------------------------------
/chytorch/nn/transformer/attention/graphormer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023, 2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from math import sqrt
24 | from torch import softmax, cat, Tensor
25 | from torch.nn import Module
26 | from torch.nn.functional import dropout
27 | from typing import Optional, Tuple
28 | from warnings import warn
29 | from ...lora import Linear
30 |
31 |
32 | def _update_unpacked(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
33 | if prefix + 'in_proj_weight' in state_dict:
34 | warn('fixed chytorch<1.44 checkpoint', DeprecationWarning)
35 | state_dict[prefix + 'qkv_proj.weight'] = state_dict.pop(prefix + 'in_proj_weight')
36 | state_dict[prefix + 'qkv_proj.bias'] = state_dict.pop(prefix + 'in_proj_bias')
37 | state_dict[prefix + 'o_proj.weight'] = state_dict.pop(prefix + 'out_proj.weight')
38 | state_dict[prefix + 'o_proj.bias'] = state_dict.pop(prefix + 'out_proj.bias')
39 |
40 | if prefix + 'qkv_proj.weight' in state_dict: # transform packed projection
41 | q_w, k_w, v_w = state_dict.pop(prefix + 'qkv_proj.weight').chunk(3, dim=0)
42 | state_dict[prefix + 'q_proj.weight'] = q_w
43 | state_dict[prefix + 'k_proj.weight'] = k_w
44 | state_dict[prefix + 'v_proj.weight'] = v_w
45 |
46 | if prefix + 'qkv_proj.bias' in state_dict:
47 | q_b, k_b, v_b = state_dict.pop(prefix + 'qkv_proj.bias').chunk(3, dim=0)
48 | state_dict[prefix + 'q_proj.bias'] = q_b
49 | state_dict[prefix + 'k_proj.bias'] = k_b
50 | state_dict[prefix + 'v_proj.bias'] = v_b
51 |
52 |
53 | def _update_packed(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
54 | if prefix + 'in_proj_weight' in state_dict:
55 | warn('fixed chytorch<1.44 checkpoint', DeprecationWarning)
56 | state_dict[prefix + 'o_proj.weight'] = state_dict.pop(prefix + 'out_proj.weight')
57 | state_dict[prefix + 'o_proj.bias'] = state_dict.pop(prefix + 'out_proj.bias')
58 |
59 | state_dict[prefix + 'qkv_proj.weight'] = state_dict.pop(prefix + 'in_proj_weight')
60 | state_dict[prefix + 'qkv_proj.bias'] = state_dict.pop(prefix + 'in_proj_bias')
61 | elif prefix + 'q_proj.weight' in state_dict: # transform unpacked projection
62 | q_w = state_dict.pop(prefix + 'q_proj.weight')
63 | k_w = state_dict.pop(prefix + 'k_proj.weight')
64 | v_w = state_dict.pop(prefix + 'v_proj.weight')
65 | state_dict[prefix + 'qkv_proj.weight'] = cat([q_w, k_w, v_w])
66 |
67 | if prefix + 'q_proj.bias' in state_dict:
68 | q_b = state_dict.pop(prefix + 'q_proj.bias')
69 | k_b = state_dict.pop(prefix + 'k_proj.bias')
70 | v_b = state_dict.pop(prefix + 'v_proj.bias')
71 | state_dict[prefix + 'qkv_proj.bias'] = cat([q_b, k_b, v_b])
72 |
73 |
74 | class GraphormerAttention(Module):
75 | """
76 | LoRA wrapped Multi-Head Attention
77 | """
78 | def __init__(self, embed_dim, num_heads, dropout: float = .1, bias: bool = True, separate_proj: bool = False):
79 | """
80 | :param embed_dim: the size of each embedding vector
81 | :param num_heads: number of heads
82 | :param dropout: attention dropout
83 | :param separate_proj: use separated projections calculations or optimized
84 | """
85 | assert not embed_dim % num_heads, 'embed_dim must be divisible by num_heads'
86 | super().__init__()
87 | self.embed_dim = embed_dim
88 | self.num_heads = num_heads
89 | self.dropout = dropout
90 | self.separate_proj = separate_proj
91 | self._scale = 1 / sqrt(embed_dim / num_heads)
92 |
93 | if separate_proj:
94 | self.q_proj = Linear(embed_dim, embed_dim, bias=bias)
95 | self.k_proj = Linear(embed_dim, embed_dim, bias=bias)
96 | self.v_proj = Linear(embed_dim, embed_dim, bias=bias)
97 | self._register_load_state_dict_pre_hook(_update_unpacked)
98 | else: # packed projection
99 | self.qkv_proj = Linear(embed_dim, 3 * embed_dim, bias=bias)
100 | self._register_load_state_dict_pre_hook(_update_packed)
101 | self.o_proj = Linear(embed_dim, embed_dim, bias=bias)
102 |
103 | def forward(self, x: Tensor, attn_mask: Tensor, *,
104 | cache: Optional[Tuple[Tensor, Tensor]] = None,
105 | need_weights: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
106 | if self.separate_proj:
107 | q = self.q_proj(x) # BxTxH*E
108 | k = self.k_proj(x) # BxSxH*E (KV seq len can differ from tgt_len with enabled cache trick)
109 | v = self.v_proj(x) # BxSxH*E
110 | else: # optimized
111 | q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
112 |
113 | if cache is not None:
114 | # inference caching. batch should be left padded. shape should be BxSxH*E
115 | bsz, tgt_len, _ = x.shape
116 | ck, cv = cache
117 | ck[:bsz, -tgt_len:] = k
118 | cv[:bsz, -tgt_len:] = v
119 | k, v = ck[:bsz], cv[:bsz]
120 |
121 | # BxTxH*E > BxTxHxE > BxHxTxE
122 | q = q.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
123 | # BxSxH*E > BxSxHxE > BxHxExS
124 | k = k.unflatten(2, (self.num_heads, -1)).permute(0, 2, 3, 1)
125 | # BxSxH*E > BxSxHxE > BxHxSxE
126 | v = v.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
127 |
128 | # BxHxTxE @ BxHxExS > BxHxTxS
129 | a = (q @ k) * self._scale + attn_mask
130 | a = softmax(a, dim=-1)
131 | if self.training and self.dropout:
132 | a = dropout(a, self.dropout)
133 |
134 | # BxHxTxS @ BxHxSxE > BxHxTxE > BxTxHxE > BxTxH*E
135 | o = (a @ v).transpose(1, 2).flatten(2)
136 | o = self.o_proj(o)
137 |
138 | if need_weights:
139 | a = a.sum(dim=1) / self.num_heads
140 | return o, a
141 | else:
142 | return o, None
143 |
144 |
145 | __all__ = ['GraphormerAttention']
146 |
--------------------------------------------------------------------------------
/chytorch/nn/transformer/encoder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2021-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from torch import Tensor, nn
24 | from torch.nn import Dropout, GELU, LayerNorm, Module, SiLU
25 | from typing import Tuple, Optional, Type
26 | from warnings import warn
27 | from .attention import GraphormerAttention
28 | from ..lora import Linear
29 |
30 |
31 | def _update(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
32 | if prefix + 'linear1.weight' in state_dict:
33 | warn('fixed chytorch<1.64 checkpoint', DeprecationWarning)
34 | state_dict[prefix + 'mlp.linear1.weight'] = state_dict.pop(prefix + 'linear1.weight')
35 | state_dict[prefix + 'mlp.linear2.weight'] = state_dict.pop(prefix + 'linear2.weight')
36 | if prefix + 'linear1.bias' in state_dict:
37 | state_dict[prefix + 'mlp.linear1.bias'] = state_dict.pop(prefix + 'linear1.bias')
38 | state_dict[prefix + 'mlp.linear2.bias'] = state_dict.pop(prefix + 'linear2.bias')
39 |
40 |
41 | class MLP(Module):
42 | def __init__(self, d_model, dim_feedforward, dropout=0.1, activation=GELU, bias: bool = True):
43 | super().__init__()
44 | self.linear1 = Linear(d_model, dim_feedforward, bias=bias)
45 | self.linear2 = Linear(dim_feedforward, d_model, bias=bias)
46 | self.dropout = Dropout(dropout)
47 |
48 | # ad-hoc for resolving class from name
49 | if isinstance(activation, str):
50 | activation = getattr(nn, activation)
51 | self.activation = activation()
52 |
53 | def forward(self, x):
54 | return self.linear2(self.dropout(self.activation(self.linear1(x))))
55 |
56 |
57 | class LLaMAMLP(Module):
58 | def __init__(self, d_model, dim_feedforward, dropout=0.1, activation=SiLU, bias: bool = False):
59 | super().__init__()
60 | self.linear1 = Linear(d_model, dim_feedforward, bias=bias)
61 | self.linear2 = Linear(d_model, dim_feedforward, bias=bias)
62 | self.linear3 = Linear(dim_feedforward, d_model, bias=bias)
63 | self.dropout = Dropout(dropout)
64 |
65 | # ad-hoc for resolving class from name
66 | if isinstance(activation, str):
67 | activation = getattr(nn, activation)
68 | self.activation = activation()
69 |
70 | def forward(self, x):
71 | return self.linear3(self.dropout(self.activation(self.linear1(x))) * self.linear2(x))
72 |
73 |
74 | class EncoderLayer(Module):
75 | r"""EncoderLayer based on torch.nn.TransformerEncoderLayer, but batch always first and returns also attention.
76 |
77 | :param d_model: the number of expected features in the input (required).
78 | :param nhead: the number of heads in the multiheadattention models (required).
79 | :param dim_feedforward: the dimension of the feedforward network model (required).
80 | :param dropout: the dropout value (default=0.1).
81 | :param activation: the activation function of the intermediate layer. Default: GELU.
82 | :param layer_norm_eps: the eps value in layer normalization components (default=1e-5).
83 | :param norm_first: if `True`, layer norm is done prior to self attention, multihead
84 | attention and feedforward operations, respectively. Otherwise, it's done after.
85 | """
86 | def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1, activation=GELU, layer_norm_eps=1e-5,
87 | norm_first: bool = False, attention: Type[Module] = GraphormerAttention, mlp: Type[Module] = MLP,
88 | norm_layer: Type[Module] = LayerNorm, projection_bias: bool = True, ff_bias: bool = True):
89 | super().__init__()
90 | self.self_attn = attention(d_model, nhead, dropout, projection_bias)
91 | self.mlp = mlp(d_model, dim_feedforward, dropout, activation, ff_bias)
92 |
93 | self.norm1 = norm_layer(d_model, eps=layer_norm_eps)
94 | self.norm2 = norm_layer(d_model, eps=layer_norm_eps)
95 | self.dropout1 = Dropout(dropout)
96 | self.dropout2 = Dropout(dropout)
97 | self.norm_first = norm_first
98 | self._register_load_state_dict_pre_hook(_update)
99 |
100 | def forward(self, x: Tensor, attn_mask: Optional[Tensor], *,
101 | need_embedding: bool = True, need_weights: bool = False,
102 | **kwargs) -> Tuple[Optional[Tensor], Optional[Tensor]]:
103 | nx = self.norm1(x) if self.norm_first else x # pre-norm or post-norm
104 | e, a = self.self_attn(nx, attn_mask, need_weights=need_weights, **kwargs)
105 |
106 | if need_embedding:
107 | x = x + self.dropout1(e)
108 | if self.norm_first:
109 | return x + self.dropout2(self.mlp(self.norm2(x))), a
110 | # else: post-norm
111 | x = self.norm1(x)
112 | return self.norm2(x + self.dropout2(self.mlp(x))), a
113 | return None, a
114 |
115 |
116 | __all__ = ['EncoderLayer', 'MLP', 'LLaMAMLP']
117 |
--------------------------------------------------------------------------------
/chytorch/nn/voting/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022, 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from .binary import *
24 | from .regressor import *
25 | from .classifier import *
26 |
27 |
28 | __all__ = ['VotingRegressor', 'VotingClassifier', 'BinaryVotingClassifier']
29 |
--------------------------------------------------------------------------------
/chytorch/nn/voting/_kfold.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022, 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from functools import lru_cache
24 | from torch import ones, zeros
25 |
26 |
27 | @lru_cache()
28 | def k_fold_mask(k_fold, ensemble, batch_size, train, device=None):
29 | """
30 | :param k_fold: number of folds
31 | :param ensemble: number of predicting heads
32 | :param batch_size: size of batch
33 | :param train: create train of test mask. train - mask only 1/k_fold of data. test - mask 4/k_fold of data.
34 |
35 | :param device: device of mask
36 | """
37 | assert k_fold >= 3, 'k-fold should be at least 3'
38 | assert not ensemble % k_fold, 'ensemble should be divisible by k-fold'
39 | assert not batch_size % k_fold, 'batch size should be divisible by k-fold'
40 |
41 | if train:
42 | m = ones(batch_size, ensemble, device=device) # k-th fold mask
43 | disable = 0.
44 | else: # test/validation
45 | m = zeros(batch_size, ensemble, device=device) # k-th fold mask
46 | disable = 1.
47 |
48 | batch_size //= k_fold
49 | ensemble //= k_fold
50 | for n in range(k_fold): # disable folds
51 | m[n * batch_size: n * batch_size + batch_size, n * ensemble: n * ensemble + ensemble] = disable
52 | return m
53 |
54 |
55 | __all__ = ['k_fold_mask']
56 |
--------------------------------------------------------------------------------
/chytorch/nn/voting/binary.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022, 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from math import nan
24 | from torch import sigmoid, no_grad
25 | from torch.nn import GELU
26 | from torch.nn.functional import binary_cross_entropy_with_logits
27 | from torchtyping import TensorType
28 | from typing import Union, Optional
29 | from ._kfold import k_fold_mask
30 | from .regressor import VotingRegressor
31 |
32 |
33 | class BinaryVotingClassifier(VotingRegressor):
34 | """
35 | Simple two-layer perceptron with layer normalization and dropout adopted for effective
36 | ensemble binary classification tasks.
37 | """
38 | def __init__(self, ensemble: int = 10, output: int = 1, hidden: int = 256, input: Optional[int] = None,
39 | dropout: float = .5, activation=GELU, layer_norm_eps: float = 1e-5,
40 | loss_function=binary_cross_entropy_with_logits, norm_first: bool = False):
41 | super().__init__(ensemble, output, hidden, input, dropout, activation,
42 | layer_norm_eps, loss_function, norm_first)
43 |
44 | @no_grad()
45 | def predict(self, x: TensorType['batch', 'embedding'], *,
46 | k_fold: Optional[int] = None) -> Union[TensorType['batch', int], TensorType['batch', 'output', int]]:
47 | """
48 | Average class prediction
49 |
50 | :param x: features
51 | :param k_fold: average ensemble according to k-fold trick described in the `loss` method.
52 | """
53 | return (self.predict_proba(x, k_fold=k_fold) > .5).long()
54 |
55 | @no_grad()
56 | def predict_proba(self, x: TensorType['batch', 'embedding'], *,
57 | k_fold: Optional[int] = None) -> Union[TensorType['batch', float],
58 | TensorType['batch', 'output', float]]:
59 | """
60 | Average probability
61 |
62 | :param x: features
63 | :param k_fold: average ensemble according to k-fold trick described in the `loss` method.
64 | """
65 | p = sigmoid(self.forward(x))
66 | if k_fold is not None:
67 | m = k_fold_mask(k_fold, self._ensemble, x.size(0), True, p.device).bool() # B x E
68 | if self._output != 1:
69 | m.unsqueeze_(1) # B x 1 x E
70 | p.masked_fill_(m, nan)
71 | return p.nanmean(-1)
72 |
73 |
74 | __all__ = ['BinaryVotingClassifier']
75 |
--------------------------------------------------------------------------------
/chytorch/nn/voting/classifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022, 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from math import nan
24 | from torch import bmm, no_grad, Tensor
25 | from torch.nn import Dropout, GELU, LayerNorm, LazyLinear, Linear, Module
26 | from torch.nn.functional import cross_entropy, softmax
27 | from torchtyping import TensorType
28 | from typing import Optional, Union
29 | from ._kfold import k_fold_mask
30 |
31 |
32 | class VotingClassifier(Module):
33 | """
34 | Simple two-layer perceptron with layer normalization and dropout adopted for effective ensemble classification.
35 | """
36 | def __init__(self, ensemble: int = 10, output: int = 1, n_classes: int = 2, hidden: int = 256,
37 | input: Optional[int] = None, dropout: float = .5, activation=GELU, layer_norm_eps: float = 1e-5,
38 | loss_function=cross_entropy, norm_first: bool = False):
39 | """
40 | :param ensemble: number of predictive heads per output
41 | :param input: input features size. By-default do lazy initialization
42 | :param output: number of predicted properties in multitask mode. By-default single task mode is active
43 | :param n_classes: number of classes
44 | :param norm_first: do normalization of input
45 | """
46 | assert n_classes >= 2, 'number of classes should be higher or equal than 2'
47 | assert ensemble > 0 and output > 0, 'ensemble and output should be positive integers'
48 | super().__init__()
49 | if input is None:
50 | self.linear1 = LazyLinear(hidden * ensemble * output)
51 | assert not norm_first, 'input size required for prenormalization'
52 | else:
53 | if norm_first:
54 | self.norm_first = LayerNorm(input, layer_norm_eps)
55 | self.linear1 = Linear(input, hidden * ensemble * output)
56 | self.layer_norm = LayerNorm(hidden, layer_norm_eps)
57 | self.activation = activation()
58 | self.dropout = Dropout(dropout)
59 | self.linear2 = Linear(hidden, ensemble * output * n_classes)
60 | self.loss_function = loss_function
61 |
62 | self._n_classes = n_classes
63 | self._ensemble = ensemble
64 | self._input = input
65 | self._hidden = hidden
66 | self._output = output
67 | self._norm_first = norm_first
68 |
69 | def forward(self, x):
70 | """
71 | Returns ensemble of predictions in shape [Batch x Ensemble x Classes].
72 | """
73 | if self._norm_first:
74 | x = self.norm_first(x)
75 | # B x E >> B x N*H >> B x N x H >> N x B x H
76 | x = self.linear1(x).view(-1, self._ensemble * self._output, self._hidden).transpose(0, 1)
77 | x = self.dropout(self.activation(self.layer_norm(x)))
78 | # N * C x H >> N x C x H >> N x H x C
79 | w = self.linear2.weight.view(-1, self._n_classes, self._hidden).transpose(1, 2)
80 | # N x B x C >> B x N x C
81 | x = bmm(x, w).transpose(0, 1).contiguous() + self.linear2.bias.view(-1, self._n_classes)
82 | if self._output != 1: # MT mode
83 | return x.view(-1, self._output, self._ensemble, self._n_classes) # B x O x E x C
84 | return x # B x E x C
85 |
86 | def loss(self, x: TensorType['batch', 'embedding'],
87 | y: Union[TensorType['batch', 1, int], TensorType['batch', 'output', int]],
88 | k_fold: Optional[int] = None, ignore_index: int = -100) -> Tensor:
89 | """
90 | Apply loss function to ensemble of predictions.
91 |
92 | Note: y should be a column vector in single task or Batch x Output matrix in multitask mode.
93 |
94 | :param x: features
95 | :param y: properties
96 | :param k_fold: Cross-validation training procedure of ensemble model.
97 | Only k - 1 / k heads of ensemble trains on k - 1 / k items of each batch.
98 | On validation step 1 / k of same batches used to evaluate heads.
99 | Batch and ensemble sizes should be divisible by k. Disabled by default.
100 | """
101 | p = self.forward(x) # B x E x C or B x O x E x C
102 | if self._output != 1: # MT mode
103 | y = y.unsqueeze(-1).expand(-1, -1, self._ensemble) # B x O > B x O x 1 > B x O x E
104 | else:
105 | y = y.expand(-1, self._ensemble) # B x E
106 |
107 | if k_fold is not None:
108 | m = k_fold_mask(k_fold, self._ensemble, x.size(0), not self.training, p.device).bool() # B x E
109 | if self._output != 1:
110 | # B x E > B x 1 x E
111 | m.unsqueeze_(1)
112 | y = y.masked_fill(m, ignore_index)
113 |
114 | # B x E x C >> B * E x C
115 | # B x O x E x C >> B * O * E x C
116 | p = p.flatten(end_dim=-2)
117 | # B x E >> B * E
118 | # B x O x E >> B * O * E
119 | y = y.flatten()
120 | return self.loss_function(p, y)
121 |
122 | @no_grad()
123 | def predict(self, x: TensorType['batch', 'embedding'], *,
124 | k_fold: Optional[int] = None) -> Union[TensorType['batch', int], TensorType['batch', 'output', int]]:
125 | """
126 | Average class prediction
127 |
128 | :param k_fold: average ensemble according to k-fold trick described in the `loss` method.
129 | """
130 | return self.predict_proba(x, k_fold=k_fold).argmax(-1) # B or B x O
131 |
132 | @no_grad()
133 | def predict_proba(self, x: TensorType['batch', 'embedding'], *,
134 | k_fold: Optional[int] = None) -> Union[TensorType['batch', 'classes', float],
135 | TensorType['batch', 'output', 'classes', float]]:
136 | """
137 | Average probability
138 |
139 | :param x: features
140 | :param k_fold: average ensemble according to k-fold trick described in the `loss` method.
141 | """
142 | p = softmax(self.forward(x), -1)
143 | if k_fold is not None:
144 | m = k_fold_mask(k_fold, self._ensemble, x.size(0), True, p.device).bool().unsqueeze_(-1) # B x E x 1
145 | if self._output != 1:
146 | m.unsqueeze_(1) # B x 1 x E x 1
147 | p.masked_fill_(m, nan)
148 | return p.nanmean(-2) # B x C or B x O x C
149 |
150 |
151 | __all__ = ['VotingClassifier']
152 |
--------------------------------------------------------------------------------
/chytorch/nn/voting/regressor.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022, 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from math import nan
24 | from torch import bmm, no_grad, Tensor
25 | from torch.nn import Dropout, GELU, LayerNorm, LazyLinear, Linear, Module
26 | from torch.nn.functional import smooth_l1_loss
27 | from torchtyping import TensorType
28 | from typing import Optional, Union
29 | from ._kfold import k_fold_mask
30 |
31 |
32 | class VotingRegressor(Module):
33 | """
34 | Simple two-layer perceptron with layer normalization and dropout adopted for effective ensemble regression modeling.
35 | """
36 | def __init__(self, ensemble: int = 10, output: int = 1, hidden: int = 256, input: Optional[int] = None,
37 | dropout: float = .5, activation=GELU, layer_norm_eps: float = 1e-5, loss_function=smooth_l1_loss,
38 | norm_first: bool = False):
39 | """
40 | :param ensemble: number of predictive heads per output
41 | :param input: input features size. By-default do lazy initialization
42 | :param output: number of predicted properties in multitask mode. By-default single task mode is active
43 | :param norm_first: do normalization of input
44 | """
45 | assert ensemble > 0 and output > 0, 'ensemble and output should be positive integers'
46 | super().__init__()
47 | if input is None:
48 | self.linear1 = LazyLinear(hidden * ensemble * output)
49 | assert not norm_first, 'input size required for prenormalization'
50 | else:
51 | if norm_first:
52 | self.norm_first = LayerNorm(input, layer_norm_eps)
53 | self.linear1 = Linear(input, hidden * ensemble * output)
54 |
55 | self.layer_norm = LayerNorm(hidden, layer_norm_eps)
56 | self.activation = activation()
57 | self.dropout = Dropout(dropout)
58 | self.linear2 = Linear(hidden, ensemble * output)
59 | self.loss_function = loss_function
60 |
61 | self._ensemble = ensemble
62 | self._input = input
63 | self._hidden = hidden
64 | self._output = output
65 | self._norm_first = norm_first
66 |
67 | def forward(self, x):
68 | """
69 | Returns ensemble of predictions in shape [Batch x Output*Ensemble].
70 | """
71 | if self._norm_first:
72 | x = self.norm_first(x)
73 | # B x E >> B x N*H >> B x N x H >> N x B x H
74 | x = self.linear1(x).view(-1, self._ensemble * self._output, self._hidden).transpose(0, 1)
75 | x = self.dropout(self.activation(self.layer_norm(x)))
76 | # N x H >> N x H x 1
77 | w = self.linear2.weight.unsqueeze(2)
78 | # N x B x 1 >> N x B >> B x N
79 | x = bmm(x, w).squeeze(-1).transpose(0, 1).contiguous() + self.linear2.bias
80 | if self._output != 1:
81 | return x.view(-1, self._output, self._ensemble) # B x O x E
82 | return x # B x E
83 |
84 | def loss(self, x: TensorType['batch', 'embedding'],
85 | y: Union[TensorType['batch', 1, float], TensorType['batch', 'output', float]],
86 | k_fold: Optional[int] = None) -> Tensor:
87 | """
88 | Apply loss function to ensemble of predictions.
89 |
90 | Note: y should be a column vector in single task or Batch x Output matrix in multitask mode.
91 |
92 | :param x: features
93 | :param y: properties
94 | :param k_fold: Cross-validation training procedure of ensemble model.
95 | Only k - 1 / k heads of ensemble do train on k - 1 / k items of each batch.
96 | On validation step 1 / k of same batches used to evaluate heads.
97 | Batch and ensemble sizes should be divisible by k. Disabled by default.
98 | """
99 | p = self.forward(x)
100 | if self._output != 1: # MT mode
101 | y = y.unsqueeze(-1) # B x O > B x O x 1
102 | y = y.expand(p.size()) # B x E or B x O x E
103 |
104 | if k_fold is not None:
105 | m = k_fold_mask(k_fold, self._ensemble, x.size(0), self.training, p.device) # B x E
106 | if self._output != 1:
107 | m = m.unsqueeze(1) # B x 1 x E
108 | p = p * m # zeros in mask disable gradients
109 | y = y * m # disable errors in test/val loss
110 | return self.loss_function(p, y)
111 |
112 | @no_grad()
113 | def predict(self, x: TensorType['batch', 'embedding'], *,
114 | k_fold: Optional[int] = None) -> Union[TensorType['batch', float],
115 | TensorType['batch', 'output', float]]:
116 | """
117 | Average prediction
118 |
119 | :param x: features.
120 | :param k_fold: average ensemble according to k-fold trick described in the `loss` method.
121 | """
122 | p = self.forward(x)
123 | if k_fold is not None:
124 | m = k_fold_mask(k_fold, self._ensemble, x.size(0), True, p.device).bool() # B x E
125 | if self._output != 1:
126 | m = m.unsqueeze(1) # B x 1 x E
127 | p.masked_fill_(m, nan)
128 | return p.nanmean(-1)
129 |
130 |
131 | __all__ = ['VotingRegressor']
132 |
--------------------------------------------------------------------------------
/chytorch/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2021-2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 |
--------------------------------------------------------------------------------
/chytorch/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2021-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from .lmdb import *
24 | from .molecule import *
25 | from .product import *
26 | from .reaction import *
27 | from .sampler import *
28 | from .smiles import *
29 | from .unpack import *
30 | from ._utils import *
31 |
32 |
33 | __all__ = ['MoleculeDataset', 'collate_molecules', 'left_padded_collate_molecules',
34 | 'ConformerDataset', 'collate_conformers',
35 | 'ReactionEncoderDataset', 'collate_encoded_reactions',
36 | 'RDKitConformerDataset',
37 | 'SMILESDataset',
38 | 'ProductDataset',
39 | 'SuppressException', 'SizedList', 'ShuffledList',
40 | 'ByteRange',
41 | 'LMDBMapper',
42 | 'PickleUnpack',
43 | 'JsonUnpack',
44 | 'StructUnpack',
45 | 'TensorUnpack',
46 | 'Decompress',
47 | 'Decode',
48 | 'chained_collate',
49 | 'skip_none_collate',
50 | 'load_lmdb', 'load_lmdb_zstd_dict',
51 | 'StructureSampler',
52 | 'DistributedStructureSampler',
53 | 'thiacalix_n_arene_dataset']
54 |
--------------------------------------------------------------------------------
/chytorch/utils/data/_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022, 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from random import shuffle
24 | from torch import Size
25 | from torch.utils.data import Dataset
26 | from typing import List, Sequence, TypeVar
27 | from .lmdb import LMDBMapper
28 | from .unpack import Decompress
29 |
30 |
31 | element = TypeVar('element')
32 |
33 |
34 | class chained_collate:
35 | """
36 | Collate batch of tuples with different data structures by different collate functions.
37 |
38 | :param skip_nones: ignore entities with Nones
39 | """
40 | def __init__(self, *collate_fns, skip_nones=True):
41 | self.collate_fns = collate_fns
42 | self.skip_nones = skip_nones
43 |
44 | def __call__(self, batch):
45 | sub_batches = [[] for _ in self.collate_fns]
46 | for x in batch:
47 | if self.skip_nones and (x is None or None in x):
48 | continue
49 | for y, s in zip(x, sub_batches):
50 | s.append(y)
51 | return [f(x) for x, f in zip(sub_batches, self.collate_fns)]
52 |
53 |
54 | def skip_none_collate(collate_fn):
55 | def w(batch):
56 | return collate_fn([x for x in batch if x is not None])
57 | return w
58 |
59 |
60 | def load_lmdb(path, size=4, order='big'):
61 | """
62 | Helper for loading LMDB datasets with continuous integer keys.
63 | Note: keys of datapoints should be positive indices coded as bytes with size `size` and `order` endianness.
64 |
65 | Example structure of DB with a key size=2:
66 | 0000 (0): first record
67 | 0001 (1): second record
68 | ...
69 | ffff (65535): last record
70 |
71 | :param path: path to the database
72 | :param size: key size in bytes
73 | :param order: big or little endian
74 | """
75 | db = LMDBMapper(path)
76 | db._mapping = ByteRange(len(db), size=size, order=order)
77 | return db
78 |
79 |
80 | def load_lmdb_zstd_dict(path, size=4, order='big', key=b'\xff\xff\xff\xff'):
81 | """
82 | Helper for loading LMDB datasets with continuous integer keys compressed by zstd with external dictionary.
83 | Note: keys of datapoints should be positive indices coded as bytes with size `size` and `order` endianness.
84 | Database should contain one additional record with key `key` with decompression dictionary.
85 |
86 | Example structure of DB a key size=2:
87 | ffff (65535): zstd dict bytes
88 | 0000 (0): first record
89 | 0001 (1): second record
90 | ...
91 | fffe (65534): last record
92 |
93 | :param path: path to the database
94 | :param key: LMDB entry with dictionary data
95 | :param size: key size in bytes
96 | :param order: big or little endian
97 | """
98 | db = LMDBMapper(path)
99 | db._mapping = ByteRange(len(db) - 1, size=size, order=order)
100 | db[0] # connect db
101 | dc = Decompress(db, 'zstd', db._tr.get(key))
102 | return dc
103 |
104 |
105 | class SizedList(List):
106 | """
107 | List with tensor-like size method.
108 | """
109 | def size(self, dim=None):
110 | if dim == 0:
111 | return len(self)
112 | elif dim is None:
113 | return Size((len(self),))
114 | raise IndexError
115 |
116 |
117 | class ByteRange:
118 | """
119 | Range returning values as bytes
120 | """
121 | def __init__(self, *args, size=4, order='big', **kwargs):
122 | self.range = range(*args, **kwargs)
123 | self.size = size
124 | self.order = order
125 |
126 | def __getitem__(self, item):
127 | return self.range[item].to_bytes(self.size, self.order)
128 |
129 | def __len__(self):
130 | return len(self.range)
131 |
132 |
133 | class ShuffledList(Dataset):
134 | """
135 | Returns randomly shuffled sequences
136 | """
137 | def __init__(self, data: Sequence[Sequence[element]]):
138 | self.data = data
139 |
140 | def __getitem__(self, item: int) -> List[element]:
141 | x = list(self.data[item])
142 | shuffle(x)
143 | return x
144 |
145 | def __len__(self):
146 | return len(self.data)
147 |
148 | def size(self, dim):
149 | if dim == 0:
150 | return len(self)
151 | elif dim is None:
152 | return Size((len(self),))
153 | raise IndexError
154 |
155 |
156 | class SuppressException(Dataset):
157 | """
158 | Catch exceptions in wrapped dataset and return None instead
159 | """
160 | def __init__(self, dataset):
161 | self.dataset = dataset
162 |
163 | def __getitem__(self, item):
164 | try:
165 | return self.dataset[item]
166 | except IndexError:
167 | raise
168 | except Exception:
169 | pass
170 |
171 | def __len__(self):
172 | return len(self.dataset)
173 |
174 | def size(self, dim):
175 | if dim == 0:
176 | return len(self)
177 | elif dim is None:
178 | return Size((len(self),))
179 | raise IndexError
180 |
181 |
182 | __all__ = ['SizedList',
183 | 'ShuffledList',
184 | 'SuppressException',
185 | 'ByteRange',
186 | 'chained_collate', 'skip_none_collate',
187 | 'load_lmdb', 'load_lmdb_zstd_dict']
188 |
--------------------------------------------------------------------------------
/chytorch/utils/data/lmdb.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022, 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from pathlib import Path
24 | from pickle import load, dump
25 | from torch import Size
26 | from torch.utils.data import Dataset
27 | from typing import Union
28 |
29 |
30 | class LMDBMapper(Dataset):
31 | """
32 | Map LMDB key-value storage to the Sequence Dataset of bytestrings.
33 | """
34 | def __init__(self, db: str, *, cache: Union[Path, str, None] = None):
35 | """
36 | Note: mapper internally uses python list for index to bytes-key mapping and can be huge on big datasets.
37 |
38 | :param db: lmdb dir path
39 | :param cache: path to cache file for [re]storing index. caching disabled by default.
40 | """
41 | self.db = db
42 | self.cache = cache
43 |
44 | if cache is None:
45 | return
46 | if isinstance(cache, str):
47 | cache = Path(cache)
48 | if not cache.exists():
49 | return
50 | # load existing cache
51 | with cache.open('rb') as f:
52 | self._mapping = load(f)
53 |
54 | def __getitem__(self, item: int) -> bytes:
55 | try:
56 | tr = self._tr
57 | except AttributeError:
58 | from lmdb import Environment
59 |
60 | self._db = db = Environment(self.db, readonly=True, lock=False)
61 | self._tr = tr = db.begin()
62 |
63 | try:
64 | mapping = self._mapping
65 | except AttributeError:
66 | with tr.cursor() as c:
67 | # build mapping
68 | self._mapping = mapping = list(c.iternext(keys=True, values=False))
69 | if (cache := self.cache) is not None: # save to cache
70 | if isinstance(cache, str):
71 | cache = Path(cache)
72 | with cache.open('wb') as f:
73 | dump(mapping, f)
74 |
75 | return tr.get(mapping[item])
76 |
77 | def __len__(self):
78 | try:
79 | return len(self._mapping)
80 | except AttributeError:
81 | # temporary open db
82 | from lmdb import Environment
83 |
84 | with Environment(self.db, readonly=True, lock=False) as f:
85 | return f.stat()['entries']
86 |
87 | def size(self, dim):
88 | if dim == 0:
89 | return len(self)
90 | elif dim is None:
91 | return Size((len(self),))
92 | raise IndexError
93 |
94 | def __del__(self):
95 | try:
96 | self._tr.commit()
97 | self._db.close()
98 | except AttributeError:
99 | pass
100 | else:
101 | del self._tr, self._db
102 |
103 | def __getstate__(self):
104 | return {'db': self.db, 'cache': self.cache}
105 |
106 | def __setstate__(self, state):
107 | self.__init__(**state)
108 |
109 |
110 | __all__ = ['LMDBMapper']
111 |
--------------------------------------------------------------------------------
/chytorch/utils/data/molecule/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from .conformer import *
24 | from .dummy import *
25 | from .encoder import *
26 | from .rdkit import *
27 |
28 |
29 | __all__ = ['MoleculeDataset', 'MoleculeDataPoint', 'MoleculeDataBatch',
30 | 'collate_molecules', 'left_padded_collate_molecules',
31 | 'ConformerDataset', 'ConformerDataPoint', 'ConformerDataBatch', 'collate_conformers',
32 | 'RDKitConformerDataset',
33 | 'thiacalix_n_arene_dataset']
34 |
--------------------------------------------------------------------------------
/chytorch/utils/data/molecule/_unpack.pyx:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023, 2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | import numpy as np
24 |
25 | cimport cython
26 | cimport numpy as cnp
27 | from cpython.mem cimport PyMem_Malloc, PyMem_Free
28 |
29 | cnp.import_array()
30 | DTYPE = np.int32
31 | ctypedef cnp.int32_t DTYPE_t
32 |
33 |
34 | # Format specification::
35 | #
36 | # Big endian bytes order
37 | # 8 bit - 0x02 (current format specification)
38 | # 12 bit - number of atoms
39 | # 12 bit - cis/trans stereo block size
40 | #
41 | # Atom block 9 bytes (repeated):
42 | # 12 bit - atom number
43 | # 4 bit - number of neighbors
44 | # 2 bit tetrahedron sign (00 - not stereo, 10 or 11 - has stereo)
45 | # 2 bit - allene sign
46 | # 5 bit - isotope (00000 - not specified, over = isotope - common_isotope + 16)
47 | # 7 bit - atomic number (<=118)
48 | # 32 bit - XY float16 coordinates
49 | # 3 bit - hydrogens (0-7). Note: 7 == None
50 | # 4 bit - charge (charge + 4. possible range -4 - 4)
51 | # 1 bit - radical state
52 | # Connection table: flatten list of neighbors. neighbors count stored in atom block.
53 | # For example CC(=O)O - {1: [2], 2: [1, 3, 4], 3: [2], 4: [2]} >> [2, 1, 3, 4, 2, 2].
54 | # Repeated block (equal to bonds count).
55 | # 24 bit - paired 12 bit numbers.
56 | # Bonds order block 3 bit per bond zero-padded to full byte at the end.
57 | # Cis/trans data block (repeated):
58 | # 24 bit - atoms pair
59 | # 7 bit - zero padding. in future can be used for extra bond-level stereo, like atropoisomers.
60 | # 1 bit - sign
61 |
62 | @cython.nonecheck(False)
63 | @cython.boundscheck(False)
64 | @cython.cdivision(True)
65 | @cython.wraparound(False)
66 | def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsigned short symmetric_attention,
67 | unsigned short components_attention, DTYPE_t max_neighbors, DTYPE_t max_distance):
68 | """
69 | Optimized chython pack to graph tensor converter.
70 | Ignores charge, radicals, isotope, coordinates, bond order, and stereo info
71 | """
72 | cdef unsigned char a, b, c, hydrogens, neighbors_count
73 | cdef unsigned char *connections
74 |
75 | cdef unsigned short atoms_count, bonds_count = 0, order_count = 0, cis_trans_count
76 | cdef unsigned short i, j, k, n, m
77 | cdef unsigned short[4096] mapping
78 | cdef unsigned int size, shift = 4
79 |
80 | cdef cnp.ndarray[DTYPE_t, ndim=1] atoms, neighbors
81 | cdef cnp.ndarray[DTYPE_t, ndim=2] distance
82 | cdef DTYPE_t d
83 |
84 | # read header
85 | if data[0] != 2:
86 | raise ValueError('invalid pack version')
87 |
88 | a, b, c = data[1], data[2], data[3]
89 | atoms_count = (a << 4| b >> 4) + add_cls
90 | cis_trans_count = (b & 0x0f) << 8 | c
91 |
92 | atoms = np.empty(atoms_count, dtype=DTYPE)
93 | neighbors = np.zeros(atoms_count, dtype=DTYPE)
94 | distance = np.full((atoms_count, atoms_count), 9999, dtype=DTYPE) # fill with unreachable value
95 |
96 | # allocate memory
97 | connections = PyMem_Malloc(atoms_count * sizeof(unsigned char))
98 | if not connections:
99 | raise MemoryError()
100 |
101 | if add_cls:
102 | atoms[0] = 1
103 | neighbors[0] = 0
104 | distance[0] = 1 # set CLS to all atoms attention
105 |
106 | if symmetric_attention: # set all atoms to CLS attention
107 | distance[1:, 0] = 1
108 | else: # disable atom to CLS attention
109 | distance[1:, 0] = 0
110 |
111 | # unpack atom block
112 | for i in range(add_cls, atoms_count):
113 | distance[i, i] = 0 # set diagonal to zero
114 | a, b = data[shift], data[shift + 1]
115 | n = a << 4 | b >> 4
116 | mapping[n] = i
117 | connections[i] = neighbors_count = b & 0x0f
118 | bonds_count += neighbors_count
119 |
120 | atoms[i] = (data[shift + 3] & 0x7f) + 2
121 |
122 | hydrogens = data[shift + 8] >> 5
123 | if hydrogens != 7: # hydrogens is not None
124 | neighbors_count += hydrogens
125 | if neighbors_count > max_neighbors:
126 | neighbors_count = max_neighbors
127 | neighbors[i] = neighbors_count + 2 # neighbors + hydrogens
128 | shift += 9
129 |
130 | if bonds_count:
131 | bonds_count /= 2
132 |
133 | order_count = bonds_count * 3
134 | if order_count % 8:
135 | order_count = order_count / 8 + 1
136 | else:
137 | order_count /= 8
138 |
139 | n = add_cls
140 | for i in range(0, 2 * bonds_count, 2):
141 | a, b, c = data[shift], data[shift + 1], data[shift + 2]
142 | m = mapping[a << 4 | b >> 4]
143 | while not connections[n]:
144 | n += 1
145 | connections[n] -= 1
146 | distance[n, m] = distance[m, n] = 1
147 |
148 | m = mapping[(b & 0x0f) << 8 | c]
149 | while not connections[n]:
150 | n += 1
151 | connections[n] -= 1
152 | distance[n, m] = distance[m, n] = 1
153 | shift += 3
154 |
155 | # floyd-warshall algo
156 | for k in range(add_cls, atoms_count):
157 | for i in range(add_cls, atoms_count):
158 | if i == k or distance[i, k] == 9999:
159 | continue
160 | for j in range(add_cls, atoms_count):
161 | d = distance[i, k] + distance[k, j]
162 | if d < distance[i, j]:
163 | distance[i, j] = d
164 |
165 | # reset distances to proper values
166 | for i in range(add_cls, atoms_count):
167 | for j in range(i, atoms_count):
168 | d = distance[i, j]
169 | if d == 9999:
170 | # set attention between subgraphs
171 | distance[i, j] = distance[j, i] = components_attention
172 | elif d > max_distance:
173 | distance[i, j] = distance[j, i] = max_distance + 2
174 | else:
175 | distance[i, j] = distance[j, i] = d + 2
176 |
177 | size = shift + order_count + 4 * cis_trans_count
178 | PyMem_Free(connections)
179 | return atoms, neighbors, distance, size
180 |
--------------------------------------------------------------------------------
/chytorch/utils/data/molecule/conformer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from chython import MoleculeContainer
24 | from functools import cached_property
25 | from numpy import empty, ndarray, sqrt, square, ones, digitize, arange, int32
26 | from numpy.random import default_rng
27 | from torch import IntTensor, Size, zeros, ones as t_ones, int32 as t_int32, eye
28 | from torch.nn.utils.rnn import pad_sequence
29 | from torch.utils.data import Dataset
30 | from torch.utils.data._utils.collate import default_collate_fn_map
31 | from torchtyping import TensorType
32 | from typing import Sequence, Tuple, Union, NamedTuple
33 |
34 |
35 | class ConformerDataPoint(NamedTuple):
36 | atoms: TensorType['atoms', int]
37 | hydrogens: TensorType['atoms', int]
38 | distances: TensorType['atoms', 'atoms', int]
39 |
40 |
41 | class ConformerDataBatch(NamedTuple):
42 | atoms: TensorType['batch', 'atoms', int]
43 | hydrogens: TensorType['batch', 'atoms', int]
44 | distances: TensorType['batch', 'atoms', 'atoms', int]
45 |
46 | def to(self, *args, **kwargs):
47 | return ConformerDataBatch(*(x.to(*args, **kwargs) for x in self))
48 |
49 | def cpu(self, *args, **kwargs):
50 | return ConformerDataBatch(*(x.cpu(*args, **kwargs) for x in self))
51 |
52 | def cuda(self, *args, **kwargs):
53 | return ConformerDataBatch(*(x.cuda(*args, **kwargs) for x in self))
54 |
55 |
56 | def collate_conformers(batch, *, collate_fn_map=None) -> ConformerDataBatch:
57 | """
58 | Prepares batches of conformers.
59 |
60 | :return: atoms, hydrogens, distances.
61 | """
62 | atoms, hydrogens, distances = [], [], []
63 |
64 | for a, h, d in batch:
65 | atoms.append(a)
66 | hydrogens.append(h)
67 | distances.append(d)
68 |
69 | pa = pad_sequence(atoms, True)
70 | b, s = pa.shape
71 | tmp = eye(s, dtype=t_int32).repeat(b, 1, 1) # prevent nan in MHA softmax on padding
72 | for i, d in enumerate(distances):
73 | s = d.size(0)
74 | tmp[i, :s, :s] = d
75 | return ConformerDataBatch(pa, pad_sequence(hydrogens, True), tmp)
76 |
77 |
78 | default_collate_fn_map[ConformerDataPoint] = collate_conformers # add auto_collation to the DataLoader
79 |
80 |
81 | class ConformerDataset(Dataset):
82 | def __init__(self, molecules: Sequence[Union[MoleculeContainer, Tuple[ndarray, ndarray, ndarray]]], *,
83 | short_cutoff: float = .9, long_cutoff: float = 5., precision: float = .05,
84 | add_cls: bool = True, unpack: bool = True, xyz: bool = True, noisy_distance: bool = False):
85 | """
86 | convert molecules to tuple of:
87 | atoms vector with atomic numbers + 2,
88 | vector with implicit hydrogens count shifted by 2,
89 | matrix with the discretized Euclidian distances between atoms shifted by 3.
90 |
91 | Note: atoms shifted to differentiate from padding equal to zero, special cls token equal to 1,
92 | and reserved task specific token equal to 2.
93 | hydrogens shifted to differentiate from padding equal to zero and reserved task-specific token equal to 1.
94 | distances shifted to differentiate from padding equal to zero, from special distance equal to 1
95 | that code unreachable atoms/tokens, and self-attention of atoms equal to 2.
96 |
97 | :param molecules: map-like molecules collection or tuples of atomic numbers array,
98 | hydrogens array, and coordinates/distances array.
99 | :param short_cutoff: shortest possible distance between atoms
100 | :param long_cutoff: radius of visible neighbors sphere
101 | :param precision: discretized segment size
102 | :param add_cls: add special token at first position
103 | :param unpack: unpack coordinates from `chython.MoleculeContainer` (True) or use prepared data (False).
104 | predefined data structure: (vector of atomic numbers, vector of neighbors,
105 | matrix of coordinates or distances).
106 | :param xyz: provided xyz or distance matrix if unpack=False
107 | :param noisy_distance: add noise in [-1, 1] range into binarized distance
108 | """
109 | if unpack:
110 | assert xyz, 'xyz should be True if unpack True'
111 | assert precision > .01 and short_cutoff > .1 and long_cutoff > 1, 'invalid cutoff and precision'
112 | assert long_cutoff - short_cutoff > precision, 'precision should be less than cutoff interval'
113 |
114 | self.molecules = molecules
115 | self.short_cutoff = short_cutoff
116 | self.long_cutoff = long_cutoff
117 | self.precision = precision
118 | self.add_cls = add_cls
119 | self.unpack = unpack
120 | self.xyz = xyz
121 | self.noisy_distance = noisy_distance
122 |
123 | # discrete bins intervals. first 3 bins reserved for shifted coding
124 | self._bins = arange(short_cutoff - 3 * precision, long_cutoff, precision)
125 | self._bins[:3] = [-1, 0, .01] # trick for self-loop coding
126 | self.max_distance = len(self._bins) - 2 # param for MoleculeEncoder
127 |
128 | def __getitem__(self, item: int) -> ConformerDataPoint:
129 | mol = self.molecules[item]
130 | if self.unpack:
131 | if self.add_cls:
132 | atoms = t_ones(len(mol) + 1, dtype=t_int32) # cls token = 1
133 | hydrogens = zeros(len(mol) + 1, dtype=t_int32) # cls centrality-encoder disabled by padding trick
134 | else:
135 | atoms = IntTensor(len(mol))
136 | hydrogens = IntTensor(len(mol))
137 |
138 | for i, (n, a) in enumerate(mol.atoms(), self.add_cls):
139 | atoms[i] = a.atomic_number + 2
140 | hydrogens[i] = (a.implicit_hydrogens or 0) + 2
141 |
142 | xyz = empty((len(mol), 3))
143 | conformer = mol._conformers[0] # noqa
144 | for i, n in enumerate(mol):
145 | xyz[i] = conformer[n]
146 | else:
147 | a, hgs, xyz = mol
148 | if self.add_cls:
149 | atoms = t_ones(len(a) + 1, dtype=t_int32)
150 | hydrogens = zeros(len(a) + 1, dtype=t_int32)
151 |
152 | atoms[1:] = IntTensor(a + 2)
153 | hydrogens[1:] = IntTensor(hgs + 2)
154 | else:
155 | atoms = IntTensor(a + 2)
156 | hydrogens = IntTensor(hgs + 2)
157 |
158 | if self.xyz:
159 | diff = xyz[None, :, :] - xyz[:, None, :] # NxNx3
160 | dist = sqrt(square(diff).sum(axis=-1)) # NxN
161 | else:
162 | dist = xyz
163 |
164 | dist = digitize(dist, self._bins)
165 | if self.noisy_distance:
166 | dist += (dist > 2) * self.generator.integers(-1, 1, size=dist.shape, endpoint=True)
167 | if self.add_cls: # set cls to atoms distance equal to 0
168 | tmp = ones((len(atoms), len(atoms)), dtype=int32)
169 | tmp[1:, 1:] = dist
170 | dist = tmp
171 | return ConformerDataPoint(atoms, hydrogens, IntTensor(dist))
172 |
173 | def __len__(self):
174 | return len(self.molecules)
175 |
176 | def size(self, dim):
177 | if dim == 0:
178 | return len(self.molecules)
179 | elif dim is None:
180 | return Size((len(self),))
181 | raise IndexError
182 |
183 | @cached_property
184 | def generator(self):
185 | return default_rng()
186 |
187 |
188 | __all__ = ['ConformerDataset', 'ConformerDataPoint', 'ConformerDataBatch', 'collate_conformers']
189 |
--------------------------------------------------------------------------------
/chytorch/utils/data/molecule/dummy.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from chython import smiles
24 | from .encoder import MoleculeDataset
25 |
26 |
27 | def thiacalix_n_arene_dataset(n=4, size=10_000, **kwargs):
28 | """
29 | Create a dummy dataset for testing purposes with thiacalix[n]arenes.
30 |
31 | :param n: number of macrocycle fragments. Each fragment contains 12 atoms.
32 | :param size: dataset size
33 | :param kwargs: other params of MoleculeDataset
34 | """
35 | assert n >= 3, 'n must be greater than 3'
36 | prefix = 'C12=CC(C(C)(C)C)=CC(=C2O)S'
37 | postfix = 'C2=CC(C(C)(C)C)=CC(=C2O)S1'
38 | chain = ''.join('C2=CC(C(C)(C)C)=CC(=C2O)S' for _ in range(n - 2))
39 |
40 | return MoleculeDataset([smiles(prefix + chain + postfix)] * size, **kwargs)
41 |
42 |
43 | __all__ = ['thiacalix_n_arene_dataset']
44 |
--------------------------------------------------------------------------------
/chytorch/utils/data/molecule/encoder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2021-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from chython import MoleculeContainer
24 | from functools import partial
25 | from numpy import minimum, nan_to_num
26 | from scipy.sparse.csgraph import shortest_path
27 | from torch import IntTensor, Size, int32, ones, zeros, eye, empty, triu, Tensor, cat
28 | from torch.nn.utils.rnn import pad_sequence
29 | from torch.utils.data import Dataset
30 | from torch.utils.data._utils.collate import default_collate_fn_map
31 | from torchtyping import TensorType
32 | from typing import Sequence, Union, NamedTuple, Tuple
33 | from zlib import decompress
34 |
35 |
36 | class MoleculeDataPoint(NamedTuple):
37 | atoms: TensorType['atoms', int]
38 | neighbors: TensorType['atoms', int]
39 | distances: TensorType['atoms', 'atoms', int]
40 |
41 |
42 | class MoleculeDataBatch(NamedTuple):
43 | atoms: TensorType['batch', 'atoms', int]
44 | neighbors: TensorType['batch', 'atoms', int]
45 | distances: TensorType['batch', 'atoms', 'atoms', int]
46 |
47 | def to(self, *args, **kwargs):
48 | return MoleculeDataBatch(*(x.to(*args, **kwargs) for x in self))
49 |
50 | def cpu(self, *args, **kwargs):
51 | return MoleculeDataBatch(*(x.cpu(*args, **kwargs) for x in self))
52 |
53 | def cuda(self, *args, **kwargs):
54 | return MoleculeDataBatch(*(x.cuda(*args, **kwargs) for x in self))
55 |
56 |
57 | def collate_molecules(batch, *, padding_left: bool = False, collate_fn_map=None) -> MoleculeDataBatch:
58 | """
59 | Prepares batches of molecules.
60 |
61 | :return: atoms, neighbors, distances.
62 | """
63 | atoms, neighbors, distances = [], [], []
64 |
65 | for a, n, d in batch:
66 | if padding_left:
67 | atoms.append(a.flipud())
68 | neighbors.append(n.flipud())
69 | else:
70 | atoms.append(a)
71 | neighbors.append(n)
72 | distances.append(d)
73 |
74 | pa = pad_sequence(atoms, True)
75 | b, s = pa.shape
76 | tmp = eye(s, dtype=int32).repeat(b, 1, 1) # prevent nan in MHA softmax on padding
77 | for i, d in enumerate(distances):
78 | s = d.size(0)
79 | if padding_left:
80 | tmp[i, -s:, -s:] = d
81 | else:
82 | tmp[i, :s, :s] = d
83 | if padding_left:
84 | return MoleculeDataBatch(pa.fliplr(), pad_sequence(neighbors, True).fliplr(), tmp)
85 | return MoleculeDataBatch(pa, pad_sequence(neighbors, True), tmp)
86 |
87 |
88 | left_padded_collate_molecules = partial(collate_molecules, padding_left=True)
89 | default_collate_fn_map[MoleculeDataPoint] = collate_molecules # add auto_collation to the DataLoader
90 |
91 |
92 | class MoleculeDataset(Dataset):
93 | def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
94 | add_cls: bool = True, cls_token: Union[int, Tuple[int, ...], Sequence[int], Sequence[Tuple[int, ...]],
95 | TensorType['cls', int], TensorType['dataset', 1, int], TensorType['dataset', 'cls', int]] = 1,
96 | max_distance: int = 10, max_neighbors: int = 14,
97 | attention_schema: str = 'bert', components_attention: bool = True,
98 | unpack: bool = False, compressed: bool = True, distance_cutoff=None):
99 | """
100 | convert molecules to tuple of:
101 | atoms vector with atomic numbers + 2,
102 | neighbors vector with connected neighbored atoms count including implicit hydrogens count shifted by 2,
103 | distance matrix with the shortest paths between atoms shifted by 2.
104 |
105 | Note: atoms shifted to differentiate from padding equal to zero, special cls token equal to 1, and reserved MLM
106 | task token equal to 2.
107 | neighbors shifted to differentiate from padding equal to zero and reserved MLM task token equal to 1.
108 | distances shifted to differentiate from padding equal to zero and from special distance equal to 1
109 | that code unreachable atoms (e.g. salts).
110 |
111 | :param molecules: molecules collection
112 | :param max_distance: set distances greater than cutoff to cutoff value
113 | :param add_cls: add special token at first position
114 | :param max_neighbors: set neighbors count greater than cutoff to cutoff value
115 | :param attention_schema: attention between CLS and atoms:
116 | bert - symmetrical without masks;
117 | causal - masked from atoms to cls, causal between cls multi-prompts (triangle mask) and full to atoms;
118 | directed - masked from atoms to cls and between cls, but full to atoms (self-attention of cls is kept);
119 | :param components_attention: enable or disable attention between subgraphs
120 | :param unpack: unpack molecules
121 | :param compressed: packed molecules are compressed
122 | :param cls_token: idx of cls token (int), or multiple tokens for multi-prompt (tuple, 1d-tensor),
123 | or individual token per sample (Sequence, column-vector) or
124 | individual multi-prompt per sample (Sequence of tuples, 2d-tensor).
125 | """
126 | if not isinstance(cls_token, int) or cls_token != 1:
127 | assert add_cls, 'non-default value of cls_token requires add_cls=True'
128 | assert attention_schema in ('bert', 'causal', 'directed'), 'Invalid attention schema'
129 |
130 | self.molecules = molecules
131 | # distance_cutoff is deprecated
132 | self.max_distance = distance_cutoff if distance_cutoff is not None else max_distance
133 | self.add_cls = add_cls
134 | self.max_neighbors = max_neighbors
135 | self.unpack = unpack
136 | self.compressed = compressed
137 | self.cls_token = cls_token
138 | self.attention_schema = attention_schema
139 | self.components_attention = components_attention
140 |
141 | def __getitem__(self, item: int) -> MoleculeDataPoint:
142 | mol = self.molecules[item]
143 |
144 | # cls setup lookup
145 | cls_token = self.cls_token
146 | if not self.add_cls:
147 | cls_cnt = 0
148 | elif isinstance(cls_token, int):
149 | cls_cnt = 1
150 | elif isinstance(cls_token, tuple):
151 | cls_cnt = len(cls_token)
152 | assert cls_cnt > 1, 'wrong multi-prompt setup'
153 | cls_token = IntTensor(cls_token)
154 | elif isinstance(cls_token, Sequence):
155 | if isinstance(ct := cls_token[item], int):
156 | assert isinstance(cls_token[0], int), 'inconsistent cls_token data'
157 | cls_cnt = 1
158 | cls_token = ct
159 | elif isinstance(ct, Sequence):
160 | cls_cnt = len(ct)
161 | assert cls_cnt > 1, 'wrong multi-prompt setup'
162 | assert isinstance(cls_token[0], Sequence) and cls_cnt == len(cls_token[0]), 'inconsistent cls_token data'
163 | cls_token = IntTensor(ct)
164 | else:
165 | raise TypeError('cls_token must be int, tuple of ints, sequence of ints or tuples of ints or 1,2-d tensor')
166 | elif isinstance(cls_token, Tensor):
167 | if cls_token.dim() == 1:
168 | cls_cnt = cls_token.size(0)
169 | assert cls_cnt > 1, 'wrong multi-prompt setup'
170 | elif cls_token.dim() == 2:
171 | cls_cnt = cls_token.size(1)
172 | cls_token = cls_token[item]
173 | else:
174 | raise TypeError('cls_token must be int, tuple of ints, sequence of ints or tuples of ints or 1,2-d tensor')
175 | else:
176 | raise TypeError('cls_token must be int, tuple of ints, sequence of ints or tuples of ints or 1,2-d tensor')
177 |
178 | if self.unpack:
179 | try:
180 | from ._unpack import unpack
181 | except ImportError: # windows?
182 | mol = MoleculeContainer.unpack(mol, compressed=self.compressed)
183 | else:
184 | if self.compressed:
185 | mol = decompress(mol)
186 | atoms, neighbors, distances, _ = unpack(mol, cls_cnt == 1, # only single cls token supported by cython ext
187 | # causal and directed have the same mask for 1 cls token case
188 | self.attention_schema == 'bert',
189 | self.components_attention, self.max_neighbors,
190 | self.max_distance)
191 | atoms = IntTensor(atoms)
192 | neighbors = IntTensor(neighbors)
193 | distances = IntTensor(distances)
194 | if cls_cnt == 1:
195 | # token already pre-allocated
196 | if isinstance(cls_token, Tensor) or cls_token != 1:
197 | # change default value (1)
198 | atoms[0] = cls_token
199 | elif cls_cnt: # expand atoms with cls tokens
200 | atoms = cat([cls_token, atoms])
201 | neighbors = cat([zeros(cls_cnt, dtype=int32), neighbors])
202 | distances = self._add_cls_to_distances(distances, cls_cnt)
203 | return MoleculeDataPoint(atoms, neighbors, distances)
204 |
205 | token_cnt = len(mol) + cls_cnt
206 | atoms = empty(token_cnt, dtype=int32)
207 | neighbors = zeros(token_cnt, dtype=int32) # cls centrality-encoder disabled by padding trick
208 |
209 | nc = self.max_neighbors
210 | ngb = mol._bonds # noqa speedup
211 | for i, (n, a) in enumerate(mol.atoms(), cls_cnt):
212 | atoms[i] = a.atomic_number + 2
213 | nb = len(ngb[n]) + (a.implicit_hydrogens or 0) # treat bad valence as 0-hydrogen
214 | if nb > nc:
215 | nb = nc
216 | neighbors[i] = nb + 2
217 |
218 | distances = shortest_path(mol.adjacency_matrix(), method='FW', directed=False, unweighted=True) + 2
219 | nan_to_num(distances, copy=False, posinf=self.components_attention)
220 | minimum(distances, self.max_distance + 2, out=distances)
221 | distances = IntTensor(distances)
222 |
223 | if cls_cnt:
224 | atoms[:cls_cnt] = cls_token
225 | distances = self._add_cls_to_distances(distances, cls_cnt)
226 | return MoleculeDataPoint(atoms, neighbors, distances)
227 |
228 | def __len__(self):
229 | return len(self.molecules)
230 |
231 | def size(self, dim):
232 | if dim == 0:
233 | return len(self)
234 | elif dim is None:
235 | return Size((len(self),))
236 | raise IndexError
237 |
238 | def _add_cls_to_distances(self, distances, cls_cnt):
239 | total = distances.size(0) + cls_cnt
240 | if self.attention_schema == 'bert': # everything to everything
241 | tmp = ones(total, total, dtype=int32)
242 | elif self.attention_schema == 'causal':
243 | tmp = triu(ones(total, total, dtype=int32))
244 | else: # CLS to atoms but not back
245 | tmp = eye(total, dtype=int32) # self attention of cls tokens
246 | tmp[:cls_cnt, cls_cnt:] = 1 # cls to atom attention
247 | tmp[cls_cnt:, cls_cnt:] = distances
248 | return tmp
249 |
250 |
251 | __all__ = ['MoleculeDataset', 'MoleculeDataPoint', 'MoleculeDataBatch',
252 | 'collate_molecules', 'left_padded_collate_molecules']
253 |
--------------------------------------------------------------------------------
/chytorch/utils/data/molecule/rdkit.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from numpy import array
24 | from random import choice
25 | from torch import Size
26 | from torch.utils.data import Dataset
27 | from typing import Sequence, Optional, Dict
28 |
29 |
30 | def generate(smiles, num_conf=10, max_attempts=100, prune=.2):
31 | from rdkit.Chem import MolFromSmiles, AddHs, RemoveAllHs
32 | from rdkit.Chem.AllChem import EmbedMultipleConfs
33 |
34 | m = MolFromSmiles(smiles)
35 | m = AddHs(m)
36 | if not EmbedMultipleConfs(m, numConfs=num_conf, maxAttempts=max_attempts, pruneRmsThresh=prune):
37 | # try again ignoring chirality
38 | EmbedMultipleConfs(m, numConfs=num_conf, maxAttempts=max_attempts, pruneRmsThresh=prune, enforceChirality=False)
39 | m = RemoveAllHs(m)
40 | a = array([x.GetAtomicNum() for x in m.GetAtoms()])
41 | h = array([x.GetNumExplicitHs() for x in m.GetAtoms()])
42 | return [(a, h, c.GetPositions()) for c in m.GetConformers()]
43 |
44 |
45 | class RDKitConformerDataset(Dataset):
46 | """
47 | Random conformer generator dataset
48 | """
49 | def __init__(self, molecules: Sequence[str], num_conf=10, max_attempts=100, prune=.2,
50 | cache: Optional[Dict[int, Sequence]] = None):
51 | """
52 | :param molecules: list of molecules' SMILES strings
53 | :param num_conf: numConfs parameter in EmbedMultipleConfs
54 | :param max_attempts: maxAttempts parameter in EmbedMultipleConfs
55 | :param prune: pruneRmsThresh parameter in EmbedMultipleConfs
56 | :param cache: cache for generated conformers
57 | """
58 | self.molecules = molecules
59 | self.num_conf = num_conf
60 | self.max_attempts = max_attempts
61 | self.prune = prune
62 | self.cache = cache
63 |
64 | def __getitem__(self, item: int):
65 | if self.cache is not None and item in self.cache:
66 | return choice(self.cache[item])
67 |
68 | confs = generate(self.molecules[item], self.num_conf, self.max_attempts, self.prune)
69 | if not confs:
70 | raise ValueError('conformer generation failed')
71 |
72 | if self.cache is not None:
73 | self.cache[item] = confs
74 | return choice(confs)
75 |
76 | def __len__(self):
77 | return len(self.molecules)
78 |
79 | def size(self, dim):
80 | if dim == 0:
81 | return len(self.molecules)
82 | elif dim is None:
83 | return Size((len(self),))
84 | raise IndexError
85 |
86 |
87 | __all__ = ['RDKitConformerDataset']
88 |
--------------------------------------------------------------------------------
/chytorch/utils/data/product.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from math import floor
24 | from torch import Size
25 | from torch.utils.data import Dataset
26 | from typing import Sequence, List, TypeVar
27 |
28 |
29 | element = TypeVar('element')
30 |
31 |
32 | class ProductDataset(Dataset):
33 | """
34 | Lazy product enumeration dataset for combinatorial libraries.
35 | """
36 | def __init__(self, *sets: Sequence[element]):
37 | self.sets = sets
38 |
39 | # calculate lazy product metadata
40 | self._divs = divs = []
41 | self._mods = mods = []
42 |
43 | factor = 1
44 | for x in reversed(sets):
45 | s = len(x)
46 | divs.insert(0, factor)
47 | mods.insert(0, s)
48 | factor *= s
49 | self._size = factor
50 |
51 | def __getitem__(self, item: int) -> List[element]:
52 | if item < 0:
53 | item += self._size
54 | if item < 0 or item >= self._size:
55 | raise IndexError
56 |
57 | return [s[floor(item / d) % m] for s, d, m in zip(self.sets, self._divs, self._mods)]
58 |
59 | def __len__(self):
60 | return self._size
61 |
62 | def size(self, dim):
63 | if dim == 0:
64 | return len(self)
65 | elif dim is None:
66 | return Size((len(self),))
67 | raise IndexError
68 |
69 |
70 | __all__ = ['ProductDataset']
71 |
--------------------------------------------------------------------------------
/chytorch/utils/data/reaction/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from .encoder import *
24 |
25 |
26 | # reverse compatibility
27 | ReactionDataset = ReactionEncoderDataset
28 | collate_reactions = collate_encoded_reactions
29 |
30 |
31 | __all__ = ['ReactionEncoderDataset', 'ReactionEncoderDataPoint', 'ReactionEncoderDataBatch',
32 | 'collate_encoded_reactions',
33 | # reverse compatibility
34 | 'ReactionDataset', 'collate_reactions']
35 |
--------------------------------------------------------------------------------
/chytorch/utils/data/reaction/encoder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2021-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from chython import ReactionContainer
24 | from itertools import chain, repeat
25 | from torch import IntTensor, cat, zeros, int32, Size, eye
26 | from torch.nn.utils.rnn import pad_sequence
27 | from torch.utils.data import Dataset
28 | from torch.utils.data._utils.collate import default_collate_fn_map
29 | from torchtyping import TensorType
30 | from typing import Sequence, Union, NamedTuple
31 | from ..molecule import MoleculeDataset
32 |
33 |
34 | class ReactionEncoderDataPoint(NamedTuple):
35 | atoms: TensorType['atoms', int]
36 | neighbors: TensorType['atoms', int]
37 | distances: TensorType['atoms', 'atoms', int]
38 | roles: TensorType['atoms', int]
39 |
40 |
41 | class ReactionEncoderDataBatch(NamedTuple):
42 | atoms: TensorType['batch', 'atoms', int]
43 | neighbors: TensorType['batch', 'atoms', int]
44 | distances: TensorType['batch', 'atoms', 'atoms', int]
45 | roles: TensorType['batch', 'atoms', int]
46 |
47 | def to(self, *args, **kwargs):
48 | return ReactionEncoderDataBatch(*(x.to(*args, **kwargs) for x in self))
49 |
50 | def cpu(self, *args, **kwargs):
51 | return ReactionEncoderDataBatch(*(x.cpu(*args, **kwargs) for x in self))
52 |
53 | def cuda(self, *args, **kwargs):
54 | return ReactionEncoderDataBatch(*(x.cuda(*args, **kwargs) for x in self))
55 |
56 |
57 | def collate_encoded_reactions(batch, *, collate_fn_map=None) -> ReactionEncoderDataBatch:
58 | """
59 | Prepares batches of reactions.
60 |
61 | :return: atoms, neighbors, distances, atoms roles.
62 | """
63 | atoms, neighbors, distances, roles = [], [], [], []
64 | for a, n, d, r in batch:
65 | atoms.append(a)
66 | neighbors.append(n)
67 | roles.append(r)
68 | distances.append(d)
69 |
70 | pa = pad_sequence(atoms, True)
71 | b, s = pa.shape
72 | tmp = eye(s, dtype=int32).repeat(b, 1, 1) # prevent nan in MHA softmax on padding
73 | for n, d in enumerate(distances):
74 | s = d.size(0)
75 | tmp[n, :s, :s] = d
76 | return ReactionEncoderDataBatch(pa, pad_sequence(neighbors, True), tmp, pad_sequence(roles, True))
77 |
78 |
79 | default_collate_fn_map[ReactionEncoderDataPoint] = collate_encoded_reactions # add auto_collation to the DataLoader
80 |
81 |
82 | class ReactionEncoderDataset(Dataset):
83 | def __init__(self, reactions: Sequence[Union[ReactionContainer, bytes]], *, max_distance: int = 10,
84 | max_neighbors: int = 14, add_cls: bool = True, add_molecule_cls: bool = True,
85 | hide_molecule_cls: bool = True, unpack: bool = False, distance_cutoff=None, compressed: bool = True):
86 | """
87 | convert reactions to tuple of:
88 | atoms, neighbors and distances tensors similar to molecule dataset.
89 | distances - merged molecular distances matrices filled by zero for isolating attention.
90 | roles: 2 reactants, 3 products, 0 padding, 1 cls token.
91 |
92 | :param reactions: reactions collection
93 | :param max_distance: set distances greater than cutoff to cutoff value
94 | :param add_cls: add special token at first position
95 | :param add_molecule_cls: add special token at first position of each molecule
96 | :param hide_molecule_cls: disable molecule cls in reaction lvl (mark as padding)
97 | :param max_neighbors: set neighbors count greater than cutoff to cutoff value
98 | :param unpack: unpack reactions
99 | :param compressed: packed reactions are compressed
100 | """
101 | if not add_molecule_cls:
102 | assert not hide_molecule_cls, 'add_molecule_cls should be True if hide_molecule_cls is True'
103 | self.reactions = reactions
104 | # distance_cutoff is deprecated
105 | self.max_distance = distance_cutoff if distance_cutoff is not None else max_distance
106 | self.add_cls = add_cls
107 | self.add_molecule_cls = add_molecule_cls
108 | self.hide_molecule_cls = hide_molecule_cls
109 | self.max_neighbors = max_neighbors
110 | self.unpack = unpack
111 | self.compressed = compressed
112 |
113 | def __getitem__(self, item: int) -> ReactionEncoderDataPoint:
114 | rxn = self.reactions[item]
115 | if self.unpack:
116 | rxn = ReactionContainer.unpack(rxn, compressed=self.compressed)
117 | molecules = MoleculeDataset(rxn.reactants + rxn.products, max_distance=self.max_distance,
118 | max_neighbors=self.max_neighbors, add_cls=self.add_molecule_cls)
119 |
120 | if self.add_cls:
121 | # disable rxn cls in molecules encoder
122 | atoms, neighbors, roles = [IntTensor([0])], [IntTensor([0])], [1]
123 | else:
124 | atoms, neighbors, roles = [], [], []
125 | distances = []
126 | for i, (m, r) in enumerate(chain(zip(rxn.reactants, repeat(2)), zip(rxn.products, repeat(3)))):
127 | a, n, d = molecules[i]
128 | atoms.append(a)
129 | neighbors.append(n)
130 | distances.append(d)
131 | if self.add_molecule_cls:
132 | # (dis|en)able molecule cls in reaction encoder
133 | roles.append(0 if self.hide_molecule_cls else r)
134 | roles.extend(repeat(r, len(m)))
135 |
136 | tmp = zeros(len(roles), len(roles), dtype=int32)
137 | if self.add_cls:
138 | tmp[0, 0] = 1 # prevent nan in MHA softmax.
139 | i = 1
140 | else:
141 | i = 0
142 | for d in distances:
143 | j = i + d.size(0)
144 | tmp[i:j, i:j] = d
145 | i = j
146 | return ReactionEncoderDataPoint(cat(atoms), cat(neighbors), tmp, IntTensor(roles))
147 |
148 | def __len__(self):
149 | return len(self.reactions)
150 |
151 | def size(self, dim):
152 | if dim == 0:
153 | return len(self)
154 | elif dim is None:
155 | return Size((len(self),))
156 | raise IndexError
157 |
158 |
159 | __all__ = ['ReactionEncoderDataset', 'ReactionEncoderDataPoint', 'ReactionEncoderDataBatch',
160 | 'collate_encoded_reactions']
161 |
--------------------------------------------------------------------------------
/chytorch/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2022-2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | import torch.distributed as dist
24 | from chython import MoleculeContainer
25 | from collections import defaultdict
26 | from itertools import chain, islice
27 | from math import ceil
28 | from torch import Generator, randperm
29 | from torch.utils.data import Sampler
30 | from typing import Optional, List, Iterator
31 | from .molecule import MoleculeDataset
32 |
33 |
34 | def _build_index(dataset, sizes):
35 | if not isinstance(dataset, MoleculeDataset):
36 | raise TypeError('Unsupported Dataset')
37 | if sizes is not None:
38 | return sizes
39 | elif dataset.unpack:
40 | return [MoleculeContainer.pack_len(m) for m in dataset.molecules]
41 | else:
42 | return [len(m) for m in dataset.molecules]
43 |
44 |
45 | def _indices(order, sizes, batch_size):
46 | groups = defaultdict(list)
47 | for n in order:
48 | groups[sizes[n]].append(n)
49 |
50 | iterable_groups = {k: iter(v) for k, v in groups.items()}
51 | sorted_groups = sorted(groups)
52 | chained_groups = {}
53 | for k in groups:
54 | chained_groups[k] = chain(*(iterable_groups[x] for x in sorted_groups[sorted_groups.index(k)::-1]),
55 | *(iterable_groups[x] for x in sorted_groups[sorted_groups.index(k) + 1:]))
56 |
57 | indices = []
58 | seen = set()
59 | for n in order:
60 | if n not in seen:
61 | for x in islice(chained_groups[sizes[n]], batch_size):
62 | indices.append(x)
63 | if x != n:
64 | seen.add(x)
65 | if len(indices) == len(sizes):
66 | break
67 | else:
68 | seen.discard(n)
69 | return indices
70 |
71 |
72 | class StructureSampler(Sampler[List[int]]):
73 | def __init__(self, dataset: MoleculeDataset, batch_size: int, shuffle: bool = True, seed: int = 0, *,
74 | sizes: Optional[List[int]] = None):
75 | """
76 | Sample molecules locally grouped by size to reduce idle calculations on paddings.
77 |
78 | Example:
79 | [3, 4, 3, 3, 4, 5, 4] - sizes of molecules in dataset
80 | [[0, 2, 3], [1, 4, 6], [5]] - output indices for batch_size=3
81 |
82 | :param batch_size: batch size
83 | :param sizes: precalculated sizes of molecules.
84 | """
85 | self.dataset = dataset
86 | self.batch_size = batch_size
87 | self.shuffle = shuffle
88 | self.seed = seed
89 | self.sizes = _build_index(dataset, sizes)
90 |
91 | def __iter__(self) -> Iterator[List[int]]:
92 | if self.shuffle:
93 | generator = Generator()
94 | generator.manual_seed(self.seed)
95 | index = _indices(randperm(len(self.sizes), generator=generator).tolist(), self.sizes, self.batch_size)
96 | else:
97 | index = _indices(range(len(self.sizes)), self.sizes, self.batch_size)
98 |
99 | index = iter(index)
100 | while batch := list(islice(index, self.batch_size)):
101 | yield batch
102 |
103 | def __len__(self):
104 | return ceil(len(self.sizes) / self.batch_size)
105 |
106 |
107 | class DistributedStructureSampler(Sampler[List[int]]):
108 | def __init__(self, dataset: MoleculeDataset, batch_size: int, num_replicas: Optional[int] = None,
109 | rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, *,
110 | sizes: Optional[List[int]] = None):
111 | """
112 | Sample molecules locally grouped by size to reduce idle calculations on paddings.
113 |
114 | :param batch_size: expected batch size
115 | :param sizes: precalculated sizes of molecules.
116 | :param num_replicas, rank, shuffle, seed: see torch.utils.data.DistributedSampler for details.
117 | """
118 | self.dataset = dataset
119 | self.batch_size = batch_size
120 | self.shuffle = shuffle
121 | self.seed = seed
122 | self.epoch = 0
123 | self.sizes = _build_index(dataset, sizes)
124 |
125 | # adapted from torch/utils/data/distributed.py
126 | if num_replicas is None:
127 | if not dist.is_available():
128 | raise RuntimeError('Requires distributed package to be available')
129 | num_replicas = dist.get_world_size()
130 | if rank is None:
131 | if not dist.is_available():
132 | raise RuntimeError('Requires distributed package to be available')
133 | rank = dist.get_rank()
134 | if rank >= num_replicas or rank < 0:
135 | raise ValueError(f'Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]')
136 |
137 | self.num_replicas = num_replicas
138 | self.rank = rank
139 | self.num_samples = ceil(len(self.sizes) / num_replicas)
140 | self.total_size = self.num_samples * num_replicas
141 |
142 | def __len__(self) -> int:
143 | return ceil(self.num_samples / self.batch_size)
144 |
145 | def __iter__(self) -> Iterator[List[int]]:
146 | if self.shuffle:
147 | generator = Generator()
148 | generator.manual_seed(self.seed + self.epoch)
149 | indices = _indices(randperm(len(self.sizes), generator=generator).tolist(), self.sizes, self.batch_size)
150 | else:
151 | indices = _indices(range(len(self.sizes)), self.sizes, self.batch_size)
152 |
153 | # adapted from torch/utils/data/distributed.py
154 | padding_size = self.total_size - len(indices)
155 | if padding_size <= len(indices):
156 | indices += indices[:padding_size]
157 | else:
158 | indices += (indices * ceil(padding_size / len(indices)))[:padding_size]
159 | assert len(indices) == self.total_size
160 |
161 | # subsample
162 | indices = indices[self.rank:self.total_size:self.num_replicas]
163 | assert len(indices) == self.num_samples
164 |
165 | indices = iter(indices)
166 | while batch := list(islice(indices, self.batch_size)):
167 | yield batch
168 |
169 | def set_epoch(self, epoch: int):
170 | self.epoch = epoch
171 |
172 |
173 | __all__ = ['StructureSampler', 'DistributedStructureSampler']
174 |
--------------------------------------------------------------------------------
/chytorch/utils/data/smiles.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from chython import MoleculeContainer, ReactionContainer, smiles
24 | from torch import Size
25 | from torch.utils.data import Dataset
26 | from typing import Dict, Union, Sequence, Optional, Type
27 |
28 |
29 | class SMILESDataset(Dataset):
30 | def __init__(self, data: Sequence[str], *, canonicalize: bool = False, cache: Optional[Dict[int, bytes]] = None,
31 | dtype: Union[Type[MoleculeContainer], Type[ReactionContainer]] = MoleculeContainer,
32 | unpack: bool = True, ignore_stereo: bool = True, ignore_bad_isotopes: bool = False,
33 | keep_implicit: bool = False, ignore_carbon_radicals: bool = False):
34 | """
35 | Smiles to chython containers on-the-fly parser dataset.
36 | Note: SMILES strings or coded structures can be invalid and lead to exception raising.
37 | Make sure you have validated input.
38 |
39 | :param data: smiles dataset
40 | :param canonicalize: do standardization (slow, better to prepare data in advance and keep in kekule form)
41 | :param cache: dict-like object for caching processed data. caching disabled by default.
42 | :param dtype: expected type of smiles (reaction or molecule)
43 | :param unpack: return unpacked structure or chython pack
44 | :param ignore_stereo: Ignore stereo data.
45 | :param keep_implicit: keep given in smiles implicit hydrogen count, otherwise ignore on valence error.
46 | :param ignore_bad_isotopes: reset invalid isotope mark to non-isotopic.
47 | :param ignore_carbon_radicals: fill carbon radicals with hydrogen (X[C](X)X case).
48 | """
49 | self.data = data
50 | self.canonicalize = canonicalize
51 | self.cache = cache
52 | self.dtype = dtype
53 | self.unpack = unpack
54 | self.ignore_stereo = ignore_stereo
55 | self.ignore_bad_isotopes = ignore_bad_isotopes
56 | self.keep_implicit = keep_implicit
57 | self.ignore_carbon_radicals = ignore_carbon_radicals
58 |
59 | def __getitem__(self, item: int) -> Union[MoleculeContainer, ReactionContainer, bytes]:
60 | if self.cache is not None and item in self.cache:
61 | s = self.cache[item]
62 | if self.unpack:
63 | return self.dtype.unpack(s)
64 | return s
65 |
66 | s = smiles(self.data[item], ignore_stereo=self.ignore_stereo, ignore_bad_isotopes=self.ignore_bad_isotopes,
67 | keep_implicit=self.keep_implicit, ignore_carbon_radicals=self.ignore_carbon_radicals)
68 | if not isinstance(s, self.dtype):
69 | raise TypeError(f'invalid SMILES: {self.dtype} expected, but {type(s)} given')
70 | if self.canonicalize:
71 | s.canonicalize()
72 |
73 | if self.cache is not None:
74 | p = s.pack()
75 | self.cache[item] = p
76 | if self.unpack:
77 | return s
78 | return p
79 | if self.unpack:
80 | return s
81 | return s.pack()
82 |
83 | def __len__(self):
84 | return len(self.data)
85 |
86 | def size(self, dim):
87 | if dim == 0:
88 | return len(self)
89 | elif dim is None:
90 | return Size((len(self),))
91 | raise IndexError
92 |
93 |
94 | __all__ = ['SMILESDataset']
95 |
--------------------------------------------------------------------------------
/chytorch/utils/data/unpack.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright 2023, 2024 Ramil Nugmanov
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the “Software”), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is furnished
10 | # to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | from functools import cached_property, partial
24 | from json import loads as json_loads
25 | from pickle import loads
26 | from struct import Struct
27 | from torch import Tensor, tensor, float32, Size, frombuffer
28 | from torch.utils.data import Dataset
29 | from typing import List, Tuple
30 | from zlib import decompress
31 |
32 |
33 | class StructUnpack(Dataset):
34 | def __init__(self, data: List[bytes], format_spec: str, dtype=float32, shape: Tuple[int, ...] = None):
35 | """
36 | Unpack python.struct packed tensors to 1d-tensors.
37 | Useful in case of highly compressed data.
38 |
39 | :param data: packed data
40 | :param format_spec: python.struct format for unpacking data
41 | (e.g. '>bbl' - 2 one-byte ints and 1 big-endian 4 byte int)
42 | :param dtype: output tensor dtype
43 | :param shape: reshape unpacked 1-D tensor
44 | """
45 | self.data = data
46 | self.format_spec = format_spec
47 | self.dtype = dtype
48 | self.shape = shape
49 | self._struct = Struct(format_spec)
50 |
51 | def __getitem__(self, item: int) -> Tensor:
52 | x = tensor(self._struct.unpack(self.data[item]), dtype=self.dtype)
53 | if self.shape is not None:
54 | return x.reshape(self.shape)
55 | return x
56 |
57 | def __len__(self):
58 | return len(self.data)
59 |
60 | def size(self, dim):
61 | if dim == 0:
62 | return len(self)
63 | elif dim is None:
64 | return Size((len(self),))
65 | raise IndexError
66 |
67 |
68 | class TensorUnpack(Dataset):
69 | def __init__(self, data: List[bytes], dtype=float32, shape: Tuple[int, ...] = None):
70 | """
71 | Unpack raw tensor byte buffers to 1d-tensors.
72 |
73 | :param data: packed data
74 | :param dtype: dtype of buffer
75 | :param shape: reshape unpacked 1-D tensor
76 | """
77 | self.data = data
78 | self.dtype = dtype
79 | self.shape = shape
80 |
81 | def __getitem__(self, item: int) -> Tensor:
82 | x = frombuffer(self.data[item], dtype=self.dtype)
83 | if self.shape is not None:
84 | return x.reshape(self.shape)
85 | return x
86 |
87 | def __len__(self):
88 | return len(self.data)
89 |
90 | def size(self, dim):
91 | if dim == 0:
92 | return len(self)
93 | elif dim is None:
94 | return Size((len(self),))
95 | raise IndexError
96 |
97 |
98 | class PickleUnpack(Dataset):
99 | def __init__(self, data: List[bytes]):
100 | """
101 | Unpack python-pickled data.
102 |
103 | :param data: packed data
104 | """
105 | self.data = data
106 |
107 | def __getitem__(self, item: int):
108 | return loads(self.data[item])
109 |
110 | def __len__(self):
111 | return len(self.data)
112 |
113 | def size(self, dim):
114 | if dim == 0:
115 | return len(self)
116 | elif dim is None:
117 | return Size((len(self),))
118 | raise IndexError
119 |
120 |
121 | class JsonUnpack(Dataset):
122 | def __init__(self, data: List[str]):
123 | """
124 | Unpack Json data.
125 |
126 | :param data: json strings
127 | """
128 | self.data = data
129 |
130 | def __getitem__(self, item: int):
131 | return json_loads(self.data[item])
132 |
133 | def __len__(self):
134 | return len(self.data)
135 |
136 | def size(self, dim):
137 | if dim == 0:
138 | return len(self)
139 | elif dim is None:
140 | return Size((len(self),))
141 | raise IndexError
142 |
143 |
144 | class Decompress(Dataset):
145 | def __init__(self, data: List[bytes], method: str = 'zlib', zdict: bytes = None):
146 | """
147 | Decompress zipped data.
148 |
149 | :param data: compressed data
150 | :param method: zlib or zstd
151 | :param zdict: zstd decompression dictionary
152 | """
153 | assert method in ('zlib', 'zstd')
154 | self.data = data
155 | self.method = method
156 | self.zdict = zdict
157 |
158 | def __getitem__(self, item: int) -> bytes:
159 | return self.decompress(self.data[item])
160 |
161 | @cached_property
162 | def decompress(self):
163 | if self.method == 'zlib':
164 | return decompress
165 | # zstd
166 | from pyzstd import decompress as dc, ZstdDict
167 |
168 | if self.zdict is not None:
169 | return partial(dc, zstd_dict=ZstdDict(self.zdict))
170 | return dc
171 |
172 | def __len__(self):
173 | return len(self.data)
174 |
175 | def size(self, dim):
176 | if dim == 0:
177 | return len(self)
178 | elif dim is None:
179 | return Size((len(self),))
180 | raise IndexError
181 |
182 |
183 | class Decode(Dataset):
184 | def __init__(self, data: List[bytes], encoding: str = 'utf8'):
185 | """
186 | Bytes to string decoder dataset
187 |
188 | :param data: byte-coded strings
189 | :param encoding: string encoding
190 | """
191 | self.data = data
192 | self.encoding = encoding
193 |
194 | def __getitem__(self, item: int) -> str:
195 | return self.data[item].decode(encoding=self.encoding)
196 |
197 | def __len__(self):
198 | return len(self.data)
199 |
200 | def size(self, dim):
201 | if dim == 0:
202 | return len(self)
203 | elif dim is None:
204 | return Size((len(self),))
205 | raise IndexError
206 |
207 |
208 | __all__ = ['TensorUnpack', 'StructUnpack', 'PickleUnpack', 'JsonUnpack', 'Decompress', 'Decode']
209 |
--------------------------------------------------------------------------------
/chytorch/zoo/README.md:
--------------------------------------------------------------------------------
1 | Trained models
2 | --------------
3 |
4 | Namespace package for trained models collection.
5 |
--------------------------------------------------------------------------------
/examples/reference_start.py:
--------------------------------------------------------------------------------
1 | import click
2 | import pandas as pd
3 | import pytorch_lightning as pl
4 | import torch
5 | import torch.nn as nn
6 | from chython import smiles
7 | from chython.exceptions import InvalidAromaticRing
8 | from torch.optim import Adam
9 | from torch.utils.data import DataLoader, TensorDataset
10 | from typing import Optional, Union
11 |
12 | from chytorch.nn import MoleculeEncoder, Slicer
13 | from chytorch.utils.data import MoleculeDataset, chained_collate, collate_molecules
14 |
15 | torch.manual_seed(1)
16 |
17 | # check GPU
18 | device = "cuda" if torch.cuda.is_available() else "cpu"
19 | print(f"Using {device} device")
20 |
21 |
22 | class PandasData(pl.LightningDataModule):
23 | def __init__(
24 | self,
25 | csv: str,
26 | structure: str,
27 | property: str,
28 | dataset_type: str,
29 | prepared_df_path: str,
30 | batch_size: int = 32,
31 | ):
32 | super().__init__()
33 | self.train_x = None
34 | self.train_y = None
35 | self.test_x = None
36 | self.test_y = None
37 | self.validation_x = None
38 | self.validation_y = None
39 | self.prepared_df_path = prepared_df_path
40 | self.csv = csv
41 | self.structure = structure
42 | self.property = property
43 | self.dataset_type = dataset_type
44 | self.batch_size = batch_size
45 |
46 | @staticmethod
47 | def prepare_mol(mol_smi):
48 | try:
49 | mol = smiles(mol_smi)
50 | try:
51 | mol.kekule()
52 | except InvalidAromaticRing:
53 | mol = None
54 | except Exception:
55 | mol = None
56 | return mol
57 |
58 | def prepare_data(self):
59 | df = pd.read_csv(self.csv)
60 | df = df[[self.structure, self.property, self.dataset_type]]
61 | df[self.structure] = df[self.structure].apply(self.prepare_mol)
62 | df.dropna(inplace=True)
63 | df.to_pickle(self.prepared_df_path)
64 |
65 | def setup(self, stage: Optional[str] = None):
66 | df = pd.read_pickle(self.prepared_df_path)
67 | if stage == "fit" or stage is None:
68 | df_train = df[df.dataset == "train"]
69 | mols = df_train[self.structure].to_list()
70 | self.train_x = MoleculeDataset(mols)
71 | self.train_y = torch.Tensor(df_train[self.property].to_numpy())
72 |
73 | if stage == "validation" or stage is None:
74 | df_validation = df[df.dataset == "validation"]
75 | mols = df_validation[self.structure].to_list()
76 | self.validation_x = MoleculeDataset(mols)
77 | self.validation_y = torch.Tensor(df_validation[self.property].to_numpy())
78 |
79 | if stage == "test" or stage is None:
80 | df_test = df[df.dataset == "test"]
81 | mols = df_test[self.structure].to_list()
82 | self.test_x = MoleculeDataset(mols)
83 | self.test_y = torch.Tensor(df_test[self.property].to_numpy())
84 |
85 | def train_dataloader(self):
86 | return DataLoader(
87 | dataset=TensorDataset(self.train_x, self.train_y),
88 | collate_fn=chained_collate(collate_molecules, torch.stack),
89 | batch_size=self.batch_size,
90 | shuffle=True,
91 | )
92 |
93 | def validation_dataloader(self):
94 | return DataLoader(
95 | dataset=TensorDataset(self.train_x, self.train_y),
96 | collate_fn=chained_collate(collate_molecules, torch.stack),
97 | batch_size=self.batch_size,
98 | shuffle=True,
99 | )
100 |
101 | def test_dataloader(self):
102 | return DataLoader(
103 | dataset=TensorDataset(self.test_x, self.test_y),
104 | collate_fn=chained_collate(collate_molecules, torch.stack),
105 | batch_size=self.batch_size,
106 | )
107 |
108 |
109 | class Modeler:
110 | def __init__(
111 | self,
112 | loss_function,
113 | epochs: int,
114 | learning_rate: Union[float, int],
115 | model_path: Optional[str] = None,
116 | ):
117 | self.network = None
118 | self.optimizer = None
119 | self.loss_function = loss_function
120 | self.learning_rate = learning_rate
121 | self.epochs = epochs
122 | self.model_path = model_path
123 |
124 | def train_loop(self, dataset_loader: DataLoader):
125 | size = len(dataset_loader.dataset)
126 | for batch, (X, y) in enumerate(dataset_loader):
127 | # compute prediction and loss
128 | predictions = self.network(X)
129 | loss = self.loss_function(predictions.squeeze(-1), y)
130 | # backpropagation
131 | self.optimizer.zero_grad()
132 | loss.backward()
133 | self.optimizer.step()
134 | if batch % 10 == 0:
135 | loss_v, current = loss.item(), batch * len(X[0])
136 | print(f"loss: {loss_v:>3f} [{current:>5d}/{size:>5d}]")
137 |
138 | def validation_loop(self, dataset_loader: DataLoader):
139 | size = len(dataset_loader.dataset)
140 | num_batches = len(dataset_loader)
141 | self.network.eval()
142 | test_loss, correct = 0, 0
143 | with torch.no_grad():
144 | for X, y in dataset_loader:
145 | predictions = self.network(X)
146 | test_loss += self.loss_function(predictions.squeeze(-1), y).item()
147 | correct += (predictions.argmax(1) == y).type(torch.float).sum().item()
148 | test_loss = test_loss / num_batches
149 | correct = correct / size
150 | print(
151 | f"Validation Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
152 | )
153 |
154 | def save(self):
155 | torch.save(self.network.state_dict(), self.model_path)
156 |
157 | def fit(self, dataset):
158 | """
159 | Run model training
160 | """
161 | self.network = nn.Sequential(
162 | MoleculeEncoder(),
163 | Slicer(slice(None), 0),
164 | nn.Linear(in_features=1024, out_features=512),
165 | nn.ReLU(),
166 | nn.Dropout(0.1),
167 | nn.Linear(in_features=512, out_features=1),
168 | nn.Sigmoid(),
169 | )
170 | self.optimizer = Adam(self.network.parameters(), lr=self.learning_rate)
171 | for epoch in range(self.epochs):
172 | print(f"Epoch {epoch + 1}\n------")
173 | self.train_loop(dataset.train_dataloader())
174 | self.validation_loop(dataset.validation_dataloader())
175 | if self.model_path:
176 | self.save()
177 |
178 |
179 | @click.command()
180 | @click.option(
181 | "-d", "--path_to_csv", type=click.Path(), help="Path to csv file with data."
182 | )
183 | @click.option(
184 | "-i",
185 | "--path_to_interm_dataset",
186 | type=click.Path(),
187 | help="Path to pickle with intermediate data.",
188 | )
189 | @click.option("-m", "--path_to_model", type=click.Path(), help="Path to model.pt.")
190 | def train(path_to_csv, path_to_interm_dataset, path_to_model):
191 | dataset = PandasData(
192 | csv=path_to_csv,
193 | structure="std_smiles",
194 | property="activity",
195 | dataset_type="dataset",
196 | prepared_df_path=path_to_interm_dataset,
197 | batch_size=10,
198 | )
199 | dataset.prepare_data()
200 | dataset.setup()
201 | modeler = Modeler(
202 | loss_function=nn.BCELoss(),
203 | epochs=3,
204 | learning_rate=2e-5,
205 | model_path=path_to_model,
206 | )
207 | modeler.fit(dataset)
208 |
209 |
210 | if __name__ == "__main__":
211 | train()
212 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = 'chytorch'
3 | version = '1.66'
4 | description = 'Library for modeling molecules and reactions in torch way'
5 | authors = ['Ramil Nugmanov ']
6 | license = 'MIT'
7 | readme = 'README.md'
8 | homepage = 'https://github.com/chython/chytorch'
9 | classifiers=[
10 | 'Environment :: Plugins',
11 | 'Intended Audience :: Science/Research',
12 | 'License :: OSI Approved :: MIT License',
13 | 'Operating System :: OS Independent',
14 | 'Programming Language :: Python',
15 | 'Programming Language :: Python :: 3 :: Only',
16 | 'Programming Language :: Python :: 3.8',
17 | 'Topic :: Scientific/Engineering',
18 | 'Topic :: Scientific/Engineering :: Chemistry',
19 | 'Topic :: Scientific/Engineering :: Information Analysis',
20 | 'Topic :: Software Development',
21 | 'Topic :: Software Development :: Libraries',
22 | 'Topic :: Software Development :: Libraries :: Python Modules'
23 | ]
24 |
25 | include = [
26 | {path = 'chytorch/utils/data/molecule/*.pyd', format = 'wheel'},
27 | {path = 'chytorch/utils/data/molecule/*.so', format = 'wheel'}
28 | ]
29 |
30 | [tool.poetry.dependencies]
31 | python = '>=3.8,<3.12'
32 | torchtyping = '^0.1.4'
33 | chython = '^1.70'
34 | scipy = '^1.10'
35 | torch = '>=1.8'
36 | lmdb = {version='^1.4.1', optional = true}
37 | psycopg2-binary = {version='^2.9', optional = true}
38 | rdkit = {version = '^2023.9.1', optional = true}
39 | pyzstd = {version = '^0.15.9', optional = true}
40 |
41 | [tool.poetry.extras]
42 | lmdb = ['lmdb']
43 | postgres = ['psycopg2-binary']
44 | rdkit = ['rdkit']
45 | zstd = ['pyzstd']
46 |
47 | [build-system]
48 | requires = ['poetry-core', 'setuptools', 'cython>=3.0.5', 'numpy>=1.23.3']
49 | build-backend = 'poetry.core.masonry.api'
50 |
51 | [tool.poetry.build]
52 | script = 'build.py'
53 | generate-setup-file = false
54 |
--------------------------------------------------------------------------------