├── .github └── workflows │ └── pythonpackage.yml ├── .gitignore ├── LICENSE ├── README.md ├── examples ├── __init__.py ├── augmented_cycle_gan.py ├── cartpole.py ├── cifar_contrastive.py ├── cifar_multiscale.py ├── cifar_off_ebm.py ├── cifar_vae.py ├── coinrun.py ├── conditional_cifar_classifier.py ├── conditional_mnist_ebm.py ├── conditional_mnist_lsd.py ├── conditional_mnist_neural_conditioner.py ├── conditional_mnist_score.py ├── conditional_mnist_score_classifier.py ├── cycle_gan.py ├── flowers_consistent_gan.py ├── mnist_ddim.py ├── mnist_neural_conditioner.py ├── mnist_off_ebm.py ├── mnist_vae.py ├── pix2pix.py ├── set_mnist_ebm.py ├── set_mnist_gan.py └── set_yeast_ebm.py ├── setup.py └── torchsupport ├── __init__.py ├── data ├── __init__.py ├── asap_xml.py ├── chem │ ├── __init__.py │ └── qm9.py ├── collate.py ├── episodic.py ├── io.py ├── match.py ├── namedtuple.py ├── roi_image.py ├── slides.py ├── structured.py ├── tensor_provider.py └── transforms.py ├── deprecated └── rl │ ├── __init__.py │ ├── agent.py │ ├── bdpi.py │ ├── data.py │ ├── dqn.py │ ├── environment.py │ ├── imitation.py │ ├── memory.py │ ├── off_policy.py │ ├── sampler.py │ ├── task.py │ └── trajectory.py ├── distributions ├── __init__.py ├── kl_divergence.py ├── mixture.py ├── mixture_of_logits.py ├── modifiers.py ├── standard.py ├── structured.py ├── vae_distribution.py └── von_mises.py ├── experimental ├── __init__.py ├── apps │ ├── __init__.py │ └── fewshot.py ├── autonomous.py ├── enas │ ├── __init__.py │ └── controller.py ├── gan.py ├── losses │ ├── __init__.py │ ├── instance_segmentation.py │ ├── segmentation.py │ └── vae.py ├── modules │ ├── __init__.py │ ├── compact.py │ ├── fast_gaussian.py │ └── structured │ │ ├── __init__.py │ │ ├── graphnn.py │ │ ├── nodegraph.py │ │ ├── pooling.py │ │ └── spectral.py ├── rl │ ├── __init__.py │ ├── policy_gradient.py │ └── trajectory.py └── trainables │ └── unsupervised_segmentation │ ├── main.py │ ├── model.py │ ├── parser.py │ └── training.py ├── flex ├── __init__.py ├── checkpointing │ ├── __init__.py │ ├── checkpoint.py │ ├── savable.py │ └── savable_extensions.py ├── context │ ├── __init__.py │ ├── context.py │ └── context_module.py ├── data_distributions │ ├── __init__.py │ ├── data_distribution.py │ └── metadata_distribution.py ├── examples │ ├── __init__.py │ ├── cifar_cnce.py │ ├── cifar_conditional_maximum_likelihood.py │ ├── cifar_drl.py │ ├── cifar_rsm.py │ ├── cifar_supervised.py │ ├── cifar_tdre.py │ ├── cifar_tnce.py │ └── cifar_tnce_unet.py ├── log │ ├── __init__.py │ ├── log_types.py │ ├── logger.py │ └── tensorboard_logger.py ├── step │ ├── __init__.py │ ├── loop.py │ ├── step.py │ └── training_loop.py ├── tasks │ ├── __init__.py │ ├── energy │ │ ├── __init__.py │ │ ├── density_ratio.py │ │ ├── diffusion.py │ │ ├── diffusion_recovery_likelihood.py │ │ ├── relaxed_score_matching.py │ │ └── score_matching.py │ ├── gan │ │ ├── __init__.py │ │ ├── losses.py │ │ └── tasks.py │ ├── likelihood │ │ ├── __init__.py │ │ └── maximum_likelihood.py │ ├── regularization │ │ └── __init__.py │ ├── task.py │ └── utils.py ├── test.py ├── training │ ├── __init__.py │ ├── conditional_maximum_likelihood.py │ ├── density_ratio.py │ ├── score_matching.py │ └── supervised.py ├── update │ ├── __init__.py │ └── update.py └── utils.py ├── interacting ├── __init__.py ├── awac.py ├── awr.py ├── bdpi.py ├── buffer.py ├── collector_task.py ├── control.py ├── crr.py ├── data_collector.py ├── distributor_task.py ├── energies │ ├── __init__.py │ └── energy.py ├── environments │ ├── __init__.py │ ├── cartpole.py │ ├── coinrun.py │ └── environment.py ├── interacting_training.py ├── off_ebm.py ├── off_energy_training.py ├── off_policy_training.py ├── policies │ ├── __init__.py │ ├── basic.py │ ├── mcts.py │ └── policy.py ├── shared_data.py └── stats.py ├── modules ├── __init__.py ├── activations │ ├── __init__.py │ └── geometry.py ├── attention.py ├── backbones │ ├── __init__.py │ ├── diffusion │ │ ├── __init__.py │ │ └── unet.py │ ├── gan │ │ ├── __init__.py │ │ ├── biggan.py │ │ ├── dcgan.py │ │ ├── discriminator_components.py │ │ ├── independent_dcgan.py │ │ ├── independent_stylegan2.py │ │ ├── resgan.py │ │ └── stylegan2.py │ ├── multiscale │ │ ├── __init__.py │ │ └── multiscale_resnet.py │ └── vae │ │ ├── __init__.py │ │ └── deep_vae.py ├── basic.py ├── combination.py ├── dynamic.py ├── dynet.py ├── generative.py ├── geometric_vector_perceptron.py ├── gradient.py ├── invertible.py ├── losses │ ├── __init__.py │ ├── clustering.py │ ├── generative.py │ └── vae.py ├── masked.py ├── multiscale.py ├── normalization.py ├── polynomial.py ├── recurrent.py ├── reduction.py ├── refine.py ├── residual.py ├── rezero.py ├── routing.py ├── separable.py ├── unet.py ├── vector_quantisation.py ├── weights.py └── zoom.py ├── networks ├── __init__.py ├── fewshot.py └── unet.py ├── nn └── __init__.py ├── ops ├── __init__.py └── shape.py ├── optim ├── __init__.py └── radam.py ├── reporting ├── __init__.py └── reporting.py ├── structured ├── __init__.py ├── chunkable.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── basic.py │ ├── materialized_transformer.py │ ├── parallel.py │ ├── pooling.py │ ├── rezero_transformer.py │ ├── sequence_transformer.py │ ├── spectral.py │ └── transformer.py ├── packedtensor.py ├── scatter.py └── structures │ ├── __init__.py │ ├── basic.py │ └── connection.py ├── test ├── __init__.py ├── test_argparse.py ├── test_gradient.py ├── test_mlp.py ├── test_onehot.py └── test_scatter.py ├── training ├── __init__.py ├── clustering.py ├── consistent_gan.py ├── contrastive.py ├── contrastive_multiscale.py ├── denoising_diffusion.py ├── distributed.py ├── energy.py ├── energy_sampler.py ├── energy_supervised.py ├── few_shot_gan.py ├── gan.py ├── hybrid_gan_training.py ├── log │ ├── __init__.py │ ├── log_types.py │ └── tensorboard_logger.py ├── lsd.py ├── multiscale_training.py ├── multistep_training.py ├── neural_conditioner.py ├── neural_processes.py ├── samplers.py ├── score_supervised.py ├── state.py ├── tilted_supervised_training.py ├── training.py ├── translation.py ├── vae.py ├── vera.py └── vqgan.py └── utils ├── __init__.py ├── argparse.py └── memory.py /.github/workflows/pythonpackage.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ubuntu-latest 12 | strategy: 13 | max-parallel: 4 14 | matrix: 15 | python-version: [3.6, 3.7] 16 | 17 | steps: 18 | - uses: actions/checkout@v1 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install --editable . 27 | - name: Test with pytest 28 | run: | 29 | pip install pytest 30 | pytest 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | *.pyc 3 | build 4 | dist 5 | scratch 6 | .vscode 7 | XDG_CACHE_HOME 8 | *.pyi 9 | .pytest_cache 10 | .idea 11 | *.torch 12 | *.ipynb 13 | __pycache__ 14 | *.tfevents.* 15 | examples/MNIST 16 | examples/raw 17 | examples/processed 18 | examples/flowers 19 | examples/cifar-10-batches-py 20 | examples/off-energy 21 | off-energy 22 | *.tar.gz 23 | awr-test* 24 | docs 25 | examples/* 26 | *.torch.old -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 Michael Jendrusch 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 18 | BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 19 | ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torchsupport 2 | 3 | This package contains helpers for advanced usage of PyTorch. -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/examples/__init__.py -------------------------------------------------------------------------------- /examples/cartpole.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as func 6 | 7 | from torchsupport.interacting.awr import AWRTraining 8 | from torchsupport.interacting.awac import AWACTraining 9 | from torchsupport.interacting.bdpi import BDPITraining 10 | from torchsupport.interacting.shared_data import SharedModule 11 | from torchsupport.interacting.policies.basic import RandomPolicy, CategoricalGreedyPolicy, CategoricalPolicy, EpsilonGreedyPolicy 12 | from torchsupport.interacting.environments.cartpole import CartPole 13 | 14 | from torchsupport.modules.basic import MLP 15 | 16 | class Policy(MLP): 17 | data_type = namedtuple("Data", ["logits", "outputs"]) 18 | def __init__(self, hidden_size=128, depth=3, value=False): 19 | super().__init__( 20 | 4, 2, 21 | hidden_size=hidden_size, 22 | depth=depth, 23 | batch_norm=False 24 | ) 25 | self.value = value 26 | 27 | def schema(self): 28 | return self.data_type( 29 | logits=torch.zeros(2, dtype=torch.float), 30 | outputs=None 31 | ) 32 | 33 | def forward(self, inputs, **kwargs): 34 | result = super().forward(inputs) 35 | if self.value: 36 | result = 1 * result.sigmoid() 37 | return result 38 | 39 | mode = sys.argv[1] 40 | index = sys.argv[2] 41 | training = ... 42 | if mode == "bdpi": 43 | policy = Policy() 44 | value = MLP(4, 2, batch_norm=False, depth=3) 45 | agent = CategoricalPolicy(SharedModule(policy)) 46 | env = CartPole() 47 | 48 | training = BDPITraining( 49 | policy, value, agent, env, 50 | network_name=f"awr-test/bdpi-cartpole-{index}", 51 | discount=0.99, 52 | clones=4, 53 | critic_updates=4, 54 | gradient_updates=20, 55 | batch_size=1024, 56 | buffer_size=100_000, 57 | device="cuda:0", 58 | verbose=True 59 | ) 60 | 61 | elif mode == "awr": 62 | policy = Policy() 63 | value = MLP(4, 1, batch_norm=False, depth=3) 64 | agent = CategoricalPolicy(SharedModule(policy)) 65 | env = CartPole() 66 | training = AWRTraining( 67 | policy, value, agent, env, 68 | network_name=f"awr-test/awr-cartpole-{index}", 69 | verbose=True, beta=0.05, 70 | auxiliary_steps=1, 71 | discount=0.990, 72 | clip=20, 73 | device="cuda:0", 74 | batch_size=1024, 75 | policy_steps=5 76 | ) 77 | 78 | elif mode == "awac": 79 | policy = Policy(hidden_size=64, depth=2) 80 | value = Policy(hidden_size=64, depth=2) 81 | agent = CategoricalPolicy(SharedModule(policy)) 82 | env = CartPole(scale=False) 83 | training = AWACTraining( 84 | policy, value, agent, env, 85 | network_name=f"awr-test/awac-cartpole-{index}", 86 | verbose=True, beta=0.05, 87 | auxiliary_steps=5, 88 | discount=0.99, 89 | clip=20, 90 | device="cuda:0", 91 | batch_size=1024, 92 | policy_steps=5 93 | ) 94 | 95 | training.train() 96 | -------------------------------------------------------------------------------- /examples/cifar_contrastive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | from torch.utils.data import Dataset 5 | 6 | from torchvision.datasets import CIFAR10 7 | from torchvision.transforms import ToTensor, Compose, RandomAffine, ColorJitter 8 | 9 | from torchsupport.training.contrastive import SimSiamTraining 10 | from torchsupport.modules import MLP 11 | from torchsupport.modules.rezero import ReZero 12 | 13 | class ContrastiveDataset(Dataset): 14 | def __init__(self, data, transform, variants=2): 15 | self.data = data 16 | self.transform = transform 17 | self.variants = variants 18 | 19 | def __getitem__(self, index): 20 | data, _ = self.data[index] 21 | variants = [ 22 | self.transform(data) 23 | for idx in range(self.variants) 24 | ] 25 | return variants 26 | 27 | def __len__(self): 28 | return len(self.data) 29 | 30 | class ResBlock(nn.Module): 31 | def __init__(self, in_size, out_size, kernel_size, depth=1): 32 | super().__init__() 33 | self.project_in = nn.Conv2d(in_size, in_size // 4, 1, bias=False) 34 | self.project_out = nn.Conv2d(in_size // 4, out_size, 1, bias=False) 35 | self.blocks = nn.ModuleList([ 36 | nn.Conv2d(in_size // 4, in_size // 4, kernel_size, padding=kernel_size // 2) 37 | for idx in range(depth) 38 | ]) 39 | self.zero = ReZero(out_size, initial_value=0.1) 40 | 41 | def forward(self, inputs): 42 | out = self.project_in(inputs) 43 | for block in self.blocks: 44 | out = func.gelu(block(out)) 45 | return self.zero(inputs, self.project_out(out)) 46 | 47 | class SimpleResNet(nn.Module): 48 | def __init__(self, features=128, depth=4, level_repeat=2, base=32): 49 | super().__init__() 50 | self.project = nn.Conv2d(3, base, 1) 51 | self.blocks = nn.ModuleList([ 52 | ResBlock(base, base, 3, depth=1) 53 | for idx in range(depth * level_repeat) 54 | ]) 55 | self.last = MLP(base, features, features) 56 | self.level_repeat = level_repeat 57 | 58 | def forward(self, inputs): 59 | out = self.project(inputs) 60 | for idx, block in enumerate(self.blocks): 61 | out = block(out) 62 | if (idx + 1) % self.level_repeat == 0: 63 | out = func.avg_pool2d(out, 2) 64 | out = func.adaptive_avg_pool2d(out, 1).view(out.size(0), -1) 65 | out = self.last(out) 66 | return out 67 | 68 | if __name__ == "__main__": 69 | cifar = CIFAR10("examples/", download=True) 70 | data = ContrastiveDataset(cifar, Compose([ 71 | ColorJitter(1.0, 1.0, 1.0, 0.5), 72 | RandomAffine(60, (0.5, 0.5), (0.5, 2.0), 60), 73 | ToTensor() 74 | ]), variants=2) 75 | 76 | base = 16 77 | features = 64 78 | net = SimpleResNet(features=features, base=base, level_repeat=4) 79 | predictor = MLP(features, features, hidden_size=32) 80 | 81 | training = SimSiamTraining( 82 | net, predictor, data, 83 | network_name="cifar-contrastive/siam-5", 84 | device="cuda:0", 85 | batch_size=32, 86 | max_epochs=1000, 87 | verbose=True 88 | ).load() 89 | 90 | training.train() 91 | -------------------------------------------------------------------------------- /examples/cifar_multiscale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | from torch.distributions import Normal 5 | from torch.utils.data import Dataset 6 | 7 | from torchvision.datasets import CIFAR10 8 | from torchvision.transforms import ToTensor, Compose, RandomAffine, ColorJitter 9 | 10 | from torchsupport.training.multiscale_training import MultiscaleClassifierTraining, MultiscaleNet 11 | from torchsupport.modules import MLP 12 | from torchsupport.modules.rezero import ReZero 13 | from torchsupport.data.namedtuple import NamedTuple 14 | 15 | class CIFARMultiscaleDataset(Dataset): 16 | def __init__(self, data, k=4, separate=False): 17 | self.data = data 18 | self.k = k 19 | self.separate = separate 20 | if self.separate: 21 | self.k = 1 22 | 23 | def __getitem__(self, index): 24 | data, label = self.data[index] 25 | low = func.interpolate(data[None], scale_factor=1 / 4, mode="bilinear") 26 | high = torch.cat(torch.cat(data[None].chunk(4, dim=3), dim=0)[None].chunk(4, dim=3), dim=0) 27 | perm = torch.randperm(4 * 4) 28 | x = (perm // 4)[:self.k] 29 | y = (perm % 4)[:self.k] 30 | high = high[x, y] 31 | mask = (x, y) 32 | masks = [None, mask] 33 | inputs = [high, low] 34 | 35 | if self.separate: 36 | return list(zip(inputs, [label, label])) 37 | return (inputs, masks), (label, label) 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | 42 | class ScaleBlock(nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | self.spatial = MLP(8 * 8 * 3, 4 * 4 * 32, batch_norm=False) 46 | self.prior = MLP(4 * 4 * 32, 4 * 4 * 32, batch_norm=False) 47 | self.posterior = nn.Linear(4 * 4 * (32 + 16), 16) 48 | self.policy = nn.Linear(4 * 4 * 32, 4 * 4 * 1) 49 | self.task = nn.Linear(16, 10) 50 | 51 | def forward(self, inputs, mask=None, sample=None): 52 | shape = inputs.shape 53 | ind = torch.arange(shape[0], dtype=torch.long, device=inputs.device)[:, None] 54 | out = inputs.view(-1, 8 * 8 * 3) 55 | out = self.spatial(out) 56 | policy = self.policy(out) 57 | policy = policy.view(-1, 4, 4) 58 | mu, logvar = self.prior(out).view(-1, 4 * 4, 32).chunk(2, dim=-1) 59 | prior = Normal(mu, 1.0) 60 | prior_sample = 1.0 * mu#prior.sample() 61 | prior_sample = prior_sample.view(-1, 4, 4, 16) 62 | if mask: 63 | prior_sample[ind, mask[0], mask[1]] = sample 64 | res = mu.view(-1, 4, 4, 16)[ind, mask[0], mask[1]] 65 | logvar = logvar.view(-1, 4, 4, 16)[ind, mask[0], mask[1]] 66 | prior = Normal(res, 1.0) 67 | policy = policy[ind, mask[0], mask[1]] 68 | else: 69 | policy = None 70 | out = out.view(-1, 4, 4, 32) 71 | out = torch.cat((out, prior_sample), dim=-1) 72 | sample = self.posterior(out.view(-1, 4 * 4 * (32 + 16))) 73 | task = self.task(sample) 74 | return task, NamedTuple(prior=prior, posterior=sample, policy=policy) 75 | 76 | if __name__ == "__main__": 77 | cifar = CIFAR10("examples/", transform=ToTensor(), download=True) 78 | path_data = CIFARMultiscaleDataset(cifar, k=4) 79 | separate_data = CIFARMultiscaleDataset(cifar, separate=True) 80 | 81 | net = MultiscaleNet([ 82 | ScaleBlock(), 83 | ScaleBlock() 84 | ]) 85 | 86 | training = MultiscaleClassifierTraining( 87 | net, separate_data, 88 | stack_data=path_data, 89 | path_data=path_data, 90 | network_name="cifar-multiscale/17", 91 | device="cuda:0", 92 | batch_size=32, 93 | max_epochs=1000, 94 | verbose=True, 95 | n_path=10 96 | ) 97 | 98 | training.train() 99 | -------------------------------------------------------------------------------- /examples/cifar_off_ebm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | from torch.nn.utils import spectral_norm 5 | from torch.utils.data import Dataset 6 | 7 | from torchvision.datasets import CIFAR10 8 | from torchvision.transforms import ToTensor 9 | 10 | from torchsupport.modules.basic import MLP 11 | from torchsupport.modules.residual import ResNetBlock2d 12 | from torchsupport.modules.normalization import FilterResponseNorm, NotNorm, AdaNorm, SemiNorm 13 | from torchsupport.training.samplers import Langevin 14 | from torchsupport.interacting.off_ebm import OffEBMTraining 15 | from torchsupport.interacting.energies.energy import Energy 16 | from torchsupport.interacting.shared_data import SharedModule 17 | 18 | def normalize(image): 19 | return (image - image.min()) / (image.max() - image.min()) 20 | 21 | class EnergyDataset(Dataset): 22 | def __init__(self, data): 23 | self.data = data 24 | 25 | def __getitem__(self, index): 26 | data, label_index = self.data[index] 27 | # data = data + 0.05 * torch.rand_like(data) 28 | label = torch.zeros(10) 29 | label[label_index] = 1 30 | return (data,)#, label 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | class Convolutional(nn.Module): 36 | def __init__(self, depth=4): 37 | super(Convolutional, self).__init__() 38 | self.preprocess = spectral_norm(nn.Conv2d(3, 32, 1)) 39 | self.blocks = nn.ModuleList([ 40 | spectral_norm(nn.Conv2d(32, 32, 3, padding=1)) 41 | for idx in range(depth) 42 | ]) 43 | self.postprocess = spectral_norm(nn.Linear(32, 1)) 44 | 45 | def forward(self, inputs): 46 | out = self.preprocess(inputs) 47 | for block in self.blocks: 48 | out = func.relu(out + block(out)) 49 | out = func.avg_pool2d(out, 2) 50 | out = func.adaptive_avg_pool2d(out, 1).view(-1, 32) 51 | out = self.postprocess(out) 52 | return out 53 | 54 | class CIFAR10Energy(Energy): 55 | def prepare(self, batch_size): 56 | return self.sample_type( 57 | data=torch.rand(batch_size, 3, 32, 32), 58 | args=None 59 | ) 60 | 61 | class CIFAR10EnergyTraining(OffEBMTraining): 62 | def each_generate(self, batch): 63 | data = batch.final_state 64 | samples = [torch.clamp(sample, 0, 1) for sample in data[0:10]] 65 | samples = torch.cat(samples, dim=-1) 66 | self.writer.add_image("samples", samples, self.step_id) 67 | 68 | if __name__ == "__main__": 69 | import torch.multiprocessing as mp 70 | mp.set_start_method("spawn") 71 | 72 | mnist = CIFAR10("examples/", download=True, transform=ToTensor()) 73 | data = EnergyDataset(mnist) 74 | 75 | score = Convolutional(depth=4) 76 | energy = CIFAR10Energy(SharedModule(score, dynamic=True), keep_rate=0.95) 77 | integrator = Langevin(rate=50, steps=20, max_norm=None, clamp=(0, 1)) 78 | 79 | training = CIFAR10EnergyTraining( 80 | score, energy, data, 81 | network_name="off-energy/cifar10-off-energy-2", 82 | device="cuda:0", 83 | integrator=integrator, 84 | off_energy_weight=5, 85 | batch_size=64, 86 | off_energy_decay=1, 87 | decay=1.0, 88 | n_workers=8, 89 | double=True, 90 | buffer_size=10_000, 91 | max_steps=int(1e6), 92 | report_interval=10, 93 | verbose=True 94 | ) 95 | 96 | training.train() 97 | -------------------------------------------------------------------------------- /examples/coinrun.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as func 6 | 7 | from torchsupport.interacting.awr import AWRTraining 8 | from torchsupport.interacting.awac import AWACTraining 9 | from torchsupport.interacting.bdpi import BDPITraining 10 | from torchsupport.interacting.shared_data import SharedModule 11 | from torchsupport.interacting.policies.basic import RandomPolicy, CategoricalGreedyPolicy, CategoricalPolicy, EpsilonGreedyPolicy 12 | from torchsupport.interacting.environments.coinrun import CoinRun 13 | 14 | from torchsupport.modules.basic import MLP 15 | 16 | class Policy(nn.Module): 17 | data_type = namedtuple("Data", ["logits", "outputs"]) 18 | def __init__(self, in_size=3, out_size=15): 19 | super().__init__() 20 | self.blocks = nn.Sequential( 21 | nn.Conv2d(in_size, 32, 3), 22 | nn.ReLU(), 23 | nn.InstanceNorm2d(32), 24 | nn.MaxPool2d(2), 25 | nn.Conv2d(32, 32, 3), 26 | nn.InstanceNorm2d(32), 27 | nn.ReLU(), 28 | nn.MaxPool2d(2), 29 | nn.Conv2d(32, 32, 3), 30 | nn.InstanceNorm2d(32), 31 | nn.ReLU(), 32 | nn.MaxPool2d(2), 33 | nn.Conv2d(32, 32, 3), 34 | nn.InstanceNorm2d(32), 35 | nn.ReLU() 36 | ) 37 | self.postprocess = nn.Linear(32, out_size) 38 | 39 | def schema(self): 40 | return self.data_type( 41 | logits=torch.zeros(15, dtype=torch.float), 42 | outputs=None 43 | ) 44 | 45 | def forward(self, inputs, **kwargs): 46 | result = self.blocks(inputs) 47 | result = func.adaptive_avg_pool2d(result, 1).view(result.size(0), -1) 48 | result = self.postprocess(result) 49 | return result 50 | 51 | mode = sys.argv[1] 52 | index = sys.argv[2] 53 | training = ... 54 | if mode == "bdpi": 55 | policy = Policy() 56 | value = MLP(4, 2, batch_norm=False, depth=3) 57 | agent = CategoricalPolicy(SharedModule(policy)) 58 | env = CoinRun() 59 | 60 | training = BDPITraining( 61 | policy, value, agent, env, 62 | network_name=f"awr-test/bdpi-coinrun-{index}", 63 | discount=0.99, 64 | clones=4, 65 | critic_updates=4, 66 | gradient_updates=20, 67 | batch_size=128, 68 | buffer_size=100_000, 69 | device="cuda:0", 70 | verbose=True 71 | ) 72 | 73 | elif mode == "awr": 74 | policy = Policy(in_size=3) 75 | value = Policy(in_size=3, out_size=1) 76 | agent = CategoricalPolicy(SharedModule(policy)) 77 | env = CoinRun(history=1) 78 | training = AWRTraining( 79 | policy, value, agent, env, 80 | network_name=f"awr-test/awr-coinrun-{index}", 81 | verbose=True, beta=0.05, 82 | auxiliary_steps=1, 83 | discount=0.990, 84 | clip=20, 85 | n_workers=8, 86 | device="cuda:0", 87 | batch_size=128, 88 | buffer_size=50_000, 89 | policy_steps=1 90 | ) 91 | 92 | elif mode == "awac": 93 | policy = Policy(in_size=3) 94 | value = Policy(in_size=3) 95 | agent = CategoricalPolicy(SharedModule(policy)) 96 | env = CoinRun(history=1) 97 | training = AWACTraining( 98 | policy, value, agent, env, 99 | network_name=f"awr-test/awac-coinrun-{index}", 100 | verbose=True, beta=1.0, 101 | auxiliary_steps=5, 102 | discount=0.990, 103 | clip=20, 104 | n_workers=8, 105 | device="cuda:0", 106 | batch_size=128, 107 | buffer_size=100_000, 108 | policy_steps=5 109 | ) 110 | 111 | training.train() 112 | -------------------------------------------------------------------------------- /examples/mnist_ddim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | from torch.nn.utils import spectral_norm 5 | from torch.utils.data import Dataset 6 | 7 | from torchvision.datasets import MNIST, CIFAR10 8 | from torchvision.transforms import ToTensor 9 | 10 | from torchsupport.modules.basic import MLP 11 | from torchsupport.training.samplers import Langevin 12 | from torchsupport.modules.unet import UNetBackbone, IgnoreArgs 13 | from torchsupport.training.denoising_diffusion import DenoisingDiffusionTraining 14 | from torchsupport.modules.attention import NonLocal 15 | 16 | def normalize(image): 17 | return (image - image.min()) / (image.max() - image.min()) 18 | 19 | class EnergyDataset(Dataset): 20 | def __init__(self, data): 21 | self.data = data 22 | 23 | def __getitem__(self, index): 24 | data, label_index = self.data[index] 25 | data = data + (torch.rand_like(data) - 0.5) / 256 26 | data = data.clamp(0.001, 0.999) 27 | data = 2 * data - 1 28 | label = torch.zeros(10) 29 | label[label_index] = 1 30 | return (data,) 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | class Denoiser(nn.Module): 36 | def __init__(self): 37 | super().__init__() 38 | self.input = MLP( 39 | 28 * 28 + 100, 28 * 28, 40 | hidden_size=128, depth=3, 41 | batch_norm=False 42 | ) 43 | 44 | def time_embedding(self, time): 45 | time = time.float()[:, None] 46 | return torch.cat([ 47 | (time / (idx + 1)).sin() 48 | for idx in range(100) 49 | ], dim=1) 50 | 51 | def forward(self, inputs, time): 52 | inputs = inputs.view(-1, 28 * 28) 53 | time = self.time_embedding(time) 54 | inputs = torch.cat((inputs, time), dim=1) 55 | out = self.input(inputs) 56 | return out.view(-1, 1, 28, 28) 57 | 58 | class UDenoiser(nn.Module): 59 | def __init__(self, in_size=3): 60 | super().__init__() 61 | self.project = nn.Conv2d(in_size, 64, 7, padding=3) 62 | self.bb = UNetBackbone( 63 | size_factors=[1, 1, 2, 2], activation=swish, 64 | base_size=64, kernel_size=5, depth=2, 65 | cond_size=100, hidden_size=64, 66 | norm=nn.InstanceNorm2d, 67 | hole=IgnoreArgs(NonLocal(2 * 64)) 68 | ) 69 | self.predict_bn = nn.InstanceNorm2d(64 * 2) 70 | self.predict = nn.Conv2d(64, in_size, 1) 71 | 72 | def time_embedding(self, time): 73 | time = time.float()[:, None] 74 | return torch.cat([ 75 | (time / (1000 ** (idx / 100))).sin() 76 | for idx in range(50) 77 | ] + [ 78 | (time / (1000 ** (idx / 100))).cos() 79 | for idx in range(50) 80 | ], dim=1) 81 | 82 | def forward(self, inputs, time): 83 | time = self.time_embedding(time) 84 | out = self.project(inputs) 85 | out = self.bb(out, time) 86 | out = self.predict_bn(out) 87 | out = self.predict(out) 88 | return out 89 | 90 | def swish(x): 91 | return x * x.sigmoid() 92 | 93 | if __name__ == "__main__": 94 | mnist = CIFAR10("examples/", download=False, transform=ToTensor()) 95 | data = EnergyDataset(mnist) 96 | 97 | denoiser = UDenoiser() 98 | 99 | training = DenoisingDiffusionTraining( 100 | denoiser, data, 101 | network_name="mnist-ddim/CIFAR-1", 102 | timesteps=1000, 103 | skipsteps=1, 104 | optimizer_kwargs=dict(lr=2e-4), 105 | device="cuda:0", 106 | batch_size=2, 107 | max_epochs=1000, 108 | report_interval=1000, 109 | checkpoint_interval=5000, 110 | verbose=True 111 | ) 112 | 113 | training.train() 114 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="torchsupport", 8 | version="0.0.1", 9 | author="Michael Jendrusch", 10 | author_email="jendrusch@stud.uni-heidelberg.de", 11 | description="Support for advanced pytorch usage.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/mjendrusch/torchsupport/", 15 | packages=setuptools.find_packages(), 16 | classifiers=( 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ), 21 | install_requires=['numpy', 22 | 'pandas', 23 | 'torch', 24 | 'tensorboardX', 25 | 'networkx', 26 | 'scikit-image', 27 | ] 28 | ) 29 | -------------------------------------------------------------------------------- /torchsupport/__init__.py: -------------------------------------------------------------------------------- 1 | name = "torchsupport" -------------------------------------------------------------------------------- /torchsupport/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/data/__init__.py -------------------------------------------------------------------------------- /torchsupport/data/asap_xml.py: -------------------------------------------------------------------------------- 1 | class PolygonAnnotation(object): 2 | def __init__(self, points): 3 | self.points = points 4 | 5 | def tile_at(self, position, size=(224, 224), origin=(0.5, 0.5)): 6 | pass ## TODO 7 | 8 | class SplineAnnotation(object): 9 | def __init__(self, points): 10 | self.points = points 11 | 12 | def tile_at(self, position, size=(224, 224), origin=(0.5, 0.5)): 13 | pass ## TODO 14 | 15 | class PointSetAnnotation(object): 16 | def __init__(self, points): 17 | self.points = points 18 | 19 | def tile_at(self, position, size=(224, 224), origin=(0.5, 0.5)): 20 | pass ## TODO 21 | 22 | class CoordinateAnnotation(object): 23 | def __init__(self, surface_dict): 24 | self.surface_dict = surface_dict 25 | self.n_classes = len(self.surface_dict) 26 | 27 | def tile_at(self, position, size=(224, 224), origin=(0.5, 0.5)): 28 | to_cat = [] 29 | for key in self.surface_dict: 30 | label_pixels = self.surface_dict[key].tile_at(position, size, origin) 31 | to_cat.append(label_pixels.unsqueeze(0)) 32 | return torch.cat(to_cat, dim=0) -------------------------------------------------------------------------------- /torchsupport/data/chem/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/data/chem/__init__.py -------------------------------------------------------------------------------- /torchsupport/data/chem/qm9.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as func 4 | from torch.utils.data import Dataset 5 | 6 | import torchsupport.modules.nodegraph as ng 7 | import torchsupport.data.graph as gdata 8 | 9 | class QM9(Dataset): 10 | def __init__(self): 11 | # self.data = 12 | 13 | -------------------------------------------------------------------------------- /torchsupport/data/namedtuple.py: -------------------------------------------------------------------------------- 1 | """Pickle-safe faux namedtuple action.""" 2 | 3 | from copy import copy 4 | 5 | from collections import OrderedDict 6 | 7 | class CheckArgs: 8 | def __init__(self, name, args): 9 | self.name = name 10 | self.fields = args 11 | self.sfields = sorted(args) 12 | 13 | def __call__(self, **kwargs): 14 | skeys = sorted(list(kwargs.keys())) 15 | if skeys != self.sfields: 16 | raise ValueError 17 | result = NamedTuple(**kwargs) 18 | return result 19 | 20 | class NamedTuple: 21 | def __init__(self, **kwargs): 22 | self.dict = OrderedDict(**kwargs) 23 | self.fields = list(self.dict.keys()) 24 | 25 | def asdict(self): 26 | return self.dict 27 | 28 | def __repr__(self): 29 | keyvals = [ 30 | f"{key}={self.dict[key]}" 31 | for key in self.dict 32 | ] 33 | keyvals = ", ".join(keyvals) 34 | result = f"NamedTuple({keyvals})" 35 | return result 36 | 37 | def __getattr__(self, name): 38 | result = ... 39 | if name in super().__getattribute__("fields"): 40 | result = super().__getattribute__("dict")[name] 41 | else: 42 | raise AttributeError 43 | return result 44 | 45 | def replace(self, **kwargs): 46 | result = copy(self) 47 | for key in kwargs: 48 | result.dict[key] = kwargs[key] 49 | return result 50 | 51 | def __getitem__(self, index): 52 | if index < len(self): 53 | return self.dict[self.fields[index]] 54 | else: 55 | raise IndexError 56 | 57 | def __len__(self): 58 | return len(self.dict) 59 | 60 | def __iter__(self): 61 | return ( 62 | self.dict[key] 63 | for key in self.dict 64 | ) 65 | 66 | namespace = NamedTuple 67 | 68 | def namedtuple(name, fields): 69 | return CheckArgs(name, fields) 70 | -------------------------------------------------------------------------------- /torchsupport/data/roi_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader, Sampler 3 | import openslide 4 | import random 5 | import numpy as np 6 | from read_roi import read_roi_zip 7 | from skimage.draw import polygon 8 | from PIL import Image, ImageSequence 9 | 10 | class RoiImage(Dataset): 11 | def __init__(self, path, size=(226, 226), transform=lambda x: x): 12 | self.transform = transform 13 | with Image.open(path + ".tif") as stack: 14 | frames = [] 15 | for img in ImageSequence.Iterator(stack): 16 | frame = torch.tensor(np.array(img).astype(float)) 17 | frames.append(frame.unsqueeze(0)) 18 | self.raw_image = torch.cat(frames, dim=0) 19 | rois = read_roi_zip(path + ".roi.zip") 20 | self.rois = [ 21 | torch.tensor(zip(*polygon( 22 | roi[1]["x"], roi[1]["y"], 23 | shape=(self.raw_image.size(1), 24 | self.raw_image.size(2)) 25 | )), dtype=torch.long) 26 | for roi in rois 27 | ] 28 | 29 | def __getitem__(self, idx): 30 | pass # TODO 31 | -------------------------------------------------------------------------------- /torchsupport/data/slides.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader, Sampler, ConcatDataset 3 | import openslide 4 | import random 5 | import numpy as np 6 | 7 | class SlideImage(object): 8 | def __init__(self, path): 9 | self.slide = openslide.OpenSlide(path) 10 | 11 | def _tile_at_impl(self, position, level, size, origin): 12 | to_cat = [] 13 | for lv in level: 14 | size_off = ( 15 | size[0] * self.slide.level_downsample[lv], 16 | size[1] * self.slide.level_downsample[lv] 17 | ) 18 | start = ( 19 | int(position[0] - origin[0] * size_off[0]), 20 | int(position[1] - origin[1] * size_off[1]) 21 | ) 22 | image = self.slide.read_region(start, level, size) 23 | image = np.array(image).astype(type) 24 | image = np.transpose(image,(0,1,2)) 25 | image = torch.from_numpy(image) 26 | to_cat.append(image.unsqueeze(0)) 27 | result = torch.cat(to_cat, dim=0) 28 | return result 29 | 30 | def tile_at(self, position, level=0, size=(224, 224), origin=(0.5, 0.5)): 31 | if isinstance(level, list) or isinstance(level, tuple): 32 | return self._tile_at_impl(position, level, size, origin) 33 | else: 34 | return self._tile_at_impl(position, [level], size, origin) 35 | 36 | def regular_tiling(self, level=0, size=(224, 224)): 37 | dimensions = self.slide.dimensions 38 | n_tiles = (dimensions[0] // size[0], dimensions[1] // size[1]) 39 | for idx in range(n_tiles[0]): 40 | x_pos = idx * size[0] 41 | for idy in range(n_tiles[1]): 42 | y_pos = idy * size[1] 43 | yield self.tile_at((x_pos, y_pos), level=level, size=size, origin=(0, 0)) 44 | 45 | def random_tiling(self, count, level=0, size=(224, 224)): 46 | dimensions = self.slide.dimensions 47 | lower_x, lower_y = 0, 0 48 | upper_x, upper_y = dimensions[0] - size[0], dimensions[1] - size[1] 49 | for idx in range(count): 50 | rand_x, rand_y = random.randint(lower_x, upper_x), random.randint(lower_y, upper_y) 51 | yield self.tile_at((rand_x, rand_y), level=level, size=size, origin=(0, 0)) 52 | 53 | class SingleSlideData(Dataset): 54 | def __init__(self, path, size=(224, 224), level=0, transform=lambda x: x): 55 | self.transform = transform 56 | self.slide = SlideImage(path) 57 | self.dims = self.slide.slide.dimensions 58 | self.dims = (dims[0] - size[0], dims[1] - size[1]) 59 | self.size = size 60 | self.level = level 61 | 62 | def __len__(self): 63 | return self.dims[0] * self.dims[1] 64 | 65 | def __getitem__(self, index): 66 | x_pos = index // self.dims[0] 67 | y_pos = index % self.dims[0] 68 | tile = self.slide.tile_at((x_pos, y_pos), level=self.level, size=self.size, origin=(0, 0)) 69 | if self.transform != None: 70 | tile = self.transform(tile) 71 | return tile 72 | 73 | def MultiSlideData(self, paths, size=(224, 224), level=0, transform=lambda x: x): 74 | datasets = [] 75 | for path in paths: 76 | datasets.append(SingleSlideData(path, size=size, level=level, transform=transform)) 77 | return ConcatDataset(datasets) 78 | -------------------------------------------------------------------------------- /torchsupport/data/structured.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from torchsupport.modules.structured import connected_entities as ce 8 | from torchsupport.data.graphio import LazyNodes, LazyAdjacency 9 | 10 | class LazySubgraphDataset(Dataset): 11 | def __init__(self, path, node_name, edge_names, depth=3): 12 | self.depth = depth 13 | self.path = path 14 | self.node_name = node_name 15 | self.nodes = [] 16 | self.adjacencies = [] 17 | for root, _, names in os.walk(path): 18 | for name in names: 19 | if name.endswith(f"{node_name}.node"): 20 | base = ".".join(name.split(".")[:-2]) 21 | self.nodes.append(LazyNodes(os.path.join(root, f"{base}.{node_name}.node"))) 22 | self.adjacencies.append({ 23 | name: LazyAdjacency(os.path.join(root, f"{base}.{name}.struct")) 24 | for name in edge_names 25 | }) 26 | 27 | def __len__(self): 28 | return len(self.nodes) 29 | 30 | def __getitem__(self, idx): 31 | start_node = random.randint(0, len(self.nodes[idx])) 32 | adjacencies = [ 33 | self.adjacencies[idx][name].materialize() 34 | for name in self.adjacencies[idx] 35 | ] 36 | reachable = ce.ConnectionStructure.reachable_nodes( 37 | {self.node_name: set([start_node])}, 38 | adjacencies, 39 | depth=self.depth 40 | ) 41 | reachable = list(reachable[self.node_name]) 42 | node_tensor = self.nodes[idx].materialize(reachable) 43 | adjacencies = [ 44 | adj.select(reachable) 45 | for adj in adjacencies 46 | ] 47 | return node_tensor, adjacencies 48 | -------------------------------------------------------------------------------- /torchsupport/data/tensor_provider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class TensorProvider(): 4 | def tensors(self): 5 | raise NotImplementedError("Abstract.") 6 | -------------------------------------------------------------------------------- /torchsupport/deprecated/rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/deprecated/rl/__init__.py -------------------------------------------------------------------------------- /torchsupport/deprecated/rl/agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class Agent(nn.Module): 6 | def sample(self, logits): 7 | raise NotImplementedError 8 | 9 | def forward(self, state, inputs=None): 10 | raise NotImplementedError 11 | 12 | class MultiAgent(Agent): 13 | def __init__(self, agents): 14 | super().__init__() 15 | self.agents = agents 16 | 17 | def sample(self, logits): 18 | actions = [] 19 | for idx, agent in enumerate(self.agents): 20 | if logits[idx] is None: 21 | actions.append(None) 22 | else: 23 | actions.append(agent.sample(logits[idx])) 24 | return actions 25 | 26 | def forward(self, state, inputs=None): 27 | inputs = inputs or [None for _ in self.agents] 28 | logits = [] 29 | outputs = [] 30 | for idx, agent in enumerate(self.agents): 31 | if state[idx] is None: 32 | logits.append(None) 33 | outputs.append(inputs[idx]) 34 | else: 35 | agent_logits, agent_outputs = agent(state[idx], inputs=inputs[idx]) 36 | logits.append(agent_logits) 37 | outputs.append(agent_outputs) 38 | return logits, outputs 39 | -------------------------------------------------------------------------------- /torchsupport/deprecated/rl/bdpi.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as func 6 | 7 | from torchsupport.rl.trajectory import Experience 8 | from torchsupport.rl.agent import Agent 9 | from torchsupport.rl.off_policy import OffPolicyTraining 10 | 11 | class ABCDQNAgent(Agent): 12 | def __init__(self, actor, critic): 13 | super().__init__() 14 | self.value = critic 15 | self.target = deepcopy(agent) 16 | 17 | def sample(self, logits): 18 | condition = bool(torch.rand(1)[0] < self.eps) 19 | if condition: 20 | logits = torch.rand_like(logits) 21 | return self.value.sample(logits) 22 | 23 | def forward(self, data, inputs=None): 24 | return self.value(data) 25 | 26 | def update(self): 27 | with torch.no_grad(): 28 | tp = self.target.parameters() 29 | ap = self.value.parameters() 30 | for t, a in zip(tp, ap): 31 | t *= (1 - self.tau) 32 | t += self.tau * a 33 | 34 | class BDPITraining(OffPolicyTraining): 35 | def __init__(self, actor, critic, environment, n_critic=10, discount=0.99, **kwargs): 36 | agent = ABCDQNAgent(actor, critic, n_critic=n_critic) 37 | self.discount = discount 38 | super().__init__(agent, environment, **kwargs) 39 | 40 | def update(self, *data): 41 | super().update(*data) 42 | self.agent.update() 43 | self.anneal_epsilon() 44 | 45 | def target(self, experience): 46 | with torch.no_grad(): 47 | observation = experience.final_state 48 | reward = experience.reward 49 | terminal = experience.terminal 50 | 51 | # maximum value: 52 | value, _ = self.agent.target(observation) 53 | prediction, _ = self.agent.value(observation) 54 | values = value[ 55 | torch.arange(0, prediction.size(0)), 56 | prediction.argmax(dim=1) 57 | ] 58 | 59 | return reward + (1 - terminal.float()) * self.discount * values 60 | 61 | def anneal_epsilon(self): 62 | self.agent.eps = min(torch.tensor(0.1), self.agent.eps - self.step_id * 0.01) 63 | 64 | def run_networks(self, experience): 65 | observation = experience.initial_state 66 | action = experience.action 67 | value, _ = self.agent.value(observation) 68 | value = value[torch.arange(0, value.size(0)), action] 69 | 70 | target = self.target(experience) 71 | return value, target 72 | 73 | def loss(self, value, target): 74 | result = func.mse_loss(value, target) 75 | return result 76 | -------------------------------------------------------------------------------- /torchsupport/deprecated/rl/environment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | from torchsupport.rl.trajectory import Experience 6 | 7 | class Environment: 8 | def reset(self): 9 | raise NotImplementedError 10 | 11 | def action_space(self): 12 | raise NotImplementedError 13 | 14 | def observation_space(self): 15 | raise NotImplementedError 16 | 17 | def is_done(self): 18 | raise NotImplementedError 19 | 20 | def observe(self): 21 | raise NotImplementedError 22 | 23 | def act(self, action): 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /torchsupport/deprecated/rl/memory.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/deprecated/rl/memory.py -------------------------------------------------------------------------------- /torchsupport/deprecated/rl/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | from torchsupport.rl.trajectory import Trajectory, Experience 6 | 7 | class Sampler: 8 | def __init__(self, agent, environment): 9 | self.agent = agent 10 | self.trajectories = [] 11 | self.environment = environment 12 | 13 | def step(self, inputs=None): 14 | initial_state = self.environment.observe() 15 | logits, output = self.agent( 16 | initial_state.unsqueeze(0), 17 | inputs 18 | ) 19 | logits = logits[0] 20 | output = output if output is None else output[0] 21 | action = self.agent.sample(logits) 22 | reward = self.environment.act(action) 23 | terminal = int(self.environment.is_done()) 24 | final_state = self.environment.observe() 25 | 26 | return Experience( 27 | initial_state, final_state, action, reward, 28 | terminal=terminal, logits=logits, outputs=output 29 | ) 30 | 31 | def sample_episode(self, kind=Trajectory): 32 | with torch.no_grad(): 33 | self.environment.reset() 34 | trajectory = kind() 35 | inputs = None 36 | while not self.environment.is_done(): 37 | experience = self.step(inputs=inputs) 38 | inputs = experience.outputs 39 | trajectory.append(experience) 40 | trajectory.complete() 41 | return trajectory 42 | 43 | class TaskSampler: 44 | def __init__(self, policies, environment): 45 | self.policies = policies 46 | self.trajectories = [] 47 | self.environment = environment 48 | 49 | def sample(self, agent_output, inputs=None): 50 | raise NotImplementedError("Abstract.") 51 | 52 | def run_step(self, initial_state, inputs=None): 53 | raise NotImplementedError("Abstract.") 54 | 55 | def step(self, inputs=None): 56 | initial_state = self.environment.observe() 57 | 58 | agent_output, agent_state = self.run_step(initial_state, inputs) 59 | 60 | action = self.sample(agent_output, agent_state) 61 | reward = self.environment.act(action) 62 | terminal = int(self.environment.is_done()) 63 | final_state = self.environment.observe() 64 | 65 | return Experience( 66 | initial_state, final_state, action, reward, 67 | terminal=terminal, logits=agent_output, outputs=agent_state 68 | ) 69 | -------------------------------------------------------------------------------- /torchsupport/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as dist 3 | from torchsupport.distributions.mixture import Mixture 4 | from torchsupport.distributions.von_mises import VonMises 5 | from torchsupport.distributions.standard import StandardNormal 6 | from torchsupport.distributions.structured import DistributionList 7 | from torchsupport.distributions.vae_distribution import VAEDistribution 8 | # from torchsupport.distributions.modifiers import fixed, hardened 9 | from torchsupport.distributions.kl_divergence import kl_relaxed_one_hot_categorical 10 | 11 | def _harden_one_hot(self, inputs): 12 | hard = torch.zeros_like(inputs) 13 | hard_index = inputs.argmax(dim=-1) 14 | hard[torch.arange(0, hard.size(0)), hard_index] = 1.0 15 | return hard.detach() 16 | 17 | def _harden_bernoulli(self, inputs): 18 | logits = torch.log(inputs / (1 - inputs + 1e-16) + 1e-16) 19 | return (logits > 0).float().detach() 20 | 21 | def _hard_categorical(self, dist): 22 | return dist.OneHotCategorical(logits=dist.logits) 23 | 24 | def _hard_bernoulli(self, dist): 25 | return dist.Bernoulli(logits=dist.logits) 26 | 27 | def _condtitional_categorical(self, hard): 28 | noise = -torch.log(torch.rand_like(self.logits) + 1e-16) 29 | on_condition = noise * hard 30 | off_condition = noise * (1 - hard) 31 | offset = on_condition.view(-1, hard.size(-1)).sum(dim=-1).view(*hard.shape[:-1], 1) 32 | off_condition = off_condition / (self.probs + 1e-16) - offset 33 | soft_conditional = -torch.log(on_condition + off_condition + 1e-16) 34 | return soft_conditional 35 | 36 | def _conditional_bernoulli(self, hard): 37 | noise = torch.rand_like(hard) 38 | on_condition = noise * hard 39 | off_condition = noise * (1 - hard) 40 | on_condition = on_condition * self.probs + (1 - self.probs) 41 | off_condition = off_condition * (1 - self.probs) 42 | total = on_condition + off_condition 43 | soft_conditional = torch.log(self.probs / (1 - self.probs + 1e-16) + 1e-16) 44 | soft_conditional += torch.log(total / (1 - total + 1e-16) + 1e-16) 45 | return soft_conditional 46 | 47 | setattr(dist.RelaxedOneHotCategorical, "harden", _harden_one_hot) 48 | setattr(dist.RelaxedOneHotCategorical, "hard_distribution", _hard_categorical) 49 | setattr(dist.RelaxedOneHotCategorical, "conditional_rsample", _condtitional_categorical) 50 | 51 | setattr(dist.RelaxedBernoulli, "harden", _harden_bernoulli) 52 | setattr(dist.RelaxedBernoulli, "hard_distribution", _hard_bernoulli) 53 | setattr(dist.RelaxedBernoulli, "conditional_rsample", _conditional_bernoulli) 54 | -------------------------------------------------------------------------------- /torchsupport/distributions/kl_divergence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import RelaxedOneHotCategorical, Categorical, kl_divergence, register_kl 3 | 4 | @register_kl(RelaxedOneHotCategorical, RelaxedOneHotCategorical) 5 | def kl_relaxed_one_hot_categorical(p, q): 6 | p = Categorical(probs=p.probs) 7 | q = Categorical(probs=q.probs) 8 | return kl_divergence(p, q) 9 | -------------------------------------------------------------------------------- /torchsupport/distributions/mixture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.jit 3 | from torch.distributions import constraints, Categorical, RelaxedOneHotCategorical 4 | from torch.distributions.distribution import Distribution 5 | 6 | class Mixture(Distribution): 7 | has_rsample = False 8 | def __init__(self, distributions, weights): 9 | self.distributions = distributions 10 | if all(map(lambda x: x.has_rsample, self.distributions)): 11 | self.has_rsample = True 12 | self.weights = weights 13 | 14 | def log_prob(self, value): 15 | result_exp = 0.0 16 | for weight, distribution in zip(self.weights.permute(1, 0), self.distributions): 17 | prob = distribution.log_prob(value).exp() 18 | result_exp += weight[:, None] * prob 19 | result = torch.log(result_exp + 1e-6) 20 | return result 21 | 22 | def sample(self): 23 | samples = [] 24 | for distribution in self.distributions: 25 | sample = distribution.sample().unsqueeze(0) 26 | samples.append(sample) 27 | samples = torch.cat(samples, dim=0) 28 | choice = Categorical(probs=self.weights) 29 | choice = choice.sample() 30 | result = samples[choice, torch.arange(samples.size(1))] 31 | return result 32 | 33 | def rsample(self): 34 | if not self.has_rsample: 35 | raise NotImplementedError("Mixture does not support rsample.") 36 | samples = [] 37 | for distribution in self.distributions: 38 | sample = distribution.rsample().unsqueeze(0) 39 | samples.append(sample) 40 | samples = torch.cat(samples, dim=0) 41 | expand = samples.dim() - 2 42 | choice = RelaxedOneHotCategorical(probs=self.weights, temperature=0.1) 43 | choice = choice.rsample().permute(1, 0) 44 | choice = choice.view(choice.size(0), choice.size(1), *expand) 45 | result = (samples * choice).sum(dim=0) 46 | return result 47 | -------------------------------------------------------------------------------- /torchsupport/distributions/modifiers.py: -------------------------------------------------------------------------------- 1 | import types 2 | import torch 3 | from torchsupport.modules.gradient import replace_gradient 4 | 5 | def fixed(distribution, sample): 6 | def return_sample(dist, sample_shape=torch.Size()): 7 | return sample 8 | distribution.rsample = types.MethodType(return_sample, distribution) 9 | distribution.sample = types.MethodType(return_sample, distribution) 10 | return distribution 11 | 12 | def hardened(distribution): 13 | def harden_sample(dist, sample_shape=torch.Size()): 14 | result = dist._original_rsample(sample_shape=sample_shape) 15 | return replace_gradient(dist.harden(result), result) 16 | distribution._original_rsample = distribution.rsample 17 | distribution.rsample = types.MethodType(harden_sample, distribution) 18 | distribution.sample = types.MethodType(harden_sample, distribution) 19 | return distribution 20 | -------------------------------------------------------------------------------- /torchsupport/distributions/standard.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Distribution, Normal, RelaxedBernoulli, RelaxedOneHotCategorical 3 | 4 | def StandardNormal(size): 5 | return Normal( 6 | torch.zeros(size), 7 | torch.ones(size) 8 | ) 9 | -------------------------------------------------------------------------------- /torchsupport/distributions/structured.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions.distribution import Distribution 3 | from torch.distributions.kl import register_kl 4 | 5 | from torchsupport.data.match import Matchable, match 6 | 7 | class DistributionList(Distribution): 8 | has_rsample = True 9 | def __init__(self, items): 10 | self.items = items 11 | 12 | def match(self, other): 13 | result = 0.0 14 | for s, o in zip(self.items, other.items): 15 | match_result = match(s, o) 16 | result = result + match_result 17 | return result 18 | 19 | def log_prob(self, value): 20 | log_prob = 0.0 21 | for dist, val in zip(self.items, value): 22 | current = dist.log_prob(val) 23 | current = current.view(current.size(0), -1).sum(dim=1) 24 | log_prob = log_prob + current 25 | return log_prob 26 | 27 | def sample(self, sample_shape=torch.Size()): 28 | return [ 29 | dist.sample(sample_shape=sample_shape) 30 | for dist in self.items 31 | ] 32 | 33 | def rsample(self, sample_shape=torch.Size()): 34 | return [ 35 | dist.rsample(sample_shape=sample_shape) 36 | for dist in self.items 37 | ] 38 | -------------------------------------------------------------------------------- /torchsupport/distributions/vae_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Distribution 4 | 5 | class VAEDistribution(nn.Module, Distribution): 6 | def __init__(self, encoder, decoder, prior=None): 7 | super().__init__() 8 | self.encoder = encoder 9 | self.decoder = decoder 10 | self.prior = prior 11 | 12 | def forward(self, inputs): 13 | q = self.encoder(inputs) 14 | sample = q.rsample() 15 | d = self.decoder(sample) 16 | rec = d.log_prob(inputs).view(inputs.size(0), -1).sum(dim=1, keepdim=True) 17 | log_p = self.prior.log_prob(sample).view(inputs.size(0), -1) 18 | log_q = q.log_prob(sample).view(inputs.size(0), -1) 19 | kl = -log_p.sum(dim=1, keepdim=True) + log_q.sum(dim=1, keepdim=True) 20 | return rec - kl 21 | 22 | def log_prob(self, x): 23 | return self(x) 24 | 25 | def rsample(self, sample_shape=torch.Size()): 26 | prior = self.prior.rsample(sample_shape=sample_shape) 27 | decoded = self.decoder(prior) 28 | return decoded.rsample() 29 | 30 | def sample(self, sample_shape=torch.Size()): 31 | with torch.no_grad(): 32 | return self.rsample(sample_shape=sample_shape) 33 | -------------------------------------------------------------------------------- /torchsupport/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/experimental/__init__.py -------------------------------------------------------------------------------- /torchsupport/experimental/apps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/experimental/apps/__init__.py -------------------------------------------------------------------------------- /torchsupport/experimental/enas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/experimental/enas/__init__.py -------------------------------------------------------------------------------- /torchsupport/experimental/gan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class Adversarial(nn.Module): 6 | def __init__(self, generator, discriminator): 7 | self.generator = generator 8 | self.discriminator = discriminator 9 | 10 | def generate(self, *input): 11 | return self.generator(*input) 12 | 13 | def check(self, *check_input): 14 | return self.discriminator(*check_input) 15 | 16 | def train(net, loss, optimizer, gen_inputs, real_inputs, setsize=10, epochs=1000): 17 | for epoch in range(epochs): 18 | optimizer.zero_grad() 19 | batch_results = [] 20 | batch_labels = [] 21 | for idx in range(setsize): 22 | gen_input = next(gen_inputs) 23 | real_input = next(real_inputs) 24 | batch_results.append(net.generator(*gen_input).unsqueeze(0)) 25 | batch_results.append(real_input) 26 | batch_labels.append(torch.tensor([[[0]]])) 27 | batch_labels.append(torch.tensor([[[1]]])) 28 | batch_tensor = torch.cat(batch_results, dim=0).to(device) 29 | label_tensor = torch.cat(batch_labels, dim=0).to(device) 30 | discriminator_output = net.discriminator(batch_tensor) 31 | loss_val = loss(batch_results, discriminator_output, label_tensor) 32 | loss_val.backwards() 33 | optimizer.step() 34 | return loss_val.item() -------------------------------------------------------------------------------- /torchsupport/experimental/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/experimental/losses/__init__.py -------------------------------------------------------------------------------- /torchsupport/experimental/losses/instance_segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class InstanceClusterLoss(nn.Module): 5 | def __init__(self, var=1.0, dist=1.0, reg=1.0, 6 | distance_to_center=0.1, 7 | distance_between_centers=0.2): 8 | super(InstanceClusterLoss, self).__init__() 9 | self.var = var 10 | self.dist = dist 11 | self.reg = reg 12 | self.dtc = distance_to_center 13 | self.dbc = distance_between_centers 14 | 15 | def _var_loss(self, prediction, target): 16 | loss_val = torch.tensor(0.0) 17 | mean_embeddings = [] 18 | for polygon in target: 19 | polygon_values = prediction[polygon] 20 | mean_embedding = polygon_values.mean() 21 | polygon_loss = max(torch.norm(polygon_values - mean_embedding) - self.dtc, 0.0) ** 2 22 | polygon_loss /= polygon.sum() 23 | loss_val += polygon_loss 24 | mean_embeddings.append(mean_embedding) 25 | loss_val /= len(target) 26 | return loss_val, mean_embeddings 27 | 28 | def _dist_loss(self, mean_embeddings): 29 | losses = [ 30 | max(2 * self.dbc - torch.norm(e_a - e_b), 0) ** 2 31 | for ida, e_a in enumerate(mean_embeddings) 32 | for idb, e_b in enumerate(mean_embeddings) 33 | if ida != idb 34 | ] 35 | N = len(mean_embeddings) 36 | return sum(losses) / (N * (N - 1)) 37 | 38 | def _reg_loss(self, mean_embeddings): 39 | return sum(map(torch.norm, mean_embeddings)) / len(mean_embeddings) 40 | 41 | def forward(self, prediction, target): 42 | var_loss, mean_embeddings = self._var_loss(prediction, target) 43 | reg_loss = self._reg_loss(mean_embeddings) 44 | dist_loss = self._dist_loss(mean_embeddings) 45 | return self.var * var_loss + self.dist * dist_loss + self.reg * reg_loss 46 | -------------------------------------------------------------------------------- /torchsupport/experimental/losses/vae.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/experimental/losses/vae.py -------------------------------------------------------------------------------- /torchsupport/experimental/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/experimental/modules/__init__.py -------------------------------------------------------------------------------- /torchsupport/experimental/modules/fast_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | import numpy as np 5 | 6 | def kmeans(input, n_clusters=16, tol=1e-6): 7 | """ 8 | TODO: check correctness 9 | """ 10 | indices = torch.Tensor(np.random.choice(input.size(-1), n_clusters)) 11 | values = input[:, :, indices] 12 | 13 | while True: 14 | dist = func.pairwise_distance( 15 | input.unsqueeze(2).expand(-1, -1, values.size(2), input.size(2)).reshape( 16 | input.size(0), input.size(1), input.size(2) * values.size(2)), 17 | values.unsqueeze(3).expand(-1, -1, values.size(2), input.size(2)).reshape( 18 | input.size(0), input.size(1), input.size(2) * values.size(2)) 19 | ) 20 | choice_cluster = torch.argmin(dist, dim=1) 21 | old_values = values 22 | values = input[choice_cluster.nonzeros()] 23 | shift = (old_values - values).norm(dim=1) 24 | if shift.max() ** 2 < tol: 25 | break 26 | 27 | return values 28 | 29 | 30 | def gaussian_kernel(x, sigma=4): 31 | """Gaussian distance kernel. 32 | 33 | Args: 34 | x (Tensor): difference between two input features. 35 | sigma (float): standard deviation of the gaussian kernel. 36 | 37 | Returns: 38 | The Gaussian kernel for features `x` and standard deviation `sigma`. 39 | """ 40 | return torch.exp(- (x ** 2).sum(dim=-1) / sigma) 41 | 42 | class FHDF2d(nn.Module): 43 | def __init__(self, spatial_kernel=gaussian_kernel, feature_kernel=gaussian_kernel, clustering=kmeans): 44 | """Performs fast high-dimensional filtering using clustering. 45 | 46 | Args: 47 | spatial_kernel (callable): spatial distance kernel used for filtering. 48 | Defaults to `gaussian_kernel`. 49 | feature_kernel (callable): feature distance kernel used for filtering. 50 | Defaults to `gaussian_kernel`. 51 | clustering (callable): clustering algorithm used for filter approximation. 52 | Defaults to `kmeans`. 53 | 54 | Note: 55 | This computes a _dense_ filter over a 2D image. For a sparse filter, methods based 56 | on sparse matrix-matrix multiplication could provide better performance. 57 | """ 58 | super(FHDF2d, self).__init__() 59 | self.spatial_kernel = spatial_kernel 60 | self.feature_kernel = feature_kernel 61 | self.clustering = clustering 62 | 63 | def forward(self, input, guide): 64 | """ 65 | TODO: check correctness 66 | """ 67 | padding = (input.size(-2) // 2, input.size(-1) // 2) 68 | clusters = self.clustering(guide.reshape(guide.size(0), guide.size(1), -1)) 69 | A = torch.Tensor([[self.feature_kernel(k - l) for l in clusters] for k in clusters]) 70 | Ad = A.pinverse() 71 | bk = self.feature_kernel(clusters - guide) 72 | ck = Ad.mv(bk) 73 | omega = self.spatial_kernel(input.size(-2), input.size(-1)) 74 | phi = self.feature_kernel(guide - clusters) 75 | phi_f = phi * input 76 | rk = func.conv2d(omega, phi, padding=padding) 77 | vk = func.conv2d(omega, phi_f, padding=padding) 78 | eta = (ck * rk).sum(dim=1) 79 | result = (ck * vk).sum(dim=1) / eta 80 | return result 81 | -------------------------------------------------------------------------------- /torchsupport/experimental/modules/structured/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/experimental/modules/structured/__init__.py -------------------------------------------------------------------------------- /torchsupport/experimental/rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/experimental/rl/__init__.py -------------------------------------------------------------------------------- /torchsupport/experimental/rl/trajectory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | 4 | class Trajectory(object): 5 | def __init__(self, states=[], probabilities=[], choices=[]): 6 | """Reinforcement learning agent trajectory. 7 | 8 | Args: 9 | states (list): trajectory states. 10 | probabilities (list): trajectory probabilities. 11 | choices (list): trajectory choices. 12 | """ 13 | self._states = states 14 | self._probabilities = probabilities 15 | self._choices = choices 16 | 17 | def __len__(self): 18 | return len(self._states) 19 | 20 | def __getitem__(self, idx): 21 | if isinstance(idx, slice): 22 | return Trajectory( 23 | states=self._states[idx], 24 | probabilities=self._probabilities[idx], 25 | choices=self._choices[idx] 26 | ) 27 | else: 28 | return ( 29 | self._states[idx], 30 | self._probabilities[idx], 31 | self._choices[idx] 32 | ) 33 | 34 | @property 35 | def states(self): 36 | return self._states 37 | 38 | @property 39 | def probabilities(self): 40 | prob = [] 41 | for idx, choice in enumerate(self._choices): 42 | prob.append(self._probabilities[idx][choice]) 43 | return torch.cat(prob) 44 | 45 | @property 46 | def choices(self): 47 | return torch.cat(self._choices) 48 | 49 | def append(self, state, probability, choice): 50 | self._states.append(state) 51 | self._probabilities.append(probability) 52 | self._choices.append(choice) 53 | -------------------------------------------------------------------------------- /torchsupport/experimental/trainables/unsupervised_segmentation/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | import parser 6 | import model 7 | 8 | opt = parser.parse() 9 | 10 | if opt.train: 11 | # TODO: train 12 | elif opt.eval: 13 | # TODO: evaluate 14 | -------------------------------------------------------------------------------- /torchsupport/experimental/trainables/unsupervised_segmentation/parser.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | def parse(): 4 | parser = ArgumentParser(description="unsupervised semantic segmentation.") 5 | 6 | # general training settings: 7 | parser.add_argument('--path', type=str, required=True, help="path to input images.") 8 | parser.add_argument('--data_type', type=str, default="ndpi", help="input image type.") 9 | parser.add_argument('--train', action="store_true", help="train on input images?") 10 | parser.add_argument('--eval', action="store_true", help="evaluate network on input images?") 11 | parser.add_argument('--cuda', action="store_true", help="run using cuda?") 12 | parser.add_argument('--on_gpu', type=int, default=0, help="bind to given GPU.") 13 | parser.add_argument('--batch', type=int, default=64, help="training batch size.") 14 | parser.add_argument('--epochs', type=int, default=50, help="number of training epochs.") 15 | parser.add_argument('--lr', type=float, default=0.001, help="learning rate.") 16 | parser.add_argument('--threads', type=int, default=16, help="data loader threads.") 17 | parser.add_argument('--seed', type=int, default=42, help="random seed, defaults to 42.") 18 | 19 | # general arch settings: 20 | parser.add_argument('--unsupervision', type=str, default="SegmenterDecoder", 21 | choices=["SegmenterDecoder", "ResidualDecoder", "MultiDecoder"], 22 | help="type of unsupervised architecture. Defaults to a WNet-style segmenter-decoder.") 23 | parser.add_argument('--arch', type=str, default="UNet", choices=["UNet", "AutofocusNet", "SplitNet"], 24 | help="architecture to be used. Defaults to UNet.") 25 | parser.add_argument('--regularization', type=str, default="all", 26 | help="regularization for more natural segmentation.") 27 | parser.add_argument('--max_classes', type=int, default=4, help="maximum number of different classes.") 28 | 29 | # UNet settings: 30 | parser.add_argument('--depth', type=int, default=4, help="UNet depth.") 31 | parser.add_argument('--multiscale', action="store_true", help="use multi-scale model?") 32 | parser.add_argument('--dilate', action="store_true", help="use dilated convolutions?") 33 | parser.add_argument('--attention', action="store_true", help="use attention gate?") 34 | 35 | opt = parser.parse_args() 36 | return opt 37 | -------------------------------------------------------------------------------- /torchsupport/experimental/trainables/unsupervised_segmentation/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | def training_loop_wnet(data, valid_data, encoder, decoder, optimizer, 6 | regularization, reconstruction_loss, 7 | validate_every=10, checkpoint_every=10, 8 | report_every=10, board_prefix=pwd()): 9 | valid_regularization_value = None 10 | valid_loss_value = None 11 | for epoch in range(num_epochs): 12 | for idx, batch in enumerate(data): 13 | # step 1 14 | optimizer.zero_grad() 15 | encoder_out = encoder(batch) 16 | loss = regularization(encoder_out) 17 | loss.backward() 18 | optimizer.step() 19 | 20 | regularization_value = loss.item() 21 | 22 | # step 2 23 | optimizer.zero_grad() 24 | decoder_out = decoder(encoder_out) 25 | loss = reconstruction_loss(decoder_out, batch) 26 | loss.backward() 27 | optimizer.step() 28 | 29 | loss_value = loss.item() 30 | 31 | if idx % validate_every == 0: 32 | valid_batch = next(valid_data) 33 | out = encoder(valid_batch) 34 | valid_regularization_value = regularization(out).item() 35 | out = decoder(out) 36 | valid_loss_value = reconstruction_loss(out, batch).item() 37 | 38 | if idx % checkpoint_every == 0: 39 | ... # TODO 40 | 41 | if idx % report_every == 0: 42 | ... # TODO 43 | 44 | return encoder, decoder 45 | 46 | def training_loop_residual_reconstruction(data, encoder, optimizer, board_prefix=pwd()): 47 | pass # TODO 48 | 49 | def training_loop_multi_decoder(data, encoder, decoder, optimizer, board_prefix=pwd()): 50 | pass # TODO -------------------------------------------------------------------------------- /torchsupport/flex/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | import torchsupport.flex.checkpointing.savable 2 | import torchsupport.flex.checkpointing.savable_extensions 3 | import torchsupport.flex.checkpointing.checkpoint 4 | -------------------------------------------------------------------------------- /torchsupport/flex/checkpointing/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | from torchsupport.data.io import netread, netwrite 6 | from torchsupport.flex.checkpointing.savable import Savable, SaveStateError 7 | 8 | class Checkpoint: 9 | def __init__(self, ctx): 10 | self.ctx = ctx 11 | self.checkpoint_names = {} 12 | self.save_names = {"context": ctx} 13 | 14 | def add_checkpoint(self, **kwargs): 15 | self.checkpoint_names.update(kwargs) 16 | self.save_names.update(kwargs) 17 | 18 | def remove_checkpoint(self, *names): 19 | for name in names: 20 | if name in self.checkpoint_names: 21 | del self.checkpoint_names[name] 22 | del self.save_names[name] 23 | 24 | def add_save(self, **kwargs): 25 | self.save_names.update(kwargs) 26 | 27 | def save_path(self): 28 | return f"{self.ctx.path}-save.torch" 29 | 30 | def checkpoint(self): 31 | for name, the_net in self.checkpoint_names.items(): 32 | if isinstance(the_net, torch.nn.DataParallel): 33 | the_net = the_net.module 34 | netwrite( 35 | the_net, 36 | f"{self.ctx.path}-{name}-step-{self.ctx.step_id}.torch" 37 | ) 38 | 39 | def emergency_read_checkpoint(self): 40 | import glob 41 | for name, the_net in self.checkpoint_names.items(): 42 | if isinstance(the_net, torch.nn.DataParallel): 43 | the_net = the_net.module 44 | files = glob.glob(f"{self.ctx.path}-{name}-epoch-*.torch") 45 | files = sorted(files, key=lambda x: int(x.split("-")[-1].split(".")[0])) 46 | target = files[-1] 47 | netread( 48 | the_net, 49 | target 50 | ) 51 | 52 | def write(self, path): 53 | data = {} 54 | data["_torch_rng_state"] = torch.random.get_rng_state() 55 | data["_np_rng_state"] = np.random.get_state() 56 | data["_random_rng_state"] = random.getstate() 57 | for name, param in self.save_names.items(): 58 | param = Savable.wrap(param) 59 | param.write(data, name) 60 | torch.save(data, path + ".tmp") 61 | if os.path.isfile(path): 62 | os.rename(path, path + ".old") 63 | os.rename(path + ".tmp", path) 64 | 65 | def read(self, path): 66 | data = torch.load(path) 67 | torch.random.set_rng_state(data["_torch_rng_state"]) 68 | np.random.set_state(data["_np_rng_state"]) 69 | random.setstate(data["_random_rng_state"]) 70 | for name, param in self.save_names.items(): 71 | param = Savable.wrap(param) 72 | param.read(data, name) 73 | 74 | def save(self, path=None): 75 | path = path or self.save_path() 76 | try: 77 | self.write(path) 78 | except SaveStateError: 79 | torch_rng_state = torch.random.get_rng_state() 80 | np_rng_state = np.random.get_state() 81 | random_rng_state = random.getstate() 82 | self.load() 83 | torch.random.set_rng_state(torch_rng_state) 84 | np.random.set_state(np_rng_state) 85 | random.setstate(random_rng_state) 86 | 87 | def load(self, path=None): 88 | try: 89 | path = path or self.save_path() 90 | if os.path.isfile(path): 91 | self.read(path) 92 | except Exception: 93 | print("Something went wrong! Trying to read latest network checkpoints...") 94 | self.emergency_read_checkpoint() 95 | return self 96 | -------------------------------------------------------------------------------- /torchsupport/flex/checkpointing/savable.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | SAVABLE_EXTENSION = {} 4 | 5 | def _set_savable(target, data_type=None): 6 | SAVABLE_EXTENSION[data_type] = target 7 | return target 8 | 9 | class _CompareType: 10 | def __init__(self, data_type): 11 | self.data_type = data_type 12 | 13 | def __lt__(self, other): 14 | return issubclass(self.data_type. other.data_type) 15 | 16 | def _resolve_savable(data_type): 17 | candidates = [] 18 | for key in SAVABLE_EXTENSION: 19 | if issubclass(data_type, key): 20 | candidates.append(_CompareType(key)) 21 | return SAVABLE_EXTENSION[min(candidates).data_type] 22 | 23 | def savable_of(data_type): 24 | return partial(_set_savable, data_type=data_type) 25 | 26 | class Savable: 27 | @staticmethod 28 | def wrap(data): 29 | if isinstance(data, Savable): 30 | return data 31 | return _resolve_savable(type(data))(data) 32 | 33 | def write(self, data, name): 34 | pass 35 | 36 | def read(self, data, name): 37 | pass 38 | 39 | class SaveStateError(Exception): 40 | pass 41 | 42 | def is_savable(x): 43 | return isinstance(x, Savable) or (type(x) in SAVABLE_EXTENSION) 44 | -------------------------------------------------------------------------------- /torchsupport/flex/checkpointing/savable_extensions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsupport.flex.checkpointing.savable import ( 3 | savable_of, Savable, SaveStateError 4 | ) 5 | 6 | @savable_of(torch.nn.Module) 7 | class SaveModule(Savable): 8 | def __init__(self, module): 9 | if isinstance(module, torch.nn.DataParallel): 10 | module = module.module 11 | self.module = module 12 | 13 | def write(self, data, name): 14 | for param in self.module.parameters(): 15 | if torch.isnan(param).any(): 16 | raise SaveStateError("Encountered NaN weights!") 17 | data[name] = self.module.state_dict() 18 | 19 | def read(self, data, name): 20 | self.module.load_state_dict(data[name]) 21 | 22 | @savable_of(torch.Tensor) 23 | class SaveTensor(Savable): 24 | def __init__(self, tensor): 25 | self.tensor = tensor 26 | 27 | def write(self, data, name): 28 | data[name] = self.tensor 29 | 30 | def read(self, data, name): 31 | self.tensor[:] = data[name] 32 | -------------------------------------------------------------------------------- /torchsupport/flex/context/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/context/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/context/context_module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ContextModule(nn.Module): 4 | def __init__(self, ctx=None): 5 | super().__init__() 6 | self._ctx = ctx 7 | 8 | @property 9 | def ctx(self): 10 | return self._ctx 11 | 12 | @ctx.setter 13 | def ctx(self, ctx): 14 | for module in self.modules(): 15 | if isinstance(module, ContextModule): 16 | module._ctx = ctx 17 | -------------------------------------------------------------------------------- /torchsupport/flex/data_distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/data_distributions/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/data_distributions/data_distribution.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | from torch.distributions import Distribution 5 | from torchsupport.data.collate import DataLoader 6 | from torchsupport.data.io import DeviceMovable, to_device 7 | 8 | class InfiniteSampler: 9 | def __init__(self, data_set): 10 | self.size = len(data_set) 11 | 12 | def __iter__(self): 13 | yield from itertools.islice(self.permutation(), 0, None, 1) 14 | 15 | def permutation(self): 16 | while True: 17 | yield from torch.randperm(self.size) 18 | 19 | class DataDistribution(Distribution, DeviceMovable): 20 | r"""Data distribution based on the PyTorch DataLoader 21 | combined with a standard dataset. Allows for loading 22 | multiple batches in parallel. 23 | """ 24 | def __init__(self, data_set, batch_size=1, device="cpu", **kwargs): 25 | self.data = data_set 26 | self.device = device 27 | self.loader = DataLoader( 28 | data_set, batch_size=batch_size, drop_last=True, 29 | sampler=InfiniteSampler(data_set), **kwargs 30 | ) 31 | self.iter = iter(self.loader) 32 | 33 | def move_to(self, device): 34 | self.device = device 35 | return self 36 | 37 | def sample(self, sample_shape=torch.Size()): 38 | return to_device(next(self.iter), self.device) 39 | -------------------------------------------------------------------------------- /torchsupport/flex/data_distributions/metadata_distribution.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.distributions import Distribution, Categorical 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | from torchsupport.data.io import to_device, DeviceMovable 9 | from torchsupport.flex.data_distributions.data_distribution import InfiniteSampler 10 | 11 | class MetaDataset(Dataset): 12 | def __init__(self, data): 13 | super().__init__() 14 | self.data = data 15 | 16 | def __getitem__(self, index): 17 | return super().__getitem__(index), index 18 | 19 | def __len__(self): 20 | return len(self.data) 21 | 22 | class MetaDataDistribution(nn.Module, Distribution, DeviceMovable): 23 | @staticmethod 24 | def init_meta(meta): 25 | meta.zero_() 26 | 27 | def move_to(self, device): 28 | self.device = device 29 | return self 30 | 31 | def __init__(self, data_set, meta_type=None, 32 | batch_size=1, device="cpu", init_meta=None, 33 | **kwargs): 34 | super().__init__() 35 | self.device = device 36 | self.batch_size = batch_size 37 | init_meta = init_meta or MetaDataDistribution.init_meta 38 | self.data_set = MetaDataset(data_set) 39 | self.meta_type = meta_type or torch.Size(1) 40 | self.meta_data = nn.Parameter(torch.zeros( 41 | len(self.data_set), *meta_type, 42 | requires_grad=True 43 | )) 44 | with torch.no_grad(): 45 | init_meta(self.meta_data) 46 | self.loader = DataLoader( 47 | data_set, batch_size=batch_size, drop_last=True, 48 | sampler=InfiniteSampler(data_set), **kwargs 49 | ) 50 | self.iter = iter(self.loader) 51 | 52 | def sample(self, sample_shape=torch.Size()): 53 | data, indices = next(self.iter) 54 | meta_data = self.meta_data[indices] 55 | return to_device((data, meta_data, indices), self.device) 56 | 57 | class WeightedInfiniteSampler: 58 | def __init__(self, data_set, weights): 59 | self.size = len(data_set) 60 | self.weights = weights 61 | 62 | def __iter__(self): 63 | yield from itertools.islice(self.permutation(), 0, None, 1) 64 | 65 | def permutation(self): 66 | while True: 67 | with torch.no_grad(): 68 | dist = Categorical(logits=self.weights) 69 | sample = dist.sample(self.batch_size).view(-1) 70 | yield from sample 71 | 72 | class WeightedDataDistribution(nn.Module, Distribution, DeviceMovable): 73 | @staticmethod 74 | def init_meta(meta): 75 | meta.normal_() 76 | 77 | def move_to(self, device): 78 | self.device = device 79 | return self 80 | 81 | def __init__(self, data_set, batch_size=1, device="cpu", 82 | init_meta=None, **kwargs): 83 | super().__init__() 84 | self.device = device 85 | init_meta = init_meta or WeightedDataDistribution.init_meta 86 | self.data_set = MetaDataset(data_set) 87 | self.weight = nn.Parameter(torch.zeros( 88 | len(self.data_set), requires_grad=True 89 | )) 90 | self.weight.share_memory_() 91 | with torch.no_grad(): 92 | init_meta(self.weight) 93 | self.loader = DataLoader( 94 | data_set, batch_size=batch_size, drop_last=True, 95 | sampler=WeightedInfiniteSampler(data_set, self.weight), **kwargs 96 | ) 97 | self.iter = iter(self.loader) 98 | 99 | def sample(self, sample_shape=torch.Size()): 100 | data, indices = next(self.iter) 101 | weight = self.weight.log_softmax(dim=0) 102 | weight = weight[indices] 103 | return to_device((data, weight, indices), self.device) 104 | -------------------------------------------------------------------------------- /torchsupport/flex/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/examples/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/examples/cifar_conditional_maximum_likelihood.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | from torch.utils.data import Dataset 5 | 6 | from torchvision.datasets import CIFAR10 7 | from torchvision.transforms import ToTensor 8 | 9 | from torchsupport.utils.argparse import parse_options 10 | from torchsupport.flex.log.log_types import LogImage 11 | from torchsupport.flex.context.context import TrainingContext 12 | from torchsupport.flex.data_distributions.data_distribution import DataDistribution 13 | from torchsupport.flex.tasks.likelihood.maximum_likelihood import SupervisedLikelihood 14 | from torchsupport.flex.training.conditional_maximum_likelihood import conditional_mle_training 15 | 16 | def valid_callback(args, ctx: TrainingContext=None): 17 | ctx.log(images=LogImage(args.condition)) 18 | labels = args.distribution.logits.argmax(dim=1) 19 | for idx in range(10): 20 | positive = args.condition[labels == idx] 21 | if positive.size(0) != 0: 22 | ctx.log(**{f"classified {idx}": LogImage(positive)}) 23 | 24 | class CIFAR10Dataset(Dataset): 25 | def __init__(self, data): 26 | self.data = data 27 | 28 | def __getitem__(self, index): 29 | data, label = self.data[index] 30 | return label, data 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | class Classifier(nn.Module): 36 | def __init__(self): 37 | super().__init__() 38 | self.conv = nn.Sequential( 39 | nn.Conv2d(3, 16, 3), 40 | nn.MaxPool2d(2), 41 | nn.Conv2d(16, 32, 3), 42 | nn.MaxPool2d(2), 43 | nn.Conv2d(32, 64, 3) 44 | ) 45 | self.out = nn.Linear(64, 10) 46 | 47 | def forward(self, inputs): 48 | features = self.conv(inputs) 49 | features = func.adaptive_avg_pool2d(features, 1) 50 | logits = self.out(features.view(features.size(0), -1)) 51 | cat = torch.distributions.Categorical(logits=logits) 52 | return cat 53 | 54 | if __name__ == "__main__": 55 | opt = parse_options( 56 | "CIFAR10 classifier training using flex.", 57 | path="flexamples/cifar10-cmle-1", 58 | device="cpu", 59 | batch_size=64, 60 | max_epochs=1000, 61 | report_interval=10 62 | ) 63 | 64 | cifar10 = CIFAR10("examples/", download=False, transform=ToTensor()) 65 | data = CIFAR10Dataset(cifar10) 66 | data = DataDistribution( 67 | data, batch_size=opt.batch_size, 68 | device=opt.device 69 | ) 70 | 71 | net = Classifier().to(opt.device) 72 | model = SupervisedLikelihood(net) 73 | 74 | training = conditional_mle_training( 75 | model, data, valid_data=data, 76 | path=opt.path, 77 | device=opt.device, 78 | batch_size=opt.batch_size, 79 | max_epochs=opt.max_epochs, 80 | report_interval=opt.report_interval 81 | ) 82 | training.get_step("valid_step").extend(valid_callback) 83 | training.register(net=net) 84 | training.checkpoint.remove_checkpoint("model") 85 | 86 | training.train() 87 | -------------------------------------------------------------------------------- /torchsupport/flex/examples/cifar_supervised.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | from torch.utils.data import Dataset 5 | 6 | from torchvision.datasets import CIFAR10 7 | from torchvision.transforms import ToTensor 8 | 9 | from torchsupport.utils.argparse import parse_options 10 | from torchsupport.flex.log.log_types import LogImage 11 | from torchsupport.flex.context.context import TrainingContext 12 | from torchsupport.flex.data_distributions.data_distribution import DataDistribution 13 | from torchsupport.flex.tasks.likelihood.maximum_likelihood import SupervisedArgs 14 | from torchsupport.flex.training.supervised import supervised_training 15 | 16 | def valid_callback(args: SupervisedArgs, ctx: TrainingContext=None): 17 | ctx.log(images=LogImage(args.sample)) 18 | labels = args.prediction.argmax(dim=1) 19 | for idx in range(10): 20 | positive = args.sample[labels == idx] 21 | if positive.size(0) != 0: 22 | ctx.log(**{f"classified {idx}": LogImage(positive)}) 23 | 24 | class CIFAR10Dataset(Dataset): 25 | def __init__(self, data): 26 | self.data = data 27 | 28 | def __getitem__(self, index): 29 | data, label = self.data[index] 30 | return data, label 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | class Classifier(nn.Module): 36 | def __init__(self): 37 | super().__init__() 38 | self.conv = nn.Sequential( 39 | nn.Conv2d(3, 16, 3), 40 | nn.MaxPool2d(2), 41 | nn.Conv2d(16, 32, 3), 42 | nn.MaxPool2d(2), 43 | nn.Conv2d(32, 64, 3) 44 | ) 45 | self.out = nn.Linear(64, 10) 46 | 47 | def forward(self, inputs): 48 | features = self.conv(inputs) 49 | features = func.adaptive_avg_pool2d(features, 1) 50 | return self.out(features.view(features.size(0), -1)) 51 | 52 | if __name__ == "__main__": 53 | opt = parse_options( 54 | "CIFAR10 classifier training using flex.", 55 | path="flexamples/cifar10-classifier", 56 | device="cpu", 57 | batch_size=64, 58 | max_epochs=1000, 59 | report_interval=10 60 | ) 61 | 62 | cifar10 = CIFAR10("examples/", download=False, transform=ToTensor()) 63 | data = CIFAR10Dataset(cifar10) 64 | data = DataDistribution( 65 | data, batch_size=opt.batch_size, 66 | device=opt.device 67 | ) 68 | 69 | net = Classifier() 70 | 71 | training = supervised_training( 72 | net, data, valid_data=data, 73 | losses=[nn.CrossEntropyLoss()], 74 | path=opt.path, 75 | device=opt.device, 76 | batch_size=opt.batch_size, 77 | max_epochs=opt.max_epochs, 78 | report_interval=opt.report_interval 79 | ) 80 | training.get_step("valid_step").extend(valid_callback) 81 | 82 | training.train() 83 | -------------------------------------------------------------------------------- /torchsupport/flex/log/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/log/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/log/log_types.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LogType: 4 | @property 5 | def data(self): 6 | return 0 7 | 8 | def log(self, logger, name, step): 9 | raise NotImplementedError("Abstract") 10 | 11 | class LogImage(LogType): 12 | def __init__(self, img): 13 | super().__init__() 14 | if torch.is_tensor(img): 15 | img = img.detach().cpu() 16 | if img.max() > 1.0 or img.min() < 0.0: 17 | img = (img - img.min()) / (img.max() - img.min()) 18 | self.img = img 19 | 20 | def log(self, logger, name, step): 21 | if self.img.dim() > 3: 22 | logger.log_image_batch(name, self.img, step) 23 | else: 24 | logger.log_image(name, self.img, step) 25 | 26 | class LogNumber(LogType): 27 | def __init__(self, number): 28 | super().__init__() 29 | if torch.is_tensor(number): 30 | number = float(number.detach().cpu()) 31 | self.number = number 32 | 33 | def log(self, logger, name, step): 34 | logger.log_number(name, self.number, step) 35 | 36 | class LogText(LogType): 37 | def __init__(self, text): 38 | super().__init__() 39 | self.text = text 40 | 41 | def log(self, logger, name, step): 42 | logger.log_text(name, self.text, step) 43 | 44 | class LogFigure(LogType): 45 | def __init__(self, figure): 46 | super().__init__() 47 | self.figure = figure 48 | 49 | def log(self, logger, name, step): 50 | logger.log_figure(name, self.figure, step) 51 | 52 | class LogEmbedding(LogType): 53 | def __init__(self, embedding): 54 | super().__init__() 55 | if torch.is_tensor(embedding): 56 | embedding = embedding.detach().cpu() 57 | embedding = embedding.reshape(embedding.shape[0], -1) 58 | self.embedding = embedding 59 | 60 | def log(self, logger, name, step): 61 | logger.log_embedding(name, self.embedding, step) 62 | -------------------------------------------------------------------------------- /torchsupport/flex/log/logger.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from torchsupport.flex.log.log_types import LogType 3 | 4 | class Logger: 5 | def log_image(self, name, data, step): 6 | raise NotImplementedError("Abstract") 7 | 8 | def log_image_batch(self, name, data, step): 9 | raise NotImplementedError("Abstract") 10 | 11 | def log_number(self, name, data, step): 12 | raise NotImplementedError("Abstract") 13 | 14 | def log_text(self, name, data, step): 15 | raise NotImplementedError("Abstract") 16 | 17 | def log_figure(self, name, data, step): 18 | raise NotImplementedError("Abstract") 19 | 20 | def log_embedding(self, name, data, step): 21 | raise NotImplementedError("Abstract") 22 | 23 | def log(self, name, data, step): 24 | if isinstance(data, (float, int)): 25 | self.log_number(name, data, step) 26 | elif isinstance(data, str): 27 | self.log_text(name, data, step) 28 | elif isinstance(data, LogType): 29 | data.log(self, name, step) 30 | else: 31 | warnings.warn(f"{name} of type {type(data)} could not be logged.\n" 32 | f"Consider implementing a custom LogType.") 33 | -------------------------------------------------------------------------------- /torchsupport/flex/log/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from torchsupport.flex.log.logger import Logger 3 | 4 | class TensorboardLogger(Logger): 5 | def __init__(self, path): 6 | self.writer = SummaryWriter(path) 7 | 8 | def log_image(self, name, data, step): 9 | self.writer.add_image(name, data, step) 10 | 11 | def log_image_batch(self, name, data, step): 12 | self.writer.add_images(name, data, step) 13 | 14 | def log_number(self, name, data, step): 15 | self.writer.add_scalar(name, data, step) 16 | 17 | def log_text(self, name, data, step): 18 | self.writer.add_text(name, data, step) 19 | 20 | def log_figure(self, name, data, step): 21 | self.writer.add_figure(name, data, step) 22 | 23 | def log_embedding(self, name, data, step): 24 | self.writer.add_embedding(name, data, step) 25 | -------------------------------------------------------------------------------- /torchsupport/flex/step/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/step/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/step/loop.py: -------------------------------------------------------------------------------- 1 | from torchsupport.flex.step.step import Step 2 | from torchsupport.flex.checkpointing.savable import is_savable 3 | 4 | class SequentialStep(Step): 5 | def __init__(self, ctx=None): 6 | super().__init__(None, ctx=ctx) 7 | self.run = [] 8 | 9 | def add(self, **kwargs): 10 | for name, step in kwargs.items(): 11 | setattr(self, name, step) 12 | step.name = name 13 | self.run.append(step) 14 | return self 15 | 16 | def write(self, data, name): 17 | for idx, rval in enumerate(self.run): 18 | if is_savable(rval): 19 | rval.write(data, f"{name}.run_{idx}") 20 | 21 | def read(self, data, name): 22 | for idx, rval in enumerate(self.run): 23 | if is_savable(rval): 24 | rval.read(data, f"{name}.run_{idx}") 25 | 26 | def __lshift__(self, other): 27 | self.add(**{f"step_{len(self.run)}": other}) 28 | return self 29 | 30 | def step(self): 31 | for step in self.run: 32 | step() 33 | 34 | class Loop(SequentialStep): 35 | def __init__(self, num_steps=1000, ctx=None): 36 | super().__init__(ctx=ctx) 37 | self.num_steps = num_steps 38 | 39 | def step(self): 40 | for idx in range(self.num_steps): 41 | super().step() 42 | 43 | class ConfiguredStep(Loop): 44 | def __init__(self, step, every=1, num_steps=1, ctx=None): 45 | super().__init__(num_steps=num_steps, ctx=ctx) 46 | self.every = every 47 | self.add(inner=step) 48 | 49 | def step(self): 50 | if self.ctx.step_id % self.every == 0: 51 | super().step() 52 | -------------------------------------------------------------------------------- /torchsupport/flex/step/step.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsupport.flex.checkpointing.savable import Savable, is_savable 3 | 4 | class Step(Savable): 5 | def __init__(self, run, ctx=None): 6 | self.ctx = ctx 7 | self.name = None 8 | self.run = run 9 | 10 | def __setattr__(self, name, value): 11 | if isinstance(value, Step): 12 | value.ctx = self.ctx 13 | super().__setattr__(name, value) 14 | 15 | def extend(self, run): 16 | crun = self.run 17 | self.run = lambda ctx=None: run(crun(ctx=ctx), ctx=ctx) 18 | 19 | def step(self): 20 | return self.run(ctx=self.ctx) 21 | 22 | def write(self, data, name): 23 | if is_savable(self.run): 24 | self.run.write(data, f"{name}.run") 25 | 26 | def read(self, data, name): 27 | if is_savable(self.run): 28 | self.run.read(data, f"{name}.run") 29 | 30 | def __call__(self): 31 | with self.ctx.switch(self): 32 | return self.step() 33 | 34 | class EmptyStep(Step): 35 | @staticmethod 36 | def noop(ctx): 37 | return 38 | 39 | def __init__(self, ctx=None): 40 | super().__init__(run=EmptyStep.noop, ctx=ctx) 41 | 42 | class UpdateStep(Step): 43 | def __init__(self, run, update, ctx=None): 44 | super().__init__(run, ctx=ctx) 45 | self.update = update 46 | 47 | def extend_update(self, **kwargs): 48 | for name, value in kwargs.items(): 49 | setattr(self.update, name, value) 50 | 51 | def write(self, data, name): 52 | super().write(data, name) 53 | if is_savable(self.update): 54 | self.update.write(data, f"{name}.update") 55 | 56 | def read(self, data, name): 57 | super().read(data, name) 58 | if is_savable(self.update): 59 | self.update.read(data, f"{name}.update") 60 | 61 | def step(self): 62 | with self.update as update: 63 | result = self.run(ctx=self.ctx) 64 | update.target = self.ctx.loss 65 | 66 | class EvalStep(Step): 67 | def __init__(self, run, modules=None, 68 | no_grad=True, ctx=None): 69 | super().__init__(run, ctx=ctx) 70 | self.modules = modules 71 | self.no_grad = no_grad 72 | 73 | def step(self): 74 | for net in self.modules: 75 | net.eval() 76 | if self.no_grad: 77 | with torch.no_grad(): 78 | self.run(ctx=self.ctx) 79 | else: 80 | self.run(ctx=self.ctx) 81 | for net in self.modules: 82 | net.train() 83 | -------------------------------------------------------------------------------- /torchsupport/flex/step/training_loop.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from torchsupport.flex.step.step import Step 4 | from torchsupport.flex.step.loop import Loop, ConfiguredStep, SequentialStep 5 | 6 | class LogStep(Step): 7 | @staticmethod 8 | def run_log(ctx=None): 9 | ctx.logger.log() 10 | 11 | def __init__(self, ctx=None): 12 | super().__init__(LogStep.run_log, ctx=ctx) 13 | 14 | class CheckpointStep(Step): 15 | @staticmethod 16 | def run_checkpoint(ctx=None): 17 | ctx.checkpoint.save() 18 | 19 | class TrainingLoop(Loop): 20 | def __init__(self, num_steps=1000, ctx=None): 21 | super().__init__(num_steps=num_steps, ctx=ctx) 22 | self.setup = SequentialStep() 23 | self.teardown = SequentialStep() 24 | 25 | def log(self): 26 | for name, value in self.ctx.log_store.items(): 27 | self.ctx.logger.log(name, value, self.ctx.step_id) 28 | self.ctx.log_store = {} 29 | 30 | def checkpoint(self): 31 | if self.ctx.step_id % self.ctx.checkpoint_interval == 0: 32 | self.ctx.checkpoint.checkpoint() 33 | 34 | def save(self): 35 | if self.ctx.save_time is None: 36 | self.ctx.save_time = time.monotonic() 37 | time_since_last_save = time.monotonic() - self.ctx.save_time 38 | if time_since_last_save > self.ctx.save_interval: 39 | self.ctx.checkpoint.save() 40 | self.ctx.save_time = time.monotonic() 41 | 42 | def add(self, every=1, num_steps=1, **kwargs): 43 | for name, step in kwargs.items(): 44 | step.ctx = self.ctx 45 | if not isinstance(step, ConfiguredStep): 46 | step = ConfiguredStep( 47 | step, every=every, 48 | num_steps=num_steps, 49 | ctx=self.ctx 50 | ) 51 | step.name = name 52 | setattr(self, name, step) 53 | self.run.append(step) 54 | return self 55 | 56 | def step(self): 57 | self.setup() 58 | for idx in range(self.ctx.step_id, self.num_steps): 59 | self.ctx.step_id = idx 60 | for step in self.run: 61 | step() 62 | self.log() 63 | self.checkpoint() 64 | self.save() 65 | self.teardown() 66 | -------------------------------------------------------------------------------- /torchsupport/flex/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/tasks/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/tasks/energy/__init__.py: -------------------------------------------------------------------------------- 1 | from torchsupport.flex.tasks.energy.density_ratio import * 2 | from torchsupport.flex.tasks.energy.score_matching import * 3 | -------------------------------------------------------------------------------- /torchsupport/flex/tasks/energy/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Normal 3 | 4 | class Diffusion: 5 | def mixing(self, data, noise, level): 6 | return data 7 | 8 | def conditional(self, data, noise, level): 9 | return 0.0 10 | 11 | class VEDiffusion(Diffusion): 12 | def mixing(self, data, noise, level): 13 | expand = [1] * (data.dim() - level.dim()) 14 | level = level.view(*level.shape, *expand) 15 | return data + level * noise 16 | 17 | def conditional(self, data, condition, level): 18 | numerator = -(data - condition) ** 2 19 | denominator = 2 * level ** 2 20 | result = numerator / denominator 21 | return result.view(result.size(0), -1).sum(dim=1, keepdim=True) 22 | 23 | class VPDiffusion(Diffusion): 24 | def mixing(self, data, noise, level): 25 | expand = [1] * (data.dim() - level.dim()) 26 | level = level.view(*level.shape, *expand) 27 | return (1 - level).sqrt() * data + level.sqrt() * noise 28 | 29 | def conditional(self, data, condition, level): 30 | numerator = -((1 - level).sqrt() * data - condition) ** 2 31 | denominator = 2 * level 32 | result = numerator / denominator 33 | return result.view(result.size(0), -1).sum(dim=1, keepdim=True) 34 | 35 | class DiscreteDiffusion(Diffusion): 36 | def mixing(self, data, noise, level): 37 | mask = (torch.rand_like(data) < level).float() 38 | return mask * noise + (1 - mask) * data 39 | 40 | def conditional(self, data, condition, level): 41 | dist = (1 - level) * condition + level * torch.ones_like(condition) 42 | result = (dist * data).sum(dim=1).log() 43 | return result.view(result.size(0), -1).sum(dim=1, keepdim=True) 44 | 45 | class RandomReplacementDiffusion(DiscreteDiffusion): 46 | def __init__(self, base, sigma=1e-3): 47 | self.base = base 48 | self.sigma = sigma 49 | 50 | def conditional(self, data, condition, level): 51 | base = self.base.log_prob(data).exp() 52 | delta = Normal(condition, self.sigma) 53 | delta = delta.log_prob(data) 54 | delta = delta.view(delta.size(0), -1).sum(dim=1, keepdim=True).exp() 55 | prob = (1 - level) * delta + level * base 56 | return prob.log() 57 | 58 | class ComposedDiffusion(Diffusion): 59 | def __init__(self, *components): 60 | self.components = components 61 | 62 | def mixing(self, data, noise, level): 63 | result = [] 64 | for d, n, l, c in zip(data, noise, level, self.components): 65 | result.append(c(d, n, l)) 66 | return result 67 | 68 | def conditional(self, data, condition, level): 69 | result = 0.0 70 | for d, n, l, c in zip(data, condition, level, self.components): 71 | result = result + c(d, n, l) 72 | return result 73 | -------------------------------------------------------------------------------- /torchsupport/flex/tasks/energy/diffusion_recovery_likelihood.py: -------------------------------------------------------------------------------- 1 | import random 2 | from functools import partial 3 | from torchsupport.flex.log.log_types import LogImage 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as func 8 | 9 | from torchsupport.data.namedtuple import namespace 10 | 11 | def run_diffusion_recovery_likelihood(energy, base, data, args, 12 | integrator=None, mixing=None, 13 | conditional=None): 14 | real_data, condition, levels, _ = mixing(data, base) 15 | conditional_energy = conditional(energy, condition) 16 | fake_data = integrator.integrate(conditional_energy, condition, args) 17 | real, fake = energy(real_data, levels, args), energy(fake_data, levels, args) 18 | loss = real.mean() - fake.mean() 19 | return loss, namespace( 20 | real_data=real_data, fake_data=fake_data, condition=condition, 21 | real=real, fake=fake 22 | ) 23 | 24 | def diffusion_recovery_step(energy, base, data, integrator=None, 25 | mixing=None, conditional=None, ctx=None): 26 | data, condition = data.sample(ctx.batch_size) 27 | loss, args = run_diffusion_recovery_likelihood( 28 | energy, base, data, condition, 29 | integrator=integrator, mixing=mixing, 30 | conditional=conditional 31 | ) 32 | ctx.argmin(density_ratio_loss=loss) 33 | return args 34 | -------------------------------------------------------------------------------- /torchsupport/flex/tasks/gan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/tasks/gan/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/tasks/gan/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class CriticLoss: 6 | def __init__(self, ctx=None): 7 | self.ctx = ctx 8 | 9 | def critic(self, real, fake): 10 | return 0.0 11 | 12 | def generator(self, real): 13 | return 0.0 14 | 15 | def relativistic(real, fake): 16 | real = real - fake.mean(dim=0, keepdim=True) 17 | fake = fake - real.mean(dim=0, keepdim=True) 18 | return real, fake 19 | 20 | class non_saturating(CriticLoss): 21 | def __init__(self, ctx=None, smoothing=0.0): 22 | super().__init__(ctx=ctx) 23 | self.smoothing = smoothing 24 | 25 | def critic(self, real, fake): 26 | real = func.binary_cross_entropy_with_logits( 27 | real, torch.zeros_like(real) + self.smoothing 28 | ).mean() 29 | fake = func.binary_cross_entropy_with_logits( 30 | fake, torch.ones_like(fake) 31 | ).mean() 32 | return real + fake 33 | 34 | def generator(self, fake): 35 | return func.binary_cross_entropy_with_logits( 36 | fake, torch.zeros_like(fake) 37 | ) 38 | 39 | class least_squares(CriticLoss): 40 | def __init__(self, real=1.0, fake=0.0, ctx=None): 41 | super().__init__(ctx=ctx) 42 | self.real = real 43 | self.fake = fake 44 | 45 | def critic(self, real, fake): 46 | real = ((real - self.real) ** 2).mean() 47 | fake = ((fake - self.fake) ** 2).mean() 48 | return real + fake 49 | 50 | def generator(self, fake): 51 | return ((fake - self.real) ** 2).mean() 52 | 53 | class energy_based(CriticLoss): 54 | def critic(self, real, fake): 55 | return real.mean() - fake.mean() 56 | 57 | def generator(self, fake): 58 | return fake.mean() 59 | -------------------------------------------------------------------------------- /torchsupport/flex/tasks/gan/tasks.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from torchsupport.flex.tasks.utils import parallel_steps 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as func 7 | 8 | from torchsupport.data.namedtuple import namespace 9 | from torchsupport.flex.tasks.gan.losses import non_saturating 10 | 11 | def run_discriminator(discriminator, real_data, fake_data, 12 | gan_loss=non_saturating, 13 | gan_loss_kwargs=None, 14 | ctx=None): 15 | real, fake = discriminator(real_data), discriminator(fake_data) 16 | loss = gan_loss(ctx=ctx, **(gan_loss_kwargs or {})).critic(real, fake) 17 | return loss, namespace( 18 | real_data=real_data, fake_data=fake_data, 19 | real=real, fake=fake, ctx=ctx 20 | ) 21 | 22 | def discriminator_step(generator, discriminator, data, 23 | gan_loss=non_saturating, 24 | gan_loss_kwargs=None, 25 | ctx=None): 26 | real_data = data.sample(ctx.batch_size) 27 | fake_data = generator.sample(ctx.batch_size) 28 | loss, args = run_discriminator( 29 | discriminator, real_data, fake_data, 30 | gan_loss=gan_loss, gan_loss_kwargs=gan_loss_kwargs, 31 | ctx=ctx 32 | ) 33 | ctx.argmin(discriminator_loss=loss) 34 | return args 35 | 36 | def run_generator(generator, discriminator, 37 | gan_loss=non_saturating, 38 | gan_loss_kwargs=None, 39 | ctx=None): 40 | fake_data = generator.sample(ctx.batch_size) 41 | fake = discriminator(fake_data) 42 | loss = gan_loss(ctx=ctx, **gan_loss_kwargs).generator(fake) 43 | return loss, namespace( 44 | fake_data=fake_data, fake=fake, ctx=ctx 45 | ) 46 | 47 | def generator_step(generator, discriminator, 48 | gan_loss=non_saturating, 49 | gan_loss_kwargs=None, 50 | ctx=None): 51 | loss, args = run_generator( 52 | generator, discriminator, 53 | gan_loss=gan_loss, gan_loss_kwargs=gan_loss_kwargs, 54 | ctx=ctx 55 | ) 56 | ctx.argmin(generator_loss=loss) 57 | return args 58 | -------------------------------------------------------------------------------- /torchsupport/flex/tasks/likelihood/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/tasks/likelihood/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/tasks/likelihood/maximum_likelihood.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, NamedTuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torchsupport.data.match import match 7 | from torchsupport.data.namedtuple import namespace 8 | from torchsupport.flex.context.context_module import ContextModule 9 | 10 | class SupervisedLikelihood(nn.Module): 11 | def __init__(self, predictor): 12 | super().__init__() 13 | self.predictor = predictor 14 | 15 | def forward(self, data, condition): 16 | distribution = self.predictor(condition) 17 | return distribution.log_prob(data), namespace(distribution=distribution) 18 | 19 | def log_likelihood(model, data, condition): 20 | log_p, args = model(data, condition) 21 | return log_p, args 22 | 23 | def maximum_likelihood_step(model, data, ctx=None): 24 | data, condition = data.sample(ctx.batch_size) 25 | log_p, args = model(data, condition) 26 | ctx.argmax(log_likelihood=log_p) 27 | return namespace( 28 | data=data, condition=condition, **args.asdict() 29 | ) 30 | 31 | def supervised_loss(prediction, ground_truth, losses): 32 | loss_value = 0.0 33 | loss_values = [] 34 | if not isinstance(prediction, (list, tuple)): 35 | prediction, ground_truth = [prediction], [ground_truth] 36 | for pred, gt, loss in zip(prediction, ground_truth, losses): 37 | lval = loss(pred, gt).mean() 38 | loss_values.append(lval) 39 | loss_value += lval 40 | return loss_value, loss_values 41 | 42 | class SupervisedArgs(NamedTuple): 43 | prediction: Union[torch.Tensor, List[torch.Tensor]] 44 | ground_truth: Union[torch.Tensor, List[torch.Tensor]] 45 | sample: Union[torch.Tensor, Any] 46 | losses: List[torch.Tensor] 47 | 48 | def run_supervised(model, sample, ground_truth, losses): 49 | prediction = model(sample) 50 | loss, losses = supervised_loss(prediction, ground_truth, losses) 51 | return loss, SupervisedArgs( 52 | prediction=prediction, ground_truth=ground_truth, 53 | sample=sample, losses=losses 54 | ) 55 | 56 | def supervised_step(model, data, losses=None, ctx=None): 57 | sample, ground_truth = data.sample(ctx.batch_size) 58 | loss, args = run_supervised(model, sample, ground_truth, losses) 59 | ctx.argmin(total_loss=loss) 60 | for idx, lval in enumerate(args.losses): 61 | ctx.log(**{f"loss_{idx}": float(lval)}) 62 | return args 63 | -------------------------------------------------------------------------------- /torchsupport/flex/tasks/regularization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/tasks/regularization/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/tasks/task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Task: 5 | def parameters(self): 6 | return [] 7 | 8 | def run(self, ctx=None): 9 | raise NotImplementedError("Abstract.") 10 | 11 | def __call__(self, ctx=None): 12 | return self.run(ctx=ctx) 13 | -------------------------------------------------------------------------------- /torchsupport/flex/tasks/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from torchsupport.flex.context.context import RedirectContext 4 | 5 | def redirect(step, *args, **kwargs): 6 | ctx = RedirectContext() 7 | # TODO 8 | 9 | def namespace_select(f, namespace, names, knames): 10 | return f( 11 | *[namespace[name] for name in names], 12 | **{name: namespace[knames[name]] for name in knames} 13 | ) 14 | 15 | def select(f, *names, **knames): 16 | return functools.partial(namespace_select, f, names, knames) 17 | 18 | def aux_chain(functions, *args, **kwargs): 19 | for f in functions: 20 | f(*args, **kwargs) 21 | 22 | def chain(*functions): 23 | return functools.partial(aux_chain, functions) 24 | 25 | def aux_with_name(ctx, name, function, *args, **kwargs): 26 | with ctx.switch(name): 27 | return function(*args, **kwargs) 28 | 29 | def with_name(ctx, name, function): 30 | return functools.partial(aux_with_name, ctx, name, function) 31 | 32 | def aux_compose(functions, **kwargs): 33 | tmp = [] 34 | for f in functions: 35 | tmp = f(*tmp, **kwargs) 36 | 37 | def compose(*functions): 38 | return functools.partial(aux_compose, functions) 39 | 40 | def parallel_steps(ctx=None, **kwargs): 41 | to_chain = [] 42 | for name, step in kwargs.items(): 43 | to_chain.append(with_name(ctx, name, step)) 44 | return chain(*to_chain) 45 | 46 | def composed_steps(ctx=None, **kwargs): 47 | to_chain = [] 48 | for name, step in kwargs.items(): 49 | to_chain.append(with_name(ctx, name, step)) 50 | return compose(*to_chain) 51 | 52 | def noop(*args, **kwargs): 53 | pass 54 | -------------------------------------------------------------------------------- /torchsupport/flex/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsupport.flex.update.update import Update 3 | from torchsupport.flex.step.step import Step, UpdateStep 4 | from torchsupport.flex.data_distributions.data_distribution import DataDistribution 5 | from torchsupport.flex.context.context import TrainingContext 6 | from torchsupport.data.io import to_device 7 | 8 | def canned_supervised(ctx, net, data, losses): 9 | x, y = data.sample() 10 | predictions = net(x) 11 | for idx, (p_val, y_val, loss) in enumerate(zip( 12 | predictions, y, losses 13 | )): 14 | loss_val = loss(p_val, y_val) 15 | ctx.argmin(**{f"loss {idx}": loss_val}) 16 | 17 | def SupervisedTraining(net, data, valid_data, 18 | losses=None, **kwargs): 19 | ctx = TrainingContext(kwargs["network_name"], **kwargs) 20 | net = to_device(net, ctx.device) 21 | data = DataDistribution(data, batch_size=ctx.batch_size, num_workers=ctx.num_workers) 22 | valid_data = DataDistribution(valid_data, batch_size=ctx.batch_size, num_workers=ctx.num_workers) 23 | ctx.checkpoint.add_checkpoint(net=net) 24 | ctx.loop \ 25 | .add(train=UpdateStep( 26 | canned_supervised(ctx, net.train(), data, losses), 27 | Update(net, optimizer=torch.optim.Adam) 28 | )) \ 29 | .add(valid=Step( 30 | canned_supervised(ctx, net.eval(), valid_data, losses) 31 | )) 32 | return ctx 33 | -------------------------------------------------------------------------------- /torchsupport/flex/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/training/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/training/conditional_maximum_likelihood.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from torchsupport.flex.step.loop import ConfiguredStep 3 | 4 | import torch 5 | from torchsupport.flex.tasks.likelihood.maximum_likelihood import maximum_likelihood_step 6 | 7 | from torchsupport.data.io import to_device 8 | from torchsupport.flex.step.step import EvalStep, UpdateStep 9 | from torchsupport.flex.update.update import Update 10 | from torchsupport.flex.context.context import TrainingContext 11 | from torchsupport.flex.utils import filter_kwargs 12 | 13 | def conditional_mle_training(model, data, valid_data=None, 14 | optimizer=torch.optim.Adam, 15 | optimizer_kwargs=None, 16 | eval_no_grad=True, 17 | **kwargs): 18 | opt = filter_kwargs(kwargs, ctx=TrainingContext) 19 | ctx = TrainingContext(**opt.ctx) 20 | ctx.optimizer = optimizer 21 | 22 | # networks to device 23 | ctx.register( 24 | data=to_device(data, ctx.device), 25 | model=to_device(model, ctx.device) 26 | ) 27 | 28 | ctx.add(train_step=UpdateStep( 29 | partial(maximum_likelihood_step, ctx.model, ctx.data), 30 | Update([ctx.model], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), 31 | ctx=ctx 32 | )) 33 | 34 | if valid_data is not None: 35 | ctx.register(valid_data=to_device(valid_data, ctx.device)) 36 | ctx.add(valid_step=EvalStep( 37 | partial(maximum_likelihood_step, ctx.model, ctx.valid_data), 38 | modules=[ctx.model], no_grad=eval_no_grad, ctx=ctx 39 | ), every=ctx.report_interval) 40 | return ctx 41 | -------------------------------------------------------------------------------- /torchsupport/flex/training/density_ratio.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torchsupport.flex.tasks.likelihood.maximum_likelihood import maximum_likelihood_step 5 | from torchsupport.flex.tasks.energy.density_ratio import tdre_step, density_ratio_step 6 | 7 | from torchsupport.data.io import to_device 8 | from torchsupport.flex.step.step import UpdateStep 9 | from torchsupport.flex.update.update import Update 10 | from torchsupport.flex.context.context import TrainingContext 11 | from torchsupport.flex.utils import filter_kwargs 12 | 13 | def base_dre_training(energy, base, data, train_base=True, 14 | base_step=maximum_likelihood_step, 15 | optimizer=torch.optim.Adam, 16 | base_optimizer_kwargs=None, 17 | **kwargs): 18 | opt = filter_kwargs(kwargs, ctx=TrainingContext) 19 | ctx = TrainingContext(**opt.ctx) 20 | ctx.optimizer = optimizer 21 | 22 | # networks to device 23 | ctx.register( 24 | data=to_device(data, ctx.device), 25 | base=to_device(base, ctx.device), 26 | energy=to_device(energy, ctx.device) 27 | ) 28 | 29 | if train_base: 30 | ctx.add(base_step=UpdateStep( 31 | partial(base_step, ctx.base, ctx.data), 32 | Update([ctx.base], optimizer=ctx.optimizer, **(base_optimizer_kwargs or {})), 33 | ctx=ctx 34 | )) 35 | return ctx 36 | 37 | def telescoping_density_ratio_training(energy, base, data, mixing=None, 38 | optimizer_kwargs=None, 39 | telescoping_step=tdre_step, 40 | verbose=True, 41 | **kwargs): 42 | opt = filter_kwargs(kwargs, ctx=base_dre_training) 43 | ctx = base_dre_training(energy, base, data, **opt.ctx) 44 | 45 | ctx.add(tdre_step=UpdateStep( 46 | partial(telescoping_step, ctx.energy, ctx.base, ctx.data, mixing=mixing, verbose=verbose), 47 | Update([ctx.energy], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), 48 | ctx=ctx 49 | )) 50 | 51 | return ctx 52 | 53 | def density_ratio_training(energy, base, data, optimizer_kwargs=None, 54 | density_ratio_step=density_ratio_step, 55 | **kwargs): 56 | opt = filter_kwargs(kwargs, ctx=base_dre_training) 57 | ctx = base_dre_training(energy, base, data, **opt.ctx) 58 | 59 | ctx.add(dre_step=UpdateStep( 60 | partial(density_ratio_step, ctx.energy, ctx.base, ctx.data), 61 | Update([ctx.energy], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), 62 | ctx=ctx 63 | )) 64 | 65 | return ctx 66 | -------------------------------------------------------------------------------- /torchsupport/flex/training/supervised.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from torchsupport.flex.step.loop import ConfiguredStep 3 | 4 | import torch 5 | from torchsupport.flex.tasks.likelihood.maximum_likelihood import supervised_step 6 | 7 | from torchsupport.data.io import to_device 8 | from torchsupport.flex.step.step import EvalStep, UpdateStep 9 | from torchsupport.flex.update.update import Update 10 | from torchsupport.flex.context.context import TrainingContext 11 | from torchsupport.flex.utils import filter_kwargs 12 | 13 | def supervised_training(net, data, valid_data=None, losses=None, 14 | optimizer=torch.optim.Adam, 15 | optimizer_kwargs=None, 16 | eval_no_grad=True, 17 | **kwargs): 18 | opt = filter_kwargs(kwargs, ctx=TrainingContext) 19 | ctx = TrainingContext(**opt.ctx) 20 | ctx.optimizer = optimizer 21 | ctx.losses = losses 22 | 23 | # networks to device 24 | ctx.register( 25 | data=to_device(data, ctx.device), 26 | net=to_device(net, ctx.device) 27 | ) 28 | 29 | ctx.add(train_step=UpdateStep( 30 | partial(supervised_step, ctx.net, ctx.data, losses=ctx.losses), 31 | Update([ctx.net], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), 32 | ctx=ctx 33 | )) 34 | 35 | if valid_data is not None: 36 | ctx.register(valid_data=to_device(valid_data, ctx.device)) 37 | ctx.add(valid_step=EvalStep( 38 | partial(supervised_step, ctx.net, ctx.valid_data, losses=ctx.losses), 39 | modules=[ctx.net], no_grad=eval_no_grad, ctx=ctx 40 | ), every=ctx.report_interval) 41 | return ctx 42 | -------------------------------------------------------------------------------- /torchsupport/flex/update/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/flex/update/__init__.py -------------------------------------------------------------------------------- /torchsupport/flex/update/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Update: 5 | def __init__(self, parameter_sources, optimizer=None, 6 | gradient_action=None, **kwargs): 7 | optimizer = optimizer or torch.optim.Adam 8 | self.gradient_action = gradient_action 9 | self.update_actions = [] 10 | self.parameters = [] 11 | for parameter_source in parameter_sources: 12 | if hasattr(parameter_source, "update_action"): 13 | self.update_actions.append(parameter_source) 14 | elif isinstance(parameter_source, nn.Module): 15 | self.parameters += parameter_source.parameters() 16 | else: 17 | self.parameters.append(parameter_source) 18 | self.optimizer = optimizer(self.parameters, **kwargs) 19 | self.target = None 20 | 21 | def process_gradients(self): 22 | if self.gradient_action is not None: 23 | self.gradient_action(self.parameters) 24 | 25 | def __enter__(self, *args, **kwargs): 26 | self.optimizer.zero_grad() 27 | return self 28 | 29 | def __exit__(self, *args, **kwargs): 30 | loss = self.target 31 | loss.backward() 32 | self.process_gradients() 33 | self.target = None 34 | self.optimizer.step() 35 | for update_action in self.update_actions: 36 | update_action.update_action() 37 | -------------------------------------------------------------------------------- /torchsupport/flex/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from torchsupport.data.namedtuple import namespace 4 | 5 | def get_kwargs(f): 6 | result = [] 7 | has_kwargs = False 8 | sig = inspect.signature(f) 9 | for key, val in sig.parameters.items(): 10 | if val.default != inspect.Parameter.empty: 11 | result.append(key) 12 | if val.kind == inspect.Parameter.VAR_KEYWORD: 13 | has_kwargs = True 14 | return result, has_kwargs 15 | 16 | def filter_kwargs(kwargs, **targets): 17 | result = {} 18 | for name, target in targets.items(): 19 | result[name] = {} 20 | target_kwargs, has_kwargs = get_kwargs(target) 21 | if has_kwargs: 22 | result[name] = kwargs 23 | else: 24 | for key in target_kwargs: 25 | if key in kwargs: 26 | result[name][key] = kwargs[key] 27 | return namespace(**result) 28 | -------------------------------------------------------------------------------- /torchsupport/interacting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/interacting/__init__.py -------------------------------------------------------------------------------- /torchsupport/interacting/awac.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as func 6 | 7 | from torchsupport.interacting.off_policy_training import OffPolicyTraining 8 | 9 | class AWACTraining(OffPolicyTraining): 10 | def __init__(self, policy, value, agent, environment, 11 | beta=1.0, clip=None, tau=5e-3, **kwargs): 12 | self.value = ... 13 | super().__init__( 14 | policy, agent, environment, 15 | {"value": value}, **kwargs 16 | ) 17 | self.beta = beta 18 | self.clip = clip 19 | self.tau = tau 20 | self.target = deepcopy(value) 21 | 22 | def update_target(self): 23 | with torch.no_grad(): 24 | tp = self.target.parameters() 25 | ap = self.value.parameters() 26 | for t, a in zip(tp, ap): 27 | t *= (1 - self.tau) 28 | t += self.tau * a 29 | 30 | def action_nll(self, policy, action): 31 | return func.cross_entropy(policy, action, reduction='none') 32 | 33 | def policy_loss(self, policy, action, advantage): 34 | weight = torch.exp(advantage / self.beta) 35 | if self.clip is not None: 36 | weight = weight.clamp(0, self.clip) 37 | negative_log_likelihood = self.action_nll(policy, action) 38 | weighted_loss = negative_log_likelihood * weight 39 | return weighted_loss.mean() 40 | 41 | def state_value(self, state, value=None): 42 | value = value or self.value 43 | action_value = value(state) 44 | policy = self.policy(state) 45 | expected = action_value * policy.softmax(dim=1) 46 | expected = expected.sum(dim=1) 47 | return expected 48 | 49 | def run_policy(self, sample): 50 | initial_state = sample.initial_state 51 | action = sample.action 52 | 53 | with torch.no_grad(): 54 | action_value = self.value(initial_state) 55 | inds = torch.arange(action.size(0), device=action.device) 56 | action_value = action_value[inds, action] 57 | value = self.state_value(initial_state) 58 | advantage = action_value - value 59 | 60 | self.current_losses["mean advantage"] = float(advantage.mean()) 61 | 62 | policy = self.policy(initial_state) 63 | 64 | return policy, action, advantage 65 | 66 | def auxiliary_loss(self, value, target): 67 | return func.mse_loss(value.view(-1), target.view(-1)) 68 | 69 | def run_auxiliary(self, sample): 70 | self.update_target() 71 | 72 | initial_state = sample.initial_state 73 | final_state = sample.final_state 74 | action = sample.action 75 | rewards = sample.rewards 76 | action_value = self.value(initial_state) 77 | inds = torch.arange(action.size(0), device=action.device) 78 | action_value = action_value[inds, action] 79 | 80 | with torch.no_grad(): 81 | state_value = self.state_value( 82 | final_state, value=self.target 83 | ) 84 | done_mask = 1.0 - sample.done.float() 85 | target = rewards + self.discount * done_mask * state_value 86 | 87 | self.current_losses["mean state value"] = float(state_value.mean()) 88 | self.current_losses["mean target value"] = float(target.mean()) 89 | 90 | return action_value, target 91 | -------------------------------------------------------------------------------- /torchsupport/interacting/awr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | from torchsupport.interacting.off_policy_training import OffPolicyTraining 6 | 7 | class AWRTraining(OffPolicyTraining): 8 | def __init__(self, policy, value, agent, environment, 9 | beta=1.0, clip=None, **kwargs): 10 | self.value = ... 11 | super().__init__( 12 | policy, agent, environment, 13 | {"value": value}, **kwargs 14 | ) 15 | self.beta = beta 16 | self.clip = clip 17 | 18 | def action_nll(self, policy, action): 19 | return func.cross_entropy(policy, action, reduction='none') 20 | 21 | def policy_loss(self, policy, action, advantage): 22 | weight = torch.exp(advantage / self.beta) 23 | if self.clip is not None: 24 | weight = weight.clamp(0, self.clip) 25 | negative_log_likelihood = self.action_nll(policy, action) 26 | weighted_loss = negative_log_likelihood * weight 27 | return weighted_loss.mean() 28 | 29 | def run_policy(self, sample): 30 | initial_state = sample.initial_state 31 | action = sample.action 32 | returns = sample.returns 33 | 34 | with torch.no_grad(): 35 | value = self.value(initial_state) 36 | advantage = returns - value 37 | 38 | policy = self.policy(initial_state) 39 | 40 | return policy, action, advantage 41 | 42 | def auxiliary_loss(self, value, returns): 43 | return func.mse_loss(value.view(-1), returns.view(-1)) 44 | 45 | def run_auxiliary(self, sample): 46 | initial_state = sample.initial_state 47 | returns = sample.returns 48 | value = self.value(initial_state) 49 | 50 | return value, returns 51 | -------------------------------------------------------------------------------- /torchsupport/interacting/control.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | import torch.multiprocessing as mp 3 | 4 | class ReadControl: 5 | def __init__(self, ctrl): 6 | self.ctrl = ctrl 7 | 8 | def __enter__(self): 9 | with self.ctrl.read_lock: 10 | self.ctrl.read_count.value += 1 11 | if self.ctrl.read_count.value == 1: 12 | self.ctrl.write_lock.acquire() 13 | self.ctrl.owner.pull_changes() 14 | 15 | def __exit__(self, *args): 16 | with self.ctrl.read_lock: 17 | self.ctrl.read_count.value -= 1 18 | if self.ctrl.read_count.value == 0: 19 | self.ctrl.write_lock.release() 20 | 21 | class WriteControl: 22 | def __init__(self, ctrl): 23 | self.ctrl = ctrl 24 | 25 | def __enter__(self): 26 | self.ctrl.write_lock.acquire() 27 | self.ctrl.owner.pull_changes() 28 | 29 | def __exit__(self, *args): 30 | self.ctrl.owner.push_changes() 31 | self.ctrl.write_lock.release() 32 | 33 | class ReadWriteControl: 34 | def __init__(self, owner): 35 | self.owner = owner 36 | self.read_lock = mp.Lock() 37 | self.write_lock = mp.Lock() 38 | self.read_count = mp.Value("l", 0) 39 | self.read_count.value = 0 40 | 41 | self.timestamp = mp.Value("l", 0) 42 | self.local_timestamp = 0 43 | 44 | def clone(self, owner): 45 | result = copy(self) 46 | result.owner = owner 47 | return result 48 | 49 | def change(self, toggle=True): 50 | self.timestamp.value = self.timestamp.value + 1 51 | self.local_timestamp = self.timestamp.value 52 | 53 | def advance(self): 54 | self.local_timestamp = self.timestamp.value 55 | 56 | @property 57 | def changed(self): 58 | timestamp = self.timestamp.value 59 | local_timestamp = self.local_timestamp 60 | return local_timestamp != timestamp 61 | 62 | @property 63 | def read(self): 64 | return ReadControl(self) 65 | 66 | @property 67 | def write(self): 68 | return WriteControl(self) 69 | -------------------------------------------------------------------------------- /torchsupport/interacting/data_collector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as mp 3 | 4 | def _collector_worker(statistics, buffer, distributor, 5 | collector, done, piecewise): 6 | torch.set_num_threads(1) 7 | while True: 8 | if done.value: 9 | break 10 | result = collector.sample_trajectory() 11 | trajectory_statistics = collector.compute_statistics(result) 12 | trajectory = distributor.commit_trajectory(result) 13 | 14 | if piecewise: 15 | for item in trajectory: 16 | buffer.append(item) 17 | else: 18 | buffer.append(trajectory) 19 | 20 | statistics.update(trajectory_statistics) 21 | 22 | class ExperienceCollector: 23 | def __init__(self, distributor, collector, 24 | piecewise=True, n_workers=16): 25 | self.n_workers = n_workers 26 | self.piecewise = piecewise 27 | self.distributor = distributor 28 | self.collector = collector 29 | self.done = mp.Value("l", 0) 30 | self.procs = [] 31 | 32 | def start(self, statistics, buffer): 33 | for idx in range(self.n_workers): 34 | proc = mp.Process( 35 | target=_collector_worker, 36 | args=(statistics, buffer, self.distributor, 37 | self.collector, self.done, self.piecewise) 38 | ) 39 | self.procs.append(proc) 40 | proc.start() 41 | 42 | def schema(self): 43 | return self.distributor.schema(self.collector.schema()) 44 | 45 | def join(self): 46 | self.done.value = 1 47 | for proc in self.procs: 48 | proc.join() 49 | self.procs = [] 50 | -------------------------------------------------------------------------------- /torchsupport/interacting/distributor_task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class AbstractDistributor: 4 | def commit_trajectory(self, results): 5 | raise NotImplementedError("Abstract.") 6 | 7 | def schema(self, inputs): 8 | return inputs 9 | 10 | class DefaultDistributor(AbstractDistributor): 11 | def commit_trajectory(self, results): 12 | return results 13 | 14 | class ChunkedDistributor(DefaultDistributor): 15 | def __init__(self, chunk_size=10): 16 | super().__init__() 17 | self.chunk_size = chunk_size 18 | 19 | def stack(self, chunk): 20 | tmp = {} 21 | for item in chunk: 22 | item = item._as_dict() 23 | for field in item: 24 | if field not in tmp: 25 | tmp[field] = [] 26 | tmp[field].append(item[field].unsqueeze(0)) 27 | for field in tmp: 28 | tmp[field] = torch.cat(tmp[field], dim=0) 29 | return type(chunk[0])(**tmp) 30 | 31 | def schema(self, inputs): 32 | chunk = [inputs] * self.chunk_size 33 | result = self.stack(chunk) 34 | return result 35 | 36 | def commit_trajectory(self, results): 37 | if len(results) < self.chunk_size: 38 | results += results[-1] * (self.chunk_size - len(results)) 39 | 40 | chunked = [] 41 | for idx, _ in enumerate(results[:self.chunk_size + 1]): 42 | chunk = self.stack(results[idx:idx + self.chunk_size]) 43 | chunked.append(chunk) 44 | return chunked 45 | -------------------------------------------------------------------------------- /torchsupport/interacting/energies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/interacting/energies/__init__.py -------------------------------------------------------------------------------- /torchsupport/interacting/energies/energy.py: -------------------------------------------------------------------------------- 1 | from copy import copy, deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as func 6 | 7 | from torchsupport.data.namedtuple import namedtuple 8 | 9 | class Energy(nn.Module): 10 | data_type = namedtuple("Data", [ 11 | "batch", "energy", "args" 12 | ]) 13 | sample_type = namedtuple("SampleData", [ 14 | "data", "args" 15 | ]) 16 | def __init__(self, energy, keep_rate=0.95, device="cpu"): 17 | super().__init__() 18 | self.energy = energy 19 | self.device = device 20 | self.keep_rate = keep_rate 21 | 22 | self.example_batch, self.example_args = self._example_batch() 23 | self.example_energy = self._example_energy() 24 | 25 | def _example_batch(self): 26 | return self.prepare(1) 27 | 28 | def _example_energy(self): 29 | with torch.no_grad(): 30 | pass_batch, pass_args = self.pack_batch( 31 | self.example_batch, self.example_args 32 | ) 33 | return self.energy(pass_batch, *pass_args) 34 | 35 | def move(self): 36 | result = deepcopy(self) 37 | result.energy = self.energy.clone_to(self.device) 38 | return result 39 | 40 | def push(self): 41 | self.energy.push() 42 | 43 | def pull(self): 44 | self.energy.pull() 45 | 46 | def prepare(self, batch_size): 47 | data = torch.rand_like(batch_size, *self.shape) 48 | return self.sample_type( 49 | data=data, args=None 50 | ) 51 | 52 | def batch_size(self, batch, args): 53 | return batch.size(0) 54 | 55 | def pack_batch(self, batch, args): 56 | args = args or [] 57 | return batch, args 58 | 59 | def unpack_batch(self, batch): 60 | return batch 61 | 62 | def recombine_batch(self, batch, args, new_batch, new_args, drop): 63 | batch[drop] = new_batch 64 | if args: 65 | args[drop] = new_args 66 | return batch, args 67 | 68 | def reset(self, batch, args, energy): 69 | size = self.batch_size(batch, args) 70 | drop = self.keep_rate < torch.rand(size) 71 | drop_count = int(drop.sum()) 72 | new_batch, new_args = self.prepare(drop_count) 73 | batch, args = self.recombine_batch( 74 | batch, args, new_batch, new_args, drop 75 | ) 76 | return batch, args 77 | 78 | def schema(self): 79 | args = self.example_args or [None] 80 | return self.data_type( 81 | batch=self.example_batch[0], 82 | energy=self.example_energy[0], 83 | args=args[0] 84 | ) 85 | 86 | def forward(self, data, *args): 87 | return self.energy(data, *args) 88 | -------------------------------------------------------------------------------- /torchsupport/interacting/environments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/interacting/environments/__init__.py -------------------------------------------------------------------------------- /torchsupport/interacting/environments/cartpole.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gym 3 | 4 | from torchsupport.interacting.environments.environment import Environment 5 | 6 | class CartPole(Environment): 7 | def __init__(self, scale=False): 8 | self.env = gym.make("CartPole-v1") 9 | self.initialized = False 10 | self.state = None 11 | self.done = False 12 | self.scale = scale 13 | 14 | def reset(self): 15 | self.state = torch.tensor(self.env.reset(), dtype=torch.float) 16 | self.initialized = True 17 | self.done = False 18 | 19 | def act(self, action): 20 | observation, reward, done, _ = self.env.step(int(action)) 21 | self.state = torch.tensor(observation, dtype=torch.float) 22 | self.done = done 23 | if self.scale: 24 | reward = reward / 100 25 | return torch.tensor([reward]) 26 | 27 | def observe(self): 28 | return self.state 29 | 30 | def is_done(self): 31 | return self.done 32 | 33 | @property 34 | def action_space(self): 35 | return self.env.action_space 36 | 37 | @property 38 | def observation_space(self): 39 | return self.env.observation_space 40 | 41 | def schema(self): 42 | state = torch.tensor([0.0] * 4) 43 | reward = torch.tensor(0.0) 44 | done = torch.tensor(0) 45 | action = torch.tensor(0) 46 | sample = self.data_type( 47 | state=state, action=action, rewards=reward, done=done 48 | ) 49 | return sample 50 | -------------------------------------------------------------------------------- /torchsupport/interacting/environments/coinrun.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gym 3 | 4 | from torchsupport.interacting.environments.environment import Environment 5 | 6 | class CoinRun(Environment): 7 | def __init__(self, history=1): 8 | self.env = gym.make("procgen:procgen-coinrun-v0", use_sequential_levels=True) 9 | self.initialized = False 10 | self.state = None 11 | self.done = False 12 | self.history = history 13 | self._history = torch.zeros(self.history, 3, 64, 64) 14 | 15 | def add_history(self, state): 16 | self._history = self._history.roll(1, dims=0) 17 | self._history[0] = state 18 | return self._history.view(self.history * 3, 64, 64) 19 | 20 | def reset(self): 21 | self._history.zero_() 22 | state = torch.tensor(self.env.reset(), dtype=torch.float) 23 | state = state.permute(2, 0, 1).contiguous() / 255 24 | self.state = self.add_history(state) 25 | self.initialized = True 26 | self.done = False 27 | 28 | def act(self, action): 29 | observation, reward, done, _ = self.env.step(int(action)) 30 | state = torch.tensor(observation, dtype=torch.float) / 255 31 | state = state.permute(2, 0, 1).contiguous() 32 | self.state = self.add_history(state) 33 | self.done = done 34 | return torch.tensor([reward]) 35 | 36 | def observe(self): 37 | return self.state 38 | 39 | def is_done(self): 40 | return self.done 41 | 42 | @property 43 | def action_space(self): 44 | return self.env.action_space 45 | 46 | @property 47 | def observation_space(self): 48 | return self.env.observation_space 49 | 50 | def schema(self): 51 | state = torch.zeros(self.history * 3, 64, 64) 52 | reward = torch.tensor(0.0) 53 | done = torch.tensor(0) 54 | action = torch.tensor(0) 55 | sample = self.data_type( 56 | state=state, action=action, rewards=reward, done=done 57 | ) 58 | return sample 59 | -------------------------------------------------------------------------------- /torchsupport/interacting/environments/environment.py: -------------------------------------------------------------------------------- 1 | from torchsupport.data.namedtuple import namedtuple 2 | 3 | class Environment: 4 | data_type = namedtuple("Data", [ 5 | "state", "action", "rewards", "done" 6 | ]) 7 | def reset(self): 8 | raise NotImplementedError 9 | 10 | def push_changes(self): 11 | pass 12 | 13 | def pull_changes(self): 14 | pass 15 | 16 | def action_space(self): 17 | raise NotImplementedError 18 | 19 | def observation_space(self): 20 | raise NotImplementedError 21 | 22 | def is_done(self): 23 | raise NotImplementedError 24 | 25 | def observe(self): 26 | raise NotImplementedError 27 | 28 | def act(self, action): 29 | raise NotImplementedError 30 | 31 | def schema(self): 32 | raise NotImplementedError 33 | -------------------------------------------------------------------------------- /torchsupport/interacting/interacting_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | from torchsupport.training.training import Training 6 | 7 | class InteractingTraining(Training): 8 | def __init__(self, tasks, valid_tasks=None, **kwargs): 9 | super().__init__(**kwargs) 10 | self.tasks = tasks 11 | self.valid_tasks = valid_tasks or [] 12 | for task in self.tasks: 13 | task.register_training(self) 14 | 15 | def step(self): 16 | for task in self.tasks: 17 | task_data = task.sample() 18 | task.step(task_data) 19 | self.each_step() 20 | 21 | def checkpoint(self): 22 | for task in self.tasks: 23 | task.checkpoint() 24 | self.each_checkpoint() 25 | 26 | def validate(self): 27 | for task in self.valid_tasks: 28 | task.valid_step() 29 | 30 | def train(self): 31 | for task in self.tasks: 32 | task.initialize() 33 | for _ in range(self.max_steps): 34 | self.step() 35 | if self.step_id % self.report_interval == 0: 36 | self.validate() 37 | if self.step_id % self.checkpoint_interval == 0: 38 | self.checkpoint() 39 | self.step_id += 1 40 | 41 | task_trainables = { 42 | task.name: task.trainables 43 | for task in self.tasks 44 | } 45 | 46 | return task_trainables 47 | -------------------------------------------------------------------------------- /torchsupport/interacting/off_ebm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchsupport.interacting.off_energy_training import OffEnergyTraining 4 | 5 | class OffEBMTraining(OffEnergyTraining): 6 | def __init__(self, score, energy, data, 7 | off_energy_weight=0.0, 8 | off_energy_decay=0.1, 9 | auxiliary_steps=0, **kwargs): 10 | super().__init__( 11 | score, energy, data, 12 | auxiliary_steps=auxiliary_steps, 13 | **kwargs 14 | ) 15 | self.off_energy_weight = off_energy_weight 16 | self.off_energy_decay = off_energy_decay 17 | 18 | def run_score(self, sample, data): 19 | fake = sample.final_state 20 | fake_args = sample.args or [] 21 | real, *real_args = data 22 | fake_sample_energy = sample.final_energy 23 | fake_energy = self.score(fake, *fake_args) 24 | real_energy = self.score(real, *real_args) 25 | energy_difference = fake_energy - fake_sample_energy 26 | 27 | return real_energy, fake_energy, energy_difference 28 | 29 | def score_loss(self, real_energy, fake_energy, energy_difference): 30 | regularization = self.decay * ((real_energy ** 2).mean() + (fake_energy ** 2).mean()) 31 | 32 | fake_weight = torch.exp( 33 | -self.off_energy_weight * abs(energy_difference) 34 | ).detach() 35 | weight_sum = fake_weight.sum() 36 | weight_mean = weight_sum / fake_weight.size(0) 37 | 38 | real_mean = real_energy.mean() 39 | fake_mean = (fake_energy * fake_weight).sum() / weight_sum 40 | 41 | off_energy_loss = self.off_energy_decay * (energy_difference ** 2).mean() 42 | 43 | ebm = real_mean - fake_mean 44 | self.current_losses["real"] = float(real_mean) 45 | self.current_losses["weight"] = float(weight_mean) 46 | self.current_losses["off energy"] = float(off_energy_loss) 47 | self.current_losses["energy difference"] = float(abs(energy_difference).mean()) 48 | self.current_losses["fake"] = float(fake_mean) 49 | self.current_losses["fake raw"] = float(fake_energy.mean()) 50 | self.current_losses["regularization"] = float(regularization) 51 | self.current_losses["ebm"] = float(ebm) 52 | return regularization + ebm + off_energy_loss 53 | 54 | def run_auxiliary(self, data): 55 | pass 56 | 57 | def auxiliary_loss(self, *args): 58 | return 0.0 59 | -------------------------------------------------------------------------------- /torchsupport/interacting/policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/interacting/policies/__init__.py -------------------------------------------------------------------------------- /torchsupport/interacting/policies/basic.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torch.distributions import Categorical 5 | 6 | from torchsupport.data.namedtuple import namedtuple, NamedTuple 7 | 8 | from torchsupport.interacting.policies.policy import Policy, ModulePolicy 9 | 10 | class RandomPolicy(Policy): 11 | data_type = namedtuple("Data", ["logits", "outputs"]) 12 | def __init__(self, actions): 13 | super().__init__() 14 | self.logits = torch.ones(actions) / actions 15 | 16 | def push(self): 17 | pass 18 | 19 | def pull(self): 20 | pass 21 | 22 | def schema(self): 23 | return self.data_type( 24 | logits=self.logits, outputs=None 25 | ) 26 | 27 | def forward(self, state, hidden=None): 28 | action = Categorical(logits=self.logits).sample() 29 | return action, self.logits, hidden 30 | 31 | class CategoricalPolicy(ModulePolicy): 32 | def forward(self, state, hidden=None): 33 | if isinstance(state, (list, tuple, NamedTuple)): 34 | state = [ 35 | item.unsqueeze(0) 36 | for item in state 37 | ] 38 | else: 39 | state = state.unsqueeze(0) 40 | hidden = hidden.unsqueeze(0) if hidden else None 41 | logits = self.policy( 42 | state, hidden=hidden 43 | ) 44 | outputs = [None] 45 | if isinstance(logits, tuple): 46 | logits, outputs = logits 47 | action = Categorical(logits=logits).sample() 48 | 49 | return action[0], logits[0], outputs[0] 50 | 51 | class EpsilonGreedyPolicy(ModulePolicy): 52 | def __init__(self, policy, epsilon=0.1): 53 | super().__init__(policy) 54 | self.epsilon = epsilon 55 | 56 | def forward(self, state, hidden=None): 57 | explore = random.random() < self.epsilon 58 | state = state.unsqueeze(0) 59 | hidden = hidden.unsqueeze(0) if hidden else None 60 | logits = self.policy( 61 | state, hidden=hidden 62 | ) 63 | outputs = [None] 64 | if isinstance(logits, tuple): 65 | logits, outputs = logits 66 | 67 | action = logits.argmax(dim=1) 68 | 69 | if explore: 70 | logits = torch.ones_like(logits) 71 | logits = logits / logits.size(1) 72 | action = Categorical(logits=logits).sample() 73 | 74 | return action[0], logits[0], outputs[0] 75 | 76 | class CategoricalGreedyPolicy(EpsilonGreedyPolicy): 77 | def forward(self, state, hidden=None): 78 | explore = random.random() < self.epsilon 79 | state = state.unsqueeze(0) 80 | hidden = hidden.unsqueeze(0) if hidden else None 81 | logits = self.policy( 82 | state, hidden=hidden 83 | ) 84 | outputs = [None] 85 | if isinstance(logits, tuple): 86 | logits, outputs = logits 87 | action = logits.argmax(dim=1) 88 | if explore: 89 | action = Categorical(logits=logits).sample() 90 | 91 | return action[0], logits[0], outputs[0] 92 | -------------------------------------------------------------------------------- /torchsupport/interacting/policies/mcts.py: -------------------------------------------------------------------------------- 1 | # TODO -------------------------------------------------------------------------------- /torchsupport/interacting/policies/policy.py: -------------------------------------------------------------------------------- 1 | from copy import copy, deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as func 6 | 7 | class Policy(nn.Module): 8 | def schema(self): 9 | raise NotImplementedError("Abstract.") 10 | 11 | def move(self): 12 | return self 13 | 14 | def push(self): 15 | raise NotImplementedError("Abstract.") 16 | 17 | def pull(self): 18 | raise NotImplementedError("Abstract.") 19 | 20 | def forward(self, state, inputs=None): 21 | raise NotImplementedError("Abstract.") 22 | 23 | class ModulePolicy(Policy): 24 | def __init__(self, policy, device="cpu"): 25 | super().__init__() 26 | self.policy = policy 27 | self.device = device 28 | 29 | def move(self): 30 | result = deepcopy(self) 31 | result.policy = self.policy.clone_to(self.device) 32 | return result 33 | 34 | def push(self): 35 | self.policy.push() 36 | 37 | def pull(self): 38 | self.policy.pull() 39 | 40 | def schema(self): 41 | return self.policy.schema() 42 | -------------------------------------------------------------------------------- /torchsupport/interacting/shared_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy, copy 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.multiprocessing as mp 7 | 8 | from torchsupport.data.io import to_device 9 | from torchsupport.interacting.control import ReadWriteControl 10 | 11 | class InertModule(nn.Module): 12 | def __init__(self, module): 13 | super().__init__() 14 | self.module = module 15 | 16 | def schema(self): 17 | return self.module.schema() 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.module(*args, **kwargs) 21 | 22 | def _share_skip_inert(module): 23 | if isinstance(module, InertModule): 24 | pass 25 | else: 26 | module.share_memory() 27 | 28 | class SharedModule(nn.Module): 29 | def __init__(self, module, dynamic=False): 30 | super().__init__() 31 | self.ctrl = ReadWriteControl(self) 32 | self.dynamic = dynamic 33 | self.source_process = os.getpid() 34 | self.shared_module = deepcopy(module).cpu().share_memory() 35 | self._module = InertModule(module) 36 | 37 | def __deepcopy__(self, memo): 38 | cls = self.__class__ 39 | result = cls(deepcopy(self._module.module, memo), dynamic=self.dynamic) 40 | result.ctrl = self.ctrl.clone(result) 41 | result.source_process = self.source_process 42 | result.shared_module = self.shared_module 43 | return result 44 | 45 | def clone_to(self, target="cpu"): 46 | result = deepcopy(self) 47 | result._module = result._module.to(target) 48 | 49 | return result 50 | 51 | def is_clone(self): 52 | return os.getpid() != self.source_process 53 | 54 | def share_memory(self): 55 | self.apply(_share_skip_inert) 56 | 57 | def schema(self): 58 | return self._module.schema() 59 | 60 | def pull_changes(self): 61 | if self.ctrl.changed: 62 | shared_state_dict = self.shared_module.state_dict() 63 | self._module.module.load_state_dict(shared_state_dict) 64 | self.ctrl.advance() 65 | 66 | def push_changes(self): 67 | state_dict = self._module.module.state_dict() 68 | state_dict = to_device(state_dict, "cpu") 69 | self.shared_module.load_state_dict(state_dict) 70 | self.ctrl.change() 71 | 72 | def pull(self): 73 | with self.ctrl.read: 74 | pass # NOTE: just pull changes 75 | 76 | def push(self): 77 | with self.ctrl.write: 78 | pass # NOTE: just push changes 79 | 80 | def forward(self, *args, **kwargs): 81 | if self.dynamic: 82 | with self.ctrl.read: 83 | return self._module(*args, **kwargs) 84 | else: 85 | return self._module(*args, **kwargs) 86 | 87 | # class SharedPolicy: 88 | # def __call__(self, state, inputs=None): 89 | 90 | -------------------------------------------------------------------------------- /torchsupport/interacting/stats.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as mp 3 | 4 | from torchsupport.interacting.control import ReadWriteControl 5 | 6 | class Statistics: 7 | def update(self, stats): 8 | pass 9 | 10 | class ExperienceStatistics(Statistics): 11 | def __init__(self, decay=0.6): 12 | self.ctrl = ReadWriteControl(self) 13 | self.decay = decay 14 | self._total = torch.tensor([0.0]) 15 | self._length = torch.tensor([0.0]) 16 | self._total_steps = torch.tensor([0]) 17 | self._total.share_memory_() 18 | self._length.share_memory_() 19 | self._total_steps.share_memory_() 20 | 21 | def pull_changes(self): 22 | pass 23 | 24 | def push_changes(self): 25 | pass 26 | 27 | def update(self, stats): 28 | with self.ctrl.write: 29 | self._total[0] = (1 - self.decay) * stats.total + self.decay * self._total[0] 30 | self._length[0] = (1 - self.decay) * stats.length + self.decay * self._length[0] 31 | self._total_steps[0] = self._total_steps[0] + stats.length 32 | 33 | @property 34 | def total(self): 35 | with self.ctrl.read: 36 | return self._total 37 | 38 | @property 39 | def length(self): 40 | with self.ctrl.read: 41 | return self._length 42 | 43 | @property 44 | def steps(self): 45 | with self.ctrl.read: 46 | return self._total_steps 47 | 48 | class EnergyStatistics(Statistics): 49 | pass 50 | -------------------------------------------------------------------------------- /torchsupport/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from torchsupport.modules.attention import * 2 | from torchsupport.modules.basic import * 3 | # from torchsupport.modules.capsule import * 4 | from torchsupport.modules.combination import * 5 | from torchsupport.modules.dynamic import * 6 | from torchsupport.modules.geometric_vector_perceptron import * 7 | from torchsupport.modules.gradient import * 8 | from torchsupport.modules.masked import * 9 | from torchsupport.modules.multiscale import * 10 | from torchsupport.modules.normalization import * 11 | from torchsupport.modules.polynomial import * 12 | from torchsupport.modules.recurrent import * 13 | from torchsupport.modules.reduction import * 14 | from torchsupport.modules.refine import * 15 | from torchsupport.modules.residual import * 16 | from torchsupport.modules.rezero import * 17 | from torchsupport.modules.invertible import * 18 | from torchsupport.modules.routing import * 19 | from torchsupport.modules.separable import * 20 | from torchsupport.modules.vector_quantisation import * 21 | from torchsupport.modules.generative import * 22 | -------------------------------------------------------------------------------- /torchsupport/modules/activations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/modules/activations/__init__.py -------------------------------------------------------------------------------- /torchsupport/modules/activations/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def torus(sin, cos): 4 | r"""Pins a pair of N-dimensional coordinates to 5 | proper coordinates on the N-torus. 6 | 7 | Args: 8 | sin (torch.Tensor): unnormalized sine. 9 | cos (torch.Tensor): unnormalized cosine. 10 | 11 | Returns: 12 | Corresponding coordinates on a torus. 13 | """ 14 | return torch.atan2(sin, cos) 15 | -------------------------------------------------------------------------------- /torchsupport/modules/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .gan import * 2 | -------------------------------------------------------------------------------- /torchsupport/modules/backbones/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/modules/backbones/diffusion/__init__.py -------------------------------------------------------------------------------- /torchsupport/modules/backbones/gan/__init__.py: -------------------------------------------------------------------------------- 1 | from .dcgan import ( 2 | DCGANDiscriminator, DCGANGenerator 3 | ) 4 | 5 | from .resgan import ( 6 | ResGenerator, ResDiscriminator, ResGeneratorBlock 7 | ) 8 | -------------------------------------------------------------------------------- /torchsupport/modules/backbones/gan/discriminator_components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class DistanceDiscriminator(nn.Module): 6 | def __init__(self, batch_size, out_size): 7 | super().__init__() 8 | self.batch_size = batch_size 9 | self.batch_combine = nn.Linear(batch_size, out_size) 10 | 11 | def forward(self, data): 12 | data = data.view(data.size(0), -1) 13 | dist = ((data[None, :] - data[:, None]).norm(dim=1) + 1e-6).log() 14 | return self.batch_combine(dist) 15 | 16 | class DynamicAugmentation(nn.Module): 17 | def __init__(self, transforms, target=0.6, 18 | p=0.0, step=0.01, every=4): 19 | super().__init__() 20 | self.transforms = transforms 21 | self.target = target 22 | self.p = p 23 | self.step = step 24 | self.every = every 25 | self.tick = 0 26 | 27 | def update(self, result): 28 | if self.tick % self.every == 0: 29 | with torch.no_grad(): 30 | sign = (result.sign() + 1).mean() / 2 31 | if sign < self.target: 32 | self.p -= self.step 33 | else: 34 | self.p += self.step 35 | self.p = max(0, min(1, self.p)) 36 | self.tick = 0 37 | self.tick += 1 38 | 39 | def forward(self, data): 40 | for transform in self.transforms: 41 | aug = transform(data) 42 | if isinstance(data, (list, tuple)): 43 | tmp = [] 44 | mask = (torch.rand(aug[0].size(0)) < self.p).to(aug[0].device) 45 | for item, aug_item in zip(data, aug): 46 | mm = mask.view(mask.size(0), *([1] * (item.dim() - 1))) 47 | item = ((~mm).float() * item + mm.float() * aug_item) 48 | tmp.append(item) 49 | data = tuple(tmp) 50 | else: 51 | mask = (torch.rand(aug.size(0)) < self.p).to(data.device) 52 | mask = mask.view(mask.size(0), *([1] * (data.dim() - 1))) 53 | data = ((~mask).float() * data + mask.float() * aug) 54 | return data 55 | -------------------------------------------------------------------------------- /torchsupport/modules/backbones/gan/resgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class ResGeneratorBlock(nn.Module): 6 | def __init__(self, in_size, out_size, kernel_size=3, activation=None, weight=None): 7 | super().__init__() 8 | self.activation = activation or nn.LeakyReLU(0.2) 9 | hidden_size = min(in_size, out_size) 10 | self.blocks = nn.Sequential( 11 | self.activation, 12 | nn.Conv2d(in_size, hidden_size, kernel_size, padding=kernel_size // 2), 13 | self.activation, 14 | nn.Conv2d(hidden_size, out_size, kernel_size, padding=kernel_size // 2) 15 | ) 16 | self.skip = nn.Conv2d(in_size, out_size, 1, bias=False) 17 | self.weight = weight 18 | if weight is None: 19 | self.weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 20 | 21 | def forward(self, inputs): 22 | out = self.weight * self.blocks(inputs) + self.skip(inputs) 23 | return out 24 | 25 | class ResGenerator(nn.Module): 26 | def __init__(self, in_size=100, base_channels=64, channel_factors=None, 27 | kernel_size=3, activation=None, weight=None): 28 | super().__init__() 29 | self.initial = nn.Linear(in_size, 4 * 4 * base_channels * channel_factors[0]) 30 | self.blocks = nn.ModuleList([ 31 | ResGeneratorBlock( 32 | in_factor * base_channels, 33 | out_factor * base_channels, 34 | kernel_size=kernel_size, 35 | activation=activation, 36 | weight=weight 37 | ) 38 | for in_factor, out_factor in zip( 39 | channel_factors[:-1], channel_factors[1:] 40 | ) 41 | ]) 42 | 43 | def forward(self, inputs): 44 | out = self.initial(inputs).view(inputs.size(0), -1, 4, 4) 45 | for block in self.blocks: 46 | out = func.interpolate(out, scale_factor=2) 47 | out = block(out) 48 | return out 49 | 50 | class ResDiscriminator(nn.Module): 51 | def __init__(self, in_size=3, base_channels=64, channel_factors=None, 52 | kernel_size=3, activation=None, weight=None): 53 | super().__init__() 54 | self.preprocess = nn.Conv2d( 55 | in_size, base_channels * channel_factors[0], 3, padding=1 56 | ) 57 | self.blocks = nn.ModuleList([ 58 | ResGeneratorBlock( 59 | in_factor * base_channels, 60 | out_factor * base_channels, 61 | kernel_size=kernel_size, 62 | activation=activation, 63 | weight=weight 64 | ) 65 | for in_factor, out_factor in zip( 66 | channel_factors[:-1], channel_factors[1:] 67 | ) 68 | ]) 69 | 70 | def forward(self, inputs): 71 | out = self.preprocess(inputs) 72 | for block in self.blocks: 73 | out = block(out) 74 | out = func.avg_pool2d(out, 2) 75 | out = func.adaptive_avg_pool2d(out, 1) 76 | return out 77 | -------------------------------------------------------------------------------- /torchsupport/modules/backbones/multiscale/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/modules/backbones/multiscale/__init__.py -------------------------------------------------------------------------------- /torchsupport/modules/backbones/vae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/modules/backbones/vae/__init__.py -------------------------------------------------------------------------------- /torchsupport/modules/dynamic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | def _dynamic_convnd(inputs, weight, bias=None, N=2, **kwargs): 6 | conv = getattr(func, f"conv{N}d") 7 | if weight.dim() == N + 2: 8 | return conv(inputs, weight, bias=bias, **kwargs) 9 | batch_size = weight.size(0) 10 | inputs = inputs.view(-1, *inputs.shape[2:])[None] 11 | weight = weight.view(-1, *weight.shape[2:]) 12 | if bias is not None: 13 | bias = bias.view(-1) 14 | result = conv(inputs, weight, bias=bias, groups=batch_size, **kwargs) 15 | result = result.view(batch_size, -1, *result.shape[2:]) 16 | return result 17 | 18 | def dynamic_conv1d(inputs, weight, bias=None, **kwargs): 19 | r"""Dynamic 1d convolution. For details, see `torch.nn.functional.conv1d` 20 | 21 | Args: 22 | inputs (torch.Tensor :math:`(B, C_i, W)`): input tensor. 23 | weight (torch.Tensor :math:`(B, C_o, C_i, K)`): batch of weight tensors. 24 | bias (torch.Tensor :math:`B, C_o`): batch of bias tensors. 25 | """ 26 | return _dynamic_convnd(inputs, weight, bias=bias, N=1, **kwargs) 27 | 28 | def dynamic_conv2d(inputs, weight, bias=None, **kwargs): 29 | r"""Dynamic 2d convolution. For details, see `torch.nn.functional.conv2d` 30 | 31 | Args: 32 | inputs (torch.Tensor :math:`(B, C_i, H, W)`): input tensor. 33 | weight (torch.Tensor :math:`(B, C_o, C_i, K_H, K_W)`): batch of weight tensors. 34 | bias (torch.Tensor :math:`B, C_o`): batch of bias tensors. 35 | """ 36 | return _dynamic_convnd(inputs, weight, bias=bias, N=2, **kwargs) 37 | 38 | def dynamic_conv3d(inputs, weight, bias=None, **kwargs): 39 | r"""Dynamic 3d convolution. For details, see `torch.nn.functional.conv3d` 40 | 41 | Args: 42 | inputs (torch.Tensor :math:`(B, C_i, X, Y, Z)`): input tensor. 43 | weight (torch.Tensor :math:`(B, C_o, C_i, K_X, K_Y, K_Z)`): batch of weight tensors. 44 | bias (torch.Tensor :math:`B, C_o`): batch of bias tensors. 45 | """ 46 | return _dynamic_convnd(inputs, weight, bias=bias, N=3, **kwargs) 47 | 48 | def dynamic_linear(inputs, weight, bias=None): 49 | r"""Dynamic linear layer. For details, see `torch.nn.functional.linear` 50 | 51 | Args: 52 | inputs (torch.Tensor :math:`(B, C_i, W)`): input tensor. 53 | weight (torch.Tensor :math:`(B, C_o, C_i)`): batch of weight tensors. 54 | bias (torch.Tensor :math:`B, C_o`): batch of bias tensors. 55 | """ 56 | if weight.dim() == 2: 57 | return func.linear(inputs, weight, bias=bias) 58 | result = torch.bmm(inputs[:, None], weight.transpose(1, 2))[:, 0] 59 | if bias is not None: 60 | result = result + bias 61 | return result 62 | -------------------------------------------------------------------------------- /torchsupport/modules/dynet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class DyNetNd(nn.Module): 6 | r'''Implements dynamic convolution as introduced by Zhang et al. 7 | (https://arxiv.org/abs/2004.10694). This version is implemented with conv2D. 8 | Args: 9 | in_size (int): input size 10 | out_size (int): desired output size 11 | k_number (int): number of kernels 12 | k_size (int): size of the kernel 13 | N (int): dimensions for the convolution 14 | **kwargs: for details see Conv1d/Conv2D/Conv3D 15 | ''' 16 | def __init__(self, in_size, out_size, k_number, k_size, N=1,**kwargs): 17 | super().__init__() 18 | self.N = N 19 | self.in_size = in_size 20 | self.out_size = out_size 21 | self.k_number = k_number 22 | self.k_size = k_size 23 | self.kwargs = kwargs 24 | self.weight = nn.Parameter(torch.randn(k_number, out_size, in_size, *(N*[k_size]))) 25 | self.fc = nn.Linear(in_size, k_number) 26 | 27 | def forward(self, inputs): 28 | conv = getattr(nn, f"Conv{self.N}d") 29 | adaptive_avg = getattr(func, f"adaptive_avg_pool{self.N}d") 30 | avg = adaptive_avg(inputs, 1) # (batch_size, in_size, 1, 1) 31 | avg = avg.view(inputs.size(0), -1) 32 | kernel_weights = self.fc(avg) 33 | kernel_weights = kernel_weights[2 * [slice(None)] + (self.N + 2) * [None]] #Adding dimensions to fit the multiplication in the next line 34 | dyn_kernels = torch.sum(kernel_weights * self.weight[None], dim=1) 35 | dyn_kernels = dyn_kernels.view(*dyn_kernels.shape[1:]) 36 | 37 | batch_size = inputs.size(0) 38 | inputs = inputs.view(1, -1, *inputs.shape[2:]) 39 | result = conv(inputs, groups=batch_size, weight=dyn_kernels, **self.kwargs) 40 | return result.view(batch_size, -1, *result.shape[2:]) 41 | 42 | class DyNet2d(DyNetNd): 43 | def __init__(self, in_size, out_size, k_number, k_size, **kwargs): 44 | super().__init__(in_size, out_size, k_number, k_size, N=2, **kwargs) 45 | 46 | class DyNet1d(DyNetNd): 47 | def __init__(self, in_size, out_size, k_number, k_size, **kwargs): 48 | super().__init__(in_size, out_size, k_number, k_size, N=1, **kwargs) 49 | 50 | class DyNet3d(DyNetNd): 51 | def __init__(self, in_size, out_size, k_number, k_size, **kwargs): 52 | super().__init__(in_size, out_size, k_number, k_size, N=3, **kwargs) -------------------------------------------------------------------------------- /torchsupport/modules/geometric_vector_perceptron.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class GeometricVectorPerceptron(nn.Module): 6 | r"""Implements the Geometric Vector Perceptron layer (GVP) from 7 | arXiv:2009.01411 (Jing et al. 2020). It transforms a set of scalar 8 | and vector features in a rotation equivariant way. 9 | 10 | Args: 11 | in_scalars (int): number of input scalar features. 12 | in_vectors (int): number of input vector features. 13 | out_scalars (int): number of output scalar features. 14 | out_vectors (int): number of output vector features. 15 | hidden_vectors (int): number of hidden vector features. 16 | scalar_activation (callable): activation applied to scalar features. 17 | vector_activation (callable): activation applied to scale vector features. 18 | 19 | Shape: 20 | - Scalar inputs: :math:`(N, C_{S, in})` 21 | - Vector inputs: :math:`(N, 3, C_{V, in})` 22 | - Scalar outputs: :math:`(N, C_{S, out})` 23 | - Vector outputs: :math:`(N, 3, C_{V, out})` 24 | """ 25 | def __init__(self, in_scalars, in_vectors, 26 | out_scalars, out_vectors, 27 | hidden_vectors=None, 28 | scalar_activation=func.relu, 29 | vector_activation=torch.sigmoid): 30 | super().__init__() 31 | hidden_vectors = hidden_vectors or max(in_vectors, out_vectors) 32 | self.scalar_activation = scalar_activation 33 | self.vector_activation = vector_activation 34 | self.project_vectors = nn.Linear(in_vectors, hidden_vectors, bias=False) 35 | self.predict_vectors = nn.Linear(hidden_vectors, out_vectors, bias=False) 36 | self.project_scalars = nn.Linear(in_scalars + hidden_vectors, out_scalars) 37 | 38 | def forward(self, scalars, vectors): 39 | vector_projection = self.project_vectors(vectors) 40 | vector_prediction = self.predict_vectors(vector_projection) 41 | projection_norm = vector_projection.norm(p=2, dim=1) 42 | prediction_norm = vector_prediction.norm(p=2, dim=1, keepdim=True) 43 | scalar_projection = self.project_scalars( 44 | torch.cat((scalars, projection_norm), dim=1) 45 | ) 46 | scalar_result = self.scalar_activation(scalar_projection) 47 | vector_result = self.vector_activation(prediction_norm) * vector_prediction 48 | return scalar_result, vector_result 49 | -------------------------------------------------------------------------------- /torchsupport/modules/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/modules/losses/__init__.py -------------------------------------------------------------------------------- /torchsupport/modules/losses/clustering.py: -------------------------------------------------------------------------------- 1 | # DCC loss 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | class DCCWeightedELoss(nn.Module): 8 | def __init__(self, size_average=True): 9 | super(DCCWeightedELoss, self).__init__() 10 | self.size_average = size_average 11 | 12 | def forward(self, inputs, outputs, weights): 13 | out = (inputs - outputs).view(len(inputs), -1) 14 | out = torch.sum(weights * torch.norm(out, p=2, dim=1)**2) 15 | 16 | assert np.isfinite(out.data.cpu().numpy()).all(), 'Nan found in data' 17 | 18 | if self.size_average: 19 | out = out / inputs.nelement() 20 | 21 | return out 22 | 23 | class DCCLoss(nn.Module): 24 | def __init__(self, nsamples, ndim, initU, size_average=True): 25 | super(DCCLoss, self).__init__() 26 | self.dim = ndim 27 | self.nsamples = nsamples 28 | self.size_average = size_average 29 | self.U = nn.Parameter(torch.Tensor(self.nsamples, self.dim)) 30 | self.reset_parameters(initU+1e-6*np.random.randn(*initU.shape).astype(np.float32)) 31 | 32 | def reset_parameters(self, initU): 33 | assert np.isfinite(initU).all(), 'Nan found in initialization' 34 | self.U.data = torch.from_numpy(initU) 35 | 36 | def forward(self, enc_out, sampweights, pairweights, pairs, index, _sigma1, _sigma2, _lambda): 37 | centroids = self.U[index] 38 | 39 | out1 = torch.norm((enc_out - centroids).view(len(enc_out), -1), p=2, dim=1) ** 2 40 | out11 = torch.sum(_sigma1 * sampweights * out1 / (_sigma1 + out1)) 41 | 42 | out2 = torch.norm((centroids[pairs[:, 0]] - centroids[pairs[:, 1]]).view(len(pairs), -1), p=2, dim=1) ** 2 43 | 44 | out21 = _lambda * torch.sum(_sigma2 * pairweights * out2 / (_sigma2 + out2)) 45 | 46 | out = out11 + out21 47 | 48 | if self.size_average: 49 | out = out / enc_out.nelement() 50 | 51 | return out 52 | -------------------------------------------------------------------------------- /torchsupport/modules/losses/generative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as func 4 | 5 | def normalized_distance(data, distance): 6 | data = data.view(data.size(0), -1) 7 | reference = data[:, None] 8 | comparison = data[None, :] 9 | result = distance(reference, comparison) 10 | result = result / result.sum(dim=1, keepdim=True).detach() 11 | return result 12 | 13 | class NormalizedDistance(nn.Module): 14 | def __init__(self, distance=None): 15 | super().__init__() 16 | self.distance = distance 17 | if self.distance is None: 18 | self.distance = lambda x, y: (x - y).norm(dim=-1) 19 | 20 | def forward(self, data): 21 | return normalized_distance(data, self.distance) 22 | 23 | def normalized_diversity_loss(x, y, d_x, d_y, alpha=1.0): 24 | D_x = normalized_distance(x, d_x) 25 | D_y = normalized_distance(y, d_y) 26 | result = func.relu(alpha * D_x - D_y) 27 | size = result.size(0) 28 | result[torch.arange(0, size), torch.arange(0, size)] = 0.0 29 | result = result.sum() / (size * (size - 1)) 30 | return result 31 | 32 | class NormalizedDiversityLoss(nn.Module): 33 | def __init__(self, d_x, d_y, alpha=1.0): 34 | super().__init__() 35 | self.d_x = d_x 36 | self.d_y = d_y 37 | self.alpha = alpha 38 | 39 | def forward(self, x, y): 40 | return normalized_diversity_loss(x, y, self.d_x, self.d_y, self.alpha) 41 | 42 | def mode_seeking_loss(x_0, x_1, z_0, z_1, d_x, d_z): 43 | result = d_x(x_0, x_1) / d_z(z_0, z_1) 44 | return 1 / (result + 1e-6) 45 | 46 | class ModeSeekingLoss(nn.Module): 47 | def __init__(self, d_x, d_z): 48 | super().__init__() 49 | self.d_x = d_x 50 | self.d_z = d_z 51 | 52 | def forward(self, xs, zs): 53 | return mode_seeking_loss(*xs, *zs, self.d_x, self.d_z) 54 | -------------------------------------------------------------------------------- /torchsupport/modules/masked.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class MaskedConv1d(nn.Conv1d): 6 | r"""1D causally masked convolution for autoregressive 7 | convolutional models. 8 | 9 | Thin wrapper around :class:`nn.Conv1d`.""" 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | mask = torch.zeros_like(self.weight) 13 | center_x = self.kernel_size[0] // 2 14 | for idx in range(self.kernel_size[0]): 15 | mask[:, :, idx] = int(idx < center_x) 16 | self.mask = mask 17 | 18 | def forward(self, inputs): 19 | self.weight.data = self.mask * self.weight.data 20 | return super().forward(inputs) 21 | 22 | class MaskedConv2d(nn.Conv2d): 23 | r"""2D causally masked convolution for autoregressive 24 | convolutional models. 25 | 26 | Thin wrapper around :class:`nn.Conv2d`.""" 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | mask = torch.zeros_like(self.weight) 30 | center_x = self.kernel_size[0] // 2 31 | center_y = self.kernel_size[1] // 2 32 | for idx in range(self.kernel_size[0]): 33 | for idy in range(self.kernel_size[1]): 34 | mask[:, :, idx, idy] = int(idx < center_x or (idx == center_x and idy < center_y)) 35 | self.mask = mask 36 | 37 | def forward(self, inputs): 38 | self.weight.data = self.mask * self.weight.data 39 | return super().forward(inputs) 40 | 41 | class MaskedConv3d(nn.Conv3d): 42 | r"""3D causally masked convolution for autoregressive 43 | convolutional models. 44 | 45 | Thin wrapper around :class:`nn.Conv3d`.""" 46 | def __init__(self, *args, **kwargs): 47 | super().__init__(*args, **kwargs) 48 | mask = torch.zeros_like(self.weight) 49 | center_x = self.kernel_size[0] // 2 50 | center_y = self.kernel_size[1] // 2 51 | center_z = self.kernel_size[2] // 2 52 | for idx in range(self.kernel_size[0]): 53 | for idy in range(self.kernel_size[1]): 54 | for idz in range(self.kernel_size[2]): 55 | mask[:, :, idx, idy, idz] = int( 56 | idx < center_x or \ 57 | (idx == center_x and idy < center_y) or \ 58 | (idx == center_x and idy == center_y and idz < center_z) 59 | ) 60 | self.mask = mask 61 | 62 | def forward(self, inputs): 63 | self.weight.data = self.mask * self.weight.data 64 | return super().forward(inputs) 65 | -------------------------------------------------------------------------------- /torchsupport/modules/recurrent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class ConvGRUCellNd(nn.Module): 6 | def __init__(self, in_size, out_size, kernel_size, N=1, **kwargs): 7 | super(ConvGRUCellNd, self).__init__() 8 | conv = eval(f"nn.Conv{N}d") 9 | self.conv_ir = conv(in_size, out_size, kernel_size, **kwargs) 10 | self.conv_hr = conv(in_size, out_size, kernel_size, **kwargs) 11 | self.conv_iz = conv(in_size, out_size, kernel_size, **kwargs) 12 | self.conv_hz = conv(in_size, out_size, kernel_size, **kwargs) 13 | self.conv_in = conv(in_size, out_size, kernel_size, **kwargs) 14 | self.conv_hn = conv(in_size, out_size, kernel_size, **kwargs) 15 | 16 | def forward(self, inputs, state): 17 | r = torch.sigmoid(self.conv_ir(inputs) + self.conv_hr(state)) 18 | z = torch.sigmoid(self.conv_iz(inputs) + self.conv_hz(state)) 19 | n = torch.tanh(self.conv_in(inputs) + self.conv_hn(state * r)) 20 | return z * state + (1 - z) * n 21 | 22 | class ConvGRUCell1d(ConvGRUCellNd): 23 | def __init__(self, in_size, out_size, hidden_size, kernel_size, **kwargs): 24 | super().__init__(in_size, out_size, hidden_size, kernel_size, N=1, **kwargs) 25 | 26 | 27 | class ConvGRUCell2d(ConvGRUCellNd): 28 | def __init__(self, in_size, out_size, hidden_size, kernel_size, **kwargs): 29 | super().__init__(in_size, out_size, hidden_size, kernel_size, N=2, **kwargs) 30 | 31 | 32 | class ConvGRUCell3d(ConvGRUCellNd): 33 | def __init__(self, in_size, out_size, hidden_size, kernel_size, **kwargs): 34 | super().__init__(in_size, out_size, hidden_size, kernel_size, N=3, **kwargs) 35 | -------------------------------------------------------------------------------- /torchsupport/modules/refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | from torchsupport.modules.normalization import AdaptiveBatchNorm 6 | 7 | class ResBlock(nn.Module): 8 | def __init__(self, size, ada_size): 9 | super().__init__() 10 | self.bn = AdaptiveBatchNorm(size, ada_size) 11 | self.convs = nn.Sequential( 12 | nn.ReLU(), 13 | nn.Conv2d(size, size, 3, dilation=1, padding=1), 14 | nn.ReLU(), 15 | nn.Conv2d(size, size, 3, dilation=2, padding=2) 16 | ) 17 | 18 | def forward(self, inputs, condition): 19 | out = self.bn(inputs, condition) 20 | out = self.convs(out) 21 | return out + inputs 22 | 23 | class PoolBlock(nn.Module): 24 | def __init__(self, size, ada_size, width=5, depth=3): 25 | super().__init__() 26 | self.width = width 27 | self.bn = nn.ModuleList([ 28 | AdaptiveBatchNorm(size, ada_size) 29 | for idx in range(depth) 30 | ]) 31 | self.blocks = nn.ModuleList([ 32 | nn.Conv2d(size, size, 3, padding=1) 33 | for idx in range(depth) 34 | ]) 35 | 36 | def forward(self, inputs, condition): 37 | out = inputs 38 | for bn, block in zip(self.bn, self.blocks): 39 | inner = func.max_pool2d(bn(out, condition), self.width, stride=1, padding=2) 40 | inner = block(inner) 41 | out = out + inner 42 | return out 43 | 44 | class RefineBlock(nn.Module): 45 | def __init__(self, size, ada_size, width=5, depth=3): 46 | super().__init__() 47 | self.res_low = nn.ModuleList([ 48 | ResBlock(size, ada_size) 49 | for idx in range(2) 50 | ]) 51 | self.res_high = nn.ModuleList([ 52 | ResBlock(size, ada_size) 53 | for idx in range(2) 54 | ]) 55 | self.low = nn.Conv2d(size, size, 3, padding=1) 56 | self.high = nn.Conv2d(size, size, 3, padding=1) 57 | self.pool = PoolBlock(size, ada_size, width=width, depth=depth) 58 | 59 | def forward(self, low, high, condition): 60 | for block in self.res_low: 61 | low = block(low, condition) 62 | low = func.interpolate(low, scale_factor=2, mode="bilinear") 63 | for block in self.res_high: 64 | high = block(high, condition) 65 | out = self.low(low) + self.high(high) 66 | out = self.pool(out, condition) 67 | return out 68 | -------------------------------------------------------------------------------- /torchsupport/modules/rezero.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class ReZero(nn.Module): 6 | r'''Implemets ReZero normalization proposed by Bachlechner et al. 7 | (https://arxiv.org/pdf/2003.04887.pdf). 8 | Args: 9 | out_size (int): dimension of the channel output''' 10 | def __init__(self, out_size=1, initial_value=0.0): 11 | super().__init__() 12 | self.out_size = out_size 13 | self.alpha = nn.Parameter(torch.ones( 14 | self.out_size, dtype=torch.float, requires_grad=True 15 | )) 16 | with torch.no_grad(): 17 | self.alpha *= initial_value 18 | 19 | def forward(self, inputs, result): 20 | dimension = inputs.dim() - 2 21 | alpha = self.alpha[[None, slice(None)] + dimension * [None]] 22 | return inputs + alpha * result 23 | -------------------------------------------------------------------------------- /torchsupport/modules/separable.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | class DepthWiseSeparableConv1d(nn.Module): 6 | r"""Depthwise separable 1D convolution. 7 | 8 | Analogous functionality to :class:`torch.nn.Conv1d`. 9 | 10 | Args: 11 | in_channels (int): number of input channels. 12 | out_channels (int): number of output channels. 13 | kernel_size (int or (int, int)): kernel size. 14 | kwargs: additional keyword arguments. See `Conv1d` for details. 15 | """ 16 | def __init__(self, in_channels, out_channels, kernel_size, 17 | stride=1, padding=0, dilation=1, bias=True): 18 | super(DepthWiseSeparableConv1d, self).__init__() 19 | self.depth_conv = nn.Conv1d(in_channels, in_channels, kernel_size, 20 | stride=stride, padding=padding, 21 | dilation=dilation, bias=bias) 22 | self.point_conv = nn.Conv1d(in_channels, out_channels, 1) 23 | 24 | def forward(self, inputs): 25 | return self.point_conv(self.depth_conv(inputs)) 26 | 27 | class DepthWiseSeparableConv2d(nn.Module): 28 | r"""Depthwise separable 2D convolution. 29 | 30 | Analogous functionality to :class:`torch.nn.Conv2d`. 31 | 32 | Args: 33 | in_channels (int): number of input channels. 34 | out_channels (int): number of output channels. 35 | kernel_size (int or (int, int)): kernel size. 36 | kwargs: additional keyword arguments. See `Conv2d` for details. 37 | """ 38 | def __init__(self, in_channels, out_channels, kernel_size, 39 | stride=1, padding=1, dilation=1, bias=True): 40 | super(DepthWiseSeparableConv2d, self).__init__() 41 | self.depth_conv = nn.Conv2d(in_channels, in_channels, kernel_size, 42 | stride=stride, padding=padding, 43 | dilation=dilation, bias=bias) 44 | self.point_conv = nn.Conv2d(in_channels, out_channels, 1) 45 | 46 | def forward(self, inputs): 47 | return self.point_conv(self.depth_conv(inputs)) 48 | -------------------------------------------------------------------------------- /torchsupport/modules/zoom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | from torchsupport.modules.gradient import hard_k_hot 6 | 7 | def zoom2d(target, logits, offset=64, num_samples=2, temperature=0.1): 8 | shape = logits.shape[2:] 9 | logits = logits.view(logits.size(0), -1) 10 | k_hot = hard_k_hot(logits, num_samples, temperature=temperature) 11 | k_hot = k_hot.view(logits.size(0), *shape) 12 | ind, x, y = k_hot.nonzero().t() 13 | 14 | x_ind = torch.arange(0, offset)[None] 15 | x_ind = x_ind.repeat_interleave(num_samples * logits.size(0), dim=0) 16 | x_ind = x_ind + x[:, None] 17 | 18 | y_ind = torch.arange(0, offset)[None] 19 | y_ind = y_ind.repeat_interleave(num_samples * logits.size(0), dim=0) 20 | y_ind = y_ind + y[:, None] 21 | 22 | result = target[ind[:, None, None], :, x_ind[:, :, None], y_ind[:, None, :]] 23 | result = result.unsqueeze(1).transpose(1, -1).squeeze(-1) 24 | result = result * k_hot[ind, x, y][:, None, None, None] 25 | result = result.reshape(-1, num_samples, *result.shape[1:]) 26 | return result 27 | 28 | class Zoom2d(nn.Module): 29 | def __init__(self, offset=64, num_samples=2, temperature=0.1): 30 | super().__init__() 31 | self.offset = offset 32 | self.num_samples = num_samples 33 | self.temperature = temperature 34 | 35 | def forward(self, target, logits): 36 | return zoom2d( 37 | target, logits, offset=self.offset, 38 | num_samples=self.num_samples, 39 | temperature=self.temperature 40 | ) 41 | -------------------------------------------------------------------------------- /torchsupport/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from torchsupport.networks.gan import * 2 | -------------------------------------------------------------------------------- /torchsupport/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from torchsupport.modules import * 2 | -------------------------------------------------------------------------------- /torchsupport/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/ops/__init__.py -------------------------------------------------------------------------------- /torchsupport/ops/shape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def flatten(input, batch=False): 4 | if batch: 5 | return input.view(input.size()[0], -1) 6 | else: 7 | return input.view(-1) 8 | 9 | def batchexpand(input, batch): 10 | result = input.unsqueeze(0).expand( 11 | batch.size()[0], 12 | *input.size() 13 | ) 14 | return result 15 | 16 | def deshape(inputs): 17 | dimension = inputs.dim() 18 | drop = dimension - 2 19 | if drop == 0: 20 | return inputs, None 21 | permutation = [0] + [2 + idx for idx in range(drop)] + [1] 22 | permuted = inputs.permute(*permutation) 23 | shape = permuted.shape 24 | return permuted.reshape(-1, inputs.shape[-1]), shape 25 | 26 | def reshape(inputs, shape): 27 | if shape is None: 28 | return inputs 29 | inputs = inputs.reshape(*shape[:-1], inputs.size(-1)) 30 | dimension = inputs.dim() 31 | drop = dimension - 2 32 | permutation = [0, -1] + list(range(1, drop + 1)) 33 | inputs = inputs.permute(*permutation).contiguous() 34 | return inputs 35 | -------------------------------------------------------------------------------- /torchsupport/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from torchsupport.optim.radam import RAdam -------------------------------------------------------------------------------- /torchsupport/reporting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/reporting/__init__.py -------------------------------------------------------------------------------- /torchsupport/reporting/reporting.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import torch 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | 7 | def tensorplot(writer, name, fig, step): 8 | fig.canvas.draw() 9 | buf = fig.canvas.tostring_rgb() 10 | ncols, nrows = fig.canvas.get_width_height() 11 | shape = (nrows, ncols, 3) 12 | array = np.fromstring(buf, dtype=np.uint8).reshape(shape) 13 | tensor = torch.Tensor(array.transpose(2, 0, 1)) 14 | writer.add_image(name, tensor, step) 15 | -------------------------------------------------------------------------------- /torchsupport/structured/__init__.py: -------------------------------------------------------------------------------- 1 | from .structures import * 2 | from .modules import * 3 | from .packedtensor import * 4 | from .chunkable import * 5 | from . import scatter 6 | -------------------------------------------------------------------------------- /torchsupport/structured/chunkable.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parallel.scatter_gather import Scatter 3 | 4 | def chunk_sizes(lengths, num_targets): 5 | num_entities = len(lengths) 6 | chops = num_entities // num_targets 7 | result = [ 8 | sum(lengths[idx * chops:(idx + 1) * chops]) 9 | for idx in range(num_targets) 10 | ] 11 | return result 12 | 13 | def chunk_tensor(tensor, lengths, targets, dim=0): 14 | return Scatter.apply(targets, lengths, dim, tensor) 15 | 16 | class Chunkable(): 17 | def chunk(self, targets): 18 | raise NotImplementedError("Abstract.") 19 | 20 | def scatter_chunked(inputs, target_gpus, dim=0): 21 | r""" 22 | Slices tensors into approximately equal chunks and 23 | distributes them across given GPUs. Duplicates 24 | references to objects that are not tensors. 25 | """ 26 | def scatter_map(obj): 27 | if isinstance(obj, Chunkable): 28 | return obj.chunk(target_gpus) 29 | if isinstance(obj, torch.Tensor): 30 | return Scatter.apply(target_gpus, None, dim, obj) 31 | 32 | if isinstance(obj, tuple) and len(obj) > 0: 33 | return list(zip(*map(scatter_map, obj))) 34 | if isinstance(obj, list) and len(obj) > 0: 35 | return list(map(list, zip(*map(scatter_map, obj)))) 36 | if isinstance(obj, dict) and len(obj) > 0: 37 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 38 | return [obj for targets in target_gpus] 39 | 40 | try: 41 | return scatter_map(inputs) 42 | finally: 43 | scatter_map = None 44 | 45 | def scatter_chunked_kwargs(inputs, kwargs, target_gpus, dim=0): 46 | r"""Scatter with support for kwargs dictionary""" 47 | inputs = scatter_chunked(inputs, target_gpus, dim) if inputs else [] 48 | kwargs = scatter_chunked(kwargs, target_gpus, dim) if kwargs else [] 49 | if len(inputs) < len(kwargs): 50 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 51 | elif len(kwargs) < len(inputs): 52 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 53 | inputs = tuple(inputs) 54 | kwargs = tuple(kwargs) 55 | return inputs, kwargs 56 | -------------------------------------------------------------------------------- /torchsupport/structured/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import * 2 | from .parallel import * 3 | from .pooling import * 4 | from .spectral import * 5 | from .transformer import * 6 | from .sequence_transformer import * 7 | from .rezero_transformer import RezeroTransformerBlock 8 | -------------------------------------------------------------------------------- /torchsupport/structured/modules/parallel.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torchsupport.structured.chunkable import scatter_chunked_kwargs 4 | from torchsupport.data.collate import gather_collated 5 | 6 | class DataParallel(nn.DataParallel): 7 | def scatter(self, inputs, kwargs, device_ids): 8 | return scatter_chunked_kwargs(inputs, kwargs, device_ids, dim=self.dim) 9 | 10 | def gather(self, outputs, output_device): 11 | return gather_collated(outputs, output_device, dim=self.dim) 12 | -------------------------------------------------------------------------------- /torchsupport/structured/modules/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | from torchsupport.structured import structures as cs 6 | from torchsupport.structured import scatter 7 | 8 | class DeletionPool(nn.Module): 9 | def __init__(self, size): 10 | super(DeletionPool, self).__init__() 11 | self.project = nn.Linear(size, 1) 12 | 13 | def forward(self, data, structure): 14 | vals = self.project(data) 15 | median_val = torch.median(vals) 16 | keep_nodes = (vals > median_val).nonzeros() 17 | pooled_data = vals[keep_nodes] * data[keep_nodes] 18 | pooled_structure = cs.ConnectMissing(structure, keep_nodes) 19 | return pooled_data, pooled_structure 20 | 21 | class SelectionPool(nn.Module): 22 | pass 23 | 24 | class CliquePool(nn.Module): 25 | def __init__(self): 26 | super(CliquePool, self).__init__() 27 | 28 | class GraphPool(nn.Module): 29 | def __init__(self): 30 | super(GraphPool, self).__init__() 31 | 32 | def combine(self, nodes, indices): 33 | raise NotImplementedError("Abstract.") 34 | 35 | def forward(self, nodes, indices): 36 | return self.combine(nodes, indices) 37 | 38 | class MILAttention(GraphPool): 39 | def __init__(self, in_size, out_size, attention_size, heads): 40 | super(MILAttention, self).__init__() 41 | self.heads = heads 42 | self.query = nn.Linear(attention_size, heads) 43 | self.gate = nn.Linear(in_size, attention_size) 44 | self.key = nn.Linear(in_size, attention_size) 45 | self.value = nn.Linear(in_size, heads * attention_size) 46 | self.out = nn.Linear(heads * attention_size, out_size) 47 | 48 | def combine(self, nodes, indices): 49 | logits = self.query(torch.tanh(self.key(nodes)) * torch.sigmoid(self.gate(nodes))) 50 | weight = scatter.softmax(logits, indices) 51 | value = self.value(nodes) 52 | weight = weight.view(*weight.shape[:-1], 1, self.heads) 53 | value = value.view(*value.shape[:-1], -1, self.heads) 54 | result = scatter.add(weight * value, indices) 55 | return self.out(result.view(*result.shape[:-2], -1)) 56 | 57 | class MILRouting(GraphPool): 58 | def __init__(self, in_size, k=3): 59 | super(MILRouting, self).__init__() 60 | self.k = k 61 | 62 | def forward(self, nodes, indices): 63 | weights = torch.zeros(nodes.size(0), 1) 64 | for idx in range(self.k): 65 | smax = torch.softmax(weights, dim=0) 66 | sigma = scatter.add(smax * nodes) 67 | norm = torch.norm(sigma, dim=1, keepdim=True) 68 | norm2 = norm ** 2 69 | sval = sigma / norm * norm2 / (1 + norm2) 70 | weights = weights + nodes * sigma[indices] 71 | return sigma 72 | -------------------------------------------------------------------------------- /torchsupport/structured/modules/rezero_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | from torchsupport.modules.basic import MLP 6 | from torchsupport.modules.rezero import ReZero 7 | from torchsupport.structured.modules.sequence_transformer import SequenceMultiHeadAttention 8 | 9 | class RezeroTransformerBlock(nn.Module): 10 | r'''Implements a ReZero transformer as described in the paper: 11 | https://arxiv.org/pdf/2003.04887.pdf. Thereby here the normalization step is completely 12 | skipped and the multiheadattenition result and the feed forward result from the standard 13 | transformer are multiplied with the trainable alpha-parameter (desribed in rezero.py). 14 | Args: 15 | in_size (int): input size 16 | n_heads (int): number of heads (for multihead attention) 17 | hidden_size (int): output size for layers between the input and the output in 18 | the feed forward network 19 | attention_size (int): number of features used to compare query and key in the 20 | attention kernel 21 | value_size (int): value size (for multihead attention) 22 | out_size (int): output size (for multiheadatention) 23 | function (callable): Conv1d/Conv2D/Conv3D 24 | dropout (int): dropout parameter 25 | ''' 26 | def __init__(self, size, n_heads=8, hidden_size=128, 27 | attention_size=128, value_size=128, depth=2, dropout=0.1): 28 | super().__init__() 29 | self.attention = SequenceMultiHeadAttention( 30 | size, size, 31 | attention_size=attention_size, 32 | hidden_size=value_size, 33 | heads=n_heads 34 | ) 35 | self.ff = MLP(size, size, hidden_size=hidden_size, depth=2, batch_norm=False) 36 | self.rezero = ReZero(size) 37 | 38 | self.dropout_1 = nn.Dropout(dropout) 39 | self.dropout_2 = nn.Dropout(dropout) 40 | 41 | def forward(self, inputs, index): 42 | result_1 = self.dropout_1(self.attention(inputs, index)) 43 | inputs = self.rezero(inputs, result_1) 44 | result_2 = self.dropout_2(self.ff(inputs)) 45 | inputs = self.rezero(inputs, result_2) 46 | return inputs 47 | -------------------------------------------------------------------------------- /torchsupport/structured/packedtensor.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | import torch 3 | from torchsupport.data.collate import Collatable 4 | from torchsupport.data.io import DeviceMovable 5 | from torchsupport.data.tensor_provider import TensorProvider 6 | from torchsupport.structured.chunkable import ( 7 | Chunkable, chunk_sizes, chunk_tensor 8 | ) 9 | 10 | class PackedTensor(DeviceMovable, Collatable, Chunkable, TensorProvider): 11 | def __init__(self, tensors, lengths=None, split=True, box=False): 12 | self.tensor = tensors 13 | self.split = split 14 | self.box = box 15 | self.lengths = [len(tensors)] 16 | if isinstance(self.tensor, (list, tuple)): 17 | self.lengths = list(map(lambda x: x.size(0), tensors)) 18 | self.tensor = torch.cat(self.tensor, dim=0) 19 | if lengths is not None: 20 | self.lengths = lengths 21 | 22 | @classmethod 23 | def collate(cls, tensors): 24 | data = [ 25 | tensor.tensor 26 | for tensor in tensors 27 | ] 28 | lengths = [ 29 | length 30 | for tensor in tensors 31 | for length in tensor.lengths 32 | ] 33 | if not tensors[0].split: 34 | return torch.cat(data, dim=0) 35 | return PackedTensor(data, lengths=lengths, box=tensors[0].box) 36 | 37 | def move_to(self, device): 38 | the_copy = copy(self) 39 | the_copy.tensor = self.tensor.to(device) 40 | return the_copy 41 | 42 | def tensors(self): 43 | return [self.tensor] 44 | 45 | def chunk(self, targets): 46 | sizes = chunk_sizes(self.lengths, len(targets)) 47 | chunks = chunk_tensor(self.tensor, sizes, targets, dim=0) 48 | result = [] 49 | step = len(self.lengths) // len(targets) 50 | for idx, chunk in enumerate(chunks): 51 | the_tensor = PackedTensor(chunk, split=self.split, box=self.box) 52 | the_tensor.lengths = self.lengths[idx * step:(idx + 1) * step] 53 | the_tensor = the_tensor if self.box else the_tensor.tensor 54 | result.append(the_tensor) 55 | return result 56 | 57 | def detach(self): 58 | return PackedTensor( 59 | self.tensor.detach(), 60 | lengths=list(self.lengths), 61 | split=self.split, 62 | box=self.box 63 | ) 64 | 65 | def clone(self): 66 | return PackedTensor( 67 | self.tensor.clone(), 68 | lengths=list(self.lengths), 69 | split=self.split, 70 | box=self.box 71 | ) 72 | 73 | def __len__(self): 74 | return len(self.lengths) 75 | 76 | def __getitem__(self, idx): 77 | before = sum(self.lengths[:idx]) 78 | window = slice(before, before + self.lengths[idx]) 79 | tensor = self.tensor[window] 80 | lengths = [self.lengths[idx]] 81 | 82 | return PackedTensor( 83 | tensor, 84 | lengths=lengths, 85 | split=self.split, 86 | box=self.box 87 | ) 88 | -------------------------------------------------------------------------------- /torchsupport/structured/structures/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import * 2 | from .connection import * 3 | -------------------------------------------------------------------------------- /torchsupport/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/test/__init__.py -------------------------------------------------------------------------------- /torchsupport/test/test_gradient.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torchsupport.modules import ( 4 | replace_gradient, reinforce, straight_through 5 | ) 6 | 7 | @pytest.fixture 8 | def grad_sink(): 9 | return torch.ones(10, requires_grad=True) 10 | 11 | def test_replace_gradient(grad_sink): 12 | gradient_provider = grad_sink * 42 13 | replacement_value = torch.zeros_like(grad_sink) 14 | result = replace_gradient(replacement_value, gradient_provider) 15 | assert bool((result == replacement_value).all()) 16 | replaced_grad = torch.autograd.grad( 17 | result, grad_sink, 18 | grad_outputs=torch.ones_like(result), 19 | retain_graph=True 20 | )[0] 21 | desired_grad = torch.autograd.grad( 22 | gradient_provider, grad_sink, 23 | grad_outputs=torch.ones_like(gradient_provider), 24 | retain_graph=True 25 | )[0] 26 | assert bool((replaced_grad == desired_grad).all()) 27 | 28 | def test_reinforce(grad_sink): 29 | grad_sink = grad_sink[None].expand(10, 10) 30 | op = lambda x: x.sum(dim=1, keepdim=True) 31 | reinforced_sum = reinforce(op) 32 | dist = torch.distributions.Normal(grad_sink, 0.1 * torch.ones_like(grad_sink)) 33 | reinforce_dist = torch.distributions.Normal(grad_sink, 0.1) 34 | torch.random.manual_seed(1234) 35 | reinforced_result = reinforced_sum(reinforce_dist).mean(dim=0) 36 | reparam_result = dist.rsample().sum(dim=1).mean(dim=0) 37 | 38 | assert torch.allclose( 39 | reinforced_result[0], reparam_result, 40 | rtol=0.1, atol=0.1 41 | ) 42 | 43 | # replaced_grad = torch.autograd.grad( 44 | # reinforced_result, grad_sink, 45 | # grad_outputs=torch.ones_like(reinforced_result), 46 | # retain_graph=True 47 | # )[0] 48 | # desired_grad = torch.autograd.grad( 49 | # reparam_result, grad_sink, 50 | # grad_outputs=torch.ones_like(reparam_result), 51 | # retain_graph=True 52 | # )[0] 53 | # assert torch.allclose( 54 | # replaced_grad, desired_grad, 55 | # rtol=0.1, atol=0.1 56 | # ) 57 | -------------------------------------------------------------------------------- /torchsupport/test/test_mlp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.nn.utils import spectral_norm 4 | from torchsupport.modules import MLP 5 | 6 | @pytest.mark.parametrize( 7 | 'count, in_size, out_size, hidden_size, depth, batch_norm, normalization', [ 8 | (count, in_size, out_size, hidden_size, depth, batch_norm, normalization) 9 | for count in (2, 10) 10 | for in_size in (1, 3) 11 | for out_size in (1, 4) 12 | for hidden_size in (1, 50, [10, 20, 30]) 13 | for depth in (1, 3) 14 | for batch_norm in (True, False) 15 | for normalization in (lambda x: x, spectral_norm) 16 | ] 17 | ) 18 | def test_mlp(count, in_size, out_size, 19 | hidden_size, depth, batch_norm, 20 | normalization): 21 | mlp = MLP( 22 | in_size, out_size, 23 | hidden_size=hidden_size, 24 | depth=depth, 25 | batch_norm=batch_norm, 26 | normalization=normalization 27 | ) 28 | inputs = torch.randn(count, in_size) 29 | result = mlp(inputs) 30 | if isinstance(hidden_size, (list, tuple)): 31 | intermediate = inputs 32 | for idx, block in enumerate(mlp.blocks[:-1]): 33 | intermediate = block(intermediate) 34 | assert intermediate.size(1) == hidden_size[idx] 35 | else: 36 | intermediate = mlp.blocks[0](inputs) 37 | assert intermediate.size(1) == hidden_size 38 | assert result.size(1) == out_size 39 | -------------------------------------------------------------------------------- /torchsupport/test/test_onehot.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.nn.utils import spectral_norm 4 | from torchsupport.modules import one_hot_encode, OneHotEncoder 5 | 6 | CASES = [ 7 | case 8 | for numeric in (True, False) 9 | for case in [ 10 | ( 11 | "ABCAABAC", "ABC", 12 | torch.tensor([0, 1, 2, 0, 0, 1, 0, 2]), 13 | numeric 14 | ), 15 | ( 16 | list("ABCAABAC"), "ABC", 17 | torch.tensor([0, 1, 2, 0, 0, 1, 0, 2]), 18 | numeric 19 | ), 20 | ( 21 | torch.tensor([0, 1, 2, 0, 0, 1, 0, 2]), [0, 1, 2], 22 | torch.tensor([0, 1, 2, 0, 0, 1, 0, 2]), 23 | numeric 24 | ) 25 | ] 26 | ] 27 | 28 | @pytest.mark.parametrize( 29 | "data, code, expected, numeric", CASES 30 | ) 31 | def test_one_hot_encode_shape(data, code, expected, numeric): 32 | encoding = one_hot_encode(data, code, numeric=numeric) 33 | if numeric: 34 | assert encoding.dim() == 1 35 | assert encoding.size(0) == len(data) 36 | else: 37 | assert encoding.dim() == 2 38 | assert encoding.size(0) == len(code) 39 | assert encoding.size(1) == len(data) 40 | 41 | @pytest.mark.parametrize( 42 | "data, code, expected, numeric", CASES 43 | ) 44 | def test_one_hot_encode_value(data, code, expected, numeric): 45 | encoding = one_hot_encode(data, code, numeric=numeric) 46 | if numeric: 47 | assert bool((encoding == expected).all()) 48 | else: 49 | expected_one_hot = torch.zeros(expected.size(0), len(code)) 50 | ind = torch.arange(expected.size(0)) 51 | expected_one_hot[ind, expected] = 1 52 | assert bool((encoding == expected_one_hot.t()).all()) 53 | 54 | @pytest.mark.parametrize( 55 | "data, code, expected, numeric", CASES 56 | ) 57 | def test_one_hot_encode_consistent(data, code, expected, numeric): 58 | encoding = one_hot_encode(data, code, numeric=numeric) 59 | if not numeric: 60 | encoding = encoding.argmax(dim=0) 61 | decoding = [] 62 | for base, encoded in zip(data, encoding): 63 | decoded = code[int(encoded)] 64 | decoding.append(decoded) 65 | assert decoded == base 66 | 67 | # check consistency 68 | re_encoding = one_hot_encode(decoding, code, numeric=True) 69 | print(re_encoding) 70 | assert bool((encoding == re_encoding).all()) 71 | 72 | def test_create_encoder(): 73 | OneHotEncoder("ABC", numeric=True) 74 | OneHotEncoder("ABC", numeric=False) 75 | OneHotEncoder(list("ABC")) 76 | OneHotEncoder(torch.arange(10, dtype=torch.long)) 77 | -------------------------------------------------------------------------------- /torchsupport/test/test_scatter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torchsupport.structured import scatter 4 | 5 | -------------------------------------------------------------------------------- /torchsupport/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/training/__init__.py -------------------------------------------------------------------------------- /torchsupport/training/consistent_gan.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as func 6 | from torch.distributions import Normal, RelaxedOneHotCategorical 7 | 8 | from tensorboardX import SummaryWriter 9 | 10 | from torchsupport.training.training import Training 11 | import torchsupport.modules.losses.vae as vl 12 | from torchsupport.data.io import netwrite, make_differentiable 13 | from torchsupport.data.collate import DataLoader 14 | 15 | from torchsupport.training.neural_conditioner import NeuralConditionerTraining 16 | 17 | class ConsistentGANTraining(NeuralConditionerTraining): 18 | def __init__(self, generator, discriminator, data, levels=None, gamma=1, **kwargs): 19 | super().__init__(generator, discriminator, data, **kwargs) 20 | self.gamma = gamma 21 | self.levels = levels 22 | 23 | def sample(self, data): 24 | noise = super().sample(data) 25 | return noise 26 | 27 | def zoom(self, data, idx): 28 | shape = data.shape[-1] 29 | pos = (shape - self.levels[idx]) // 2 30 | return data[:, :, pos:pos + self.levels[idx], pos:pos + self.levels[idx]] 31 | 32 | def stage(self, data): 33 | inputs = data 34 | stages = [ 35 | self.zoom(layer, idx) 36 | for idx, layer in enumerate(inputs[:-1]) 37 | ] 38 | return stages 39 | 40 | def restrict_inputs(self, inputs, mask): 41 | return [ 42 | inp * msk 43 | for inp, msk in zip(inputs, mask) 44 | ] 45 | 46 | def reconstruction_loss(self, generated, stages): 47 | l1_loss = 0.0 48 | for idx, stage in enumerate(stages): 49 | compare = generated[idx + 1] 50 | compare = func.adaptive_avg_pool2d(compare, stage.shape[-1]) 51 | diff = (compare - stage).view(compare.size(0), -1).norm(p=1, dim=1) 52 | l1_loss += diff.mean() 53 | self.current_losses["reconstruction"] = float(l1_loss) 54 | return l1_loss 55 | 56 | def run_generator(self, data): 57 | sample = self.sample(data) 58 | inputs, available, requested = data 59 | stages = self.stage(inputs) 60 | restricted_inputs = self.restrict_inputs(inputs, available) 61 | generated, stages = self.generator( 62 | sample, stages, restricted_inputs, 63 | available, requested 64 | ) 65 | 66 | return inputs, generated, stages, available, requested 67 | 68 | def run_discriminator(self, data): 69 | with torch.no_grad(): 70 | fake = self.run_generator(data) 71 | make_differentiable(fake) 72 | make_differentiable(data) 73 | _, fake_batch, _, _, _ = fake 74 | inputs, available, requested = data 75 | fake_result = self._run_discriminator_aux( 76 | inputs, fake_batch, 77 | available, requested 78 | ) 79 | real_result = self._run_discriminator_aux( 80 | inputs, inputs, 81 | available, requested 82 | ) 83 | return fake, data, fake_result, real_result 84 | 85 | def generator_step_loss(self, data, generated, stages, avl, req): 86 | gan_loss = super().generator_step_loss(data, generated, avl, req) 87 | reconstruction_loss = self.reconstruction_loss(generated, stages) 88 | return gan_loss + self.gamma * reconstruction_loss 89 | -------------------------------------------------------------------------------- /torchsupport/training/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | import torch.distributed as distributed 5 | 6 | from torchsupport.data.io import netwrite 7 | from torchsupport.training.training import Training, BasicTraining 8 | 9 | class SynchronousDistributedTraining(BasicTraining): 10 | """Distributes a given training process over a set of nodes, 11 | via gradient averaging. 12 | """ 13 | def __init__(self, *args, **kwargs): 14 | super(SynchronousDistributedTraining, self).__init__(*args, **kwargs) 15 | self.world_size = distributed.get_world_size() 16 | self.rank = distributed.get_rank() 17 | self.group = distributed.new_group(ranks=list(range(self.world_size))) 18 | 19 | def step(self, data, label): 20 | predictions = self.net(data) 21 | loss_val = self.loss(predictions, label) 22 | loss_val.backwards() 23 | _average_gradients(self.net, self.world_size, self.group) 24 | self.optimizer.step() 25 | self.training_loss = loss_val.item() 26 | 27 | def checkpoint(self): 28 | if self.rank == 0: 29 | super(SynchronousDistributedTraining, self).checkpoint() 30 | 31 | class AsynchronousDistributedTraining(BasicTraining): 32 | """Distribute a given training process over a set of nodes, 33 | via GossipGraD distributed training. 34 | """ 35 | def __init__(self, *args, **kwargs): 36 | super(AsynchronousDistributedTraining, self).__init__(*args, **kwargs) 37 | self.gossip_step = 0 38 | self.world_size = distributed.get_world_size() 39 | self.rank = distributed.get_rank() 40 | self.groups = [] 41 | for idx in range(self.world_size - 1): 42 | partner = (self.rank + idx + 1) % self.world_size 43 | group = distributed.new_group(ranks=[self.rank, partner]) 44 | self.groups.append(group) 45 | 46 | def step(self, data, label): 47 | predictions = self.net(data) 48 | loss_val = self.loss(predictions, label) 49 | loss_val.backwards() 50 | _gossip_grad(self.net, self.world_size, self.rank, 51 | self.groups, self.gossip_step) 52 | self.gossip_step += 1 53 | if self.gossip_step == self.world_size - 1: 54 | self.gossip_step = 0 55 | self.optimizer.step() 56 | self.training_loss = loss_val.item() 57 | 58 | def checkpoint(self): 59 | if self.rank == 0: 60 | super(AsynchronousDistributedTraining, self).checkpoint() 61 | 62 | def _average_gradients(net, world_size, group, cuda=False): 63 | for p in net.parameters(): 64 | tensor = p.grad.data.cpu() 65 | distributed.all_reduce(tensor, 66 | op=distributed.reduce_op.SUM, 67 | group=group) 68 | tensor /= float(world_size) 69 | if cuda: 70 | p.grad.data = tensor.cuda() 71 | else: 72 | p.grad.data = tensor 73 | 74 | def _gossip_grad(net, world_size, rank, groups, step, cuda=False): 75 | group = groups[step] 76 | for p in net.parameters(): 77 | tensor = p.grad.data.cpu() 78 | distributed.all_reduce(tensor, 79 | op=distributed.reduce_op.SUM, 80 | group=group) 81 | tensor /= 2.0 82 | if cuda: 83 | p.grad.data = tensor.cuda() 84 | else: 85 | p.grad.data = tensor 86 | -------------------------------------------------------------------------------- /torchsupport/training/few_shot_gan.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as func 7 | from torch.distributions import Normal, RelaxedOneHotCategorical 8 | 9 | from torchsupport.training.state import ( 10 | NetState, NetNameListState, TrainingState 11 | ) 12 | from torchsupport.training.gan import RothGANTraining 13 | import torchsupport.modules.losses.vae as vl 14 | from torchsupport.data.io import netwrite, to_device, detach, make_differentiable 15 | from torchsupport.data.collate import DataLoader 16 | 17 | class FewShotGANTraining(RothGANTraining): 18 | def mixing_key(self, data): 19 | return data[1] 20 | 21 | def sample(self, data): 22 | the_generator = self.generator 23 | if isinstance(the_generator, nn.DataParallel): 24 | the_generator = the_generator.module 25 | return to_device(the_generator.sample(data), self.device) 26 | 27 | def divergence_loss(self, sample): 28 | _, encoder_parameters = sample 29 | result = vl.normal_kl_norm_loss(*encoder_parameters) 30 | return result 31 | 32 | def generator_step_loss(self, data, generated, sample): 33 | gan_loss = super().generator_step_loss(data, generated, sample) 34 | sample_divergence_loss = self.divergence_loss(sample) 35 | self.current_losses["kullback leibler"] = float(sample_divergence_loss) 36 | return gan_loss + sample_divergence_loss 37 | -------------------------------------------------------------------------------- /torchsupport/training/log/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/training/log/__init__.py -------------------------------------------------------------------------------- /torchsupport/training/log/log_types.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from matplotlib import pyplot as plt 3 | from tensorboardX import SummaryWriter 4 | import warnings 5 | 6 | class LogType: 7 | @property 8 | def data(self): 9 | return 0 10 | 11 | def log(self, logger, name, step): 12 | raise NotImplementedError("Abstract") 13 | 14 | class LogImage(LogType): 15 | def __init__(self, img): 16 | super().__init__() 17 | if torch.is_tensor(img): 18 | img = img.detach().cpu() 19 | if img.max() > 1.0 or img.min() < 0.0: 20 | img = (img - img.min()) / (img.max() - img.min()) 21 | self.img = img 22 | 23 | def log(self, logger, name, step): 24 | if self.img.dim() > 3: 25 | logger.log_image_batch(name, self.img, step) 26 | else: 27 | logger.log_image(name, self.img, step) 28 | 29 | class LogNumber(LogType): 30 | def __init__(self, number): 31 | super().__init__() 32 | if torch.is_tensor(number): 33 | number = float(number.detach().cpu()) 34 | self.number = number 35 | 36 | def log(self, logger, name, step): 37 | logger.log_number(name, self.number, step) 38 | 39 | class LogText(LogType): 40 | def __init__(self, text): 41 | super().__init__() 42 | self.text = text 43 | 44 | def log(self, logger, name, step): 45 | logger.log_text(name, self.text, step) 46 | 47 | class LogFigure(LogType): 48 | def __init__(self, figure): 49 | super().__init__() 50 | self.figure = figure 51 | 52 | def log(self, logger, name, step): 53 | logger.log_figure(name, self.figure, step) 54 | 55 | class LogEmbedding(LogType): 56 | def __init__(self, embedding): 57 | super().__init__() 58 | if torch.is_tensor(embedding): 59 | embedding = embedding.detach().cpu() 60 | embedding = embedding.reshape(embedding.shape[0], -1) 61 | self.embedding = embedding 62 | 63 | def log(self, logger, name, step): 64 | logger.log_embedding(name, self.embedding, step) 65 | 66 | class Logger: 67 | def log_image(self, name, data, step): 68 | raise NotImplementedError("Abstract") 69 | 70 | def log_image_batch(self, name, data, step): 71 | raise NotImplementedError("Abstract") 72 | 73 | def log_number(self, name, data, step): 74 | raise NotImplementedError("Abstract") 75 | 76 | def log_text(self, name, data, step): 77 | raise NotImplementedError("Abstract") 78 | 79 | def log_figure(self, name, data, step): 80 | raise NotImplementedError("Abstract") 81 | 82 | def log_embedding(self, name, data, step): 83 | raise NotImplementedError("Abstract") 84 | 85 | def log(self, name, data, step): 86 | if isinstance(data, (float, int)): 87 | self.log_number(name, data, step) 88 | elif isinstance(data, str): 89 | self.log_text(name, data, step) 90 | elif isinstance(data, LogType): 91 | data.log(self, name, step) 92 | else: 93 | warnings.warn(f"{name} of type {type(data)} could not be logged.\n" 94 | f"Consider implementing a custom LogType.") 95 | -------------------------------------------------------------------------------- /torchsupport/training/log/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from torchsupport.training.log.log_types import Logger 3 | 4 | class TensorboardLogger(Logger): 5 | def __init__(self, path): 6 | self.writer = SummaryWriter(path) 7 | 8 | def log_image(self, name, data, step): 9 | self.writer.add_image(name, data, step) 10 | 11 | def log_image_batch(self, name, data, step): 12 | self.writer.add_images(name, data, step) 13 | 14 | def log_number(self, name, data, step): 15 | self.writer.add_scalar(name, data, step) 16 | 17 | def log_text(self, name, data, step): 18 | self.writer.add_text(name, data, step) 19 | 20 | def log_figure(self, name, data, step): 21 | self.writer.add_figure(name, data, step) 22 | 23 | def log_embedding(self, name, data, step): 24 | self.writer.add_embedding(name, data, step) 25 | -------------------------------------------------------------------------------- /torchsupport/training/neural_conditioner.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as func 6 | from torch.distributions import Normal, RelaxedOneHotCategorical 7 | 8 | from tensorboardX import SummaryWriter 9 | 10 | from torchsupport.training.training import Training 11 | import torchsupport.modules.losses.vae as vl 12 | from torchsupport.data.io import netwrite, make_differentiable 13 | from torchsupport.data.collate import DataLoader 14 | 15 | from torchsupport.training.gan import GANTraining, RothGANTraining 16 | 17 | def _gradient_norm(inputs, parameters): 18 | out = torch.ones(inputs.size()).to(inputs.device) 19 | gradients = torch.autograd.grad( 20 | inputs, parameters, create_graph=True, retain_graph=True, 21 | grad_outputs=out 22 | ) 23 | grad_sum = 0.0 24 | for gradient in gradients: 25 | grad_sum += (gradient ** 2).view(gradient.size(0), -1).sum(dim=1) 26 | grad_sum = torch.sqrt(grad_sum + 1e-16) 27 | return grad_sum, out 28 | 29 | class NeuralConditionerTraining(RothGANTraining): 30 | def __init__(self, generator, discriminator, data, **kwargs): 31 | super(NeuralConditionerTraining, self).__init__( 32 | generator, 33 | discriminator, 34 | data, **kwargs 35 | ) 36 | 37 | def mixing_key(self, data): 38 | if len(data) == 4: 39 | return data[1] 40 | else: 41 | return data[0] 42 | 43 | def regularization(self, fake, real, generated_result, real_result): 44 | real_norm, real_out = _gradient_norm(real_result, self.mixing_key(real)) 45 | fake_norm, fake_out = _gradient_norm(generated_result, self.mixing_key(fake)) 46 | 47 | real_penalty = real_norm ** 2 48 | fake_penalty = fake_norm ** 2 49 | 50 | penalty = 0.5 * (real_penalty + fake_penalty).mean() 51 | 52 | out = (real_out, fake_out) 53 | 54 | return penalty, out 55 | 56 | def generator_loss(self, inputs, generated, available, requested): 57 | discriminator_result = self._run_discriminator_aux( 58 | inputs, generated, available, requested 59 | ) 60 | loss_val = func.binary_cross_entropy_with_logits( 61 | discriminator_result, 62 | torch.zeros_like(discriminator_result).to(self.device) 63 | ) 64 | 65 | return loss_val 66 | 67 | def restrict_inputs(self, data, mask): 68 | return data * mask 69 | 70 | def run_generator(self, data): 71 | sample = self.sample(data) 72 | inputs, available, requested = data 73 | restricted_inputs = self.restrict_inputs(inputs, available) 74 | generated = self.generator( 75 | sample, restricted_inputs, 76 | available, requested 77 | ) 78 | 79 | return inputs, generated, available, requested 80 | 81 | def _run_discriminator_aux(self, x, x_p, a, r): 82 | avail = self.restrict_inputs(x, a) 83 | reqst = self.restrict_inputs(x_p, r) 84 | result = self.discriminator(avail, reqst, a, r) 85 | return result 86 | 87 | def run_discriminator(self, data): 88 | with torch.no_grad(): 89 | fake = self.run_generator(data) 90 | make_differentiable(fake) 91 | make_differentiable(data) 92 | _, fake_batch, _, _ = fake 93 | inputs, available, requested = data 94 | fake_result = self._run_discriminator_aux( 95 | inputs, fake_batch, 96 | available, requested 97 | ) 98 | real_result = self._run_discriminator_aux( 99 | inputs, inputs, 100 | available, requested 101 | ) 102 | return fake, data, fake_result, real_result 103 | -------------------------------------------------------------------------------- /torchsupport/training/neural_processes.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as func 2 | from torchsupport.training.vae import VAETraining 3 | import torchsupport.modules.losses.vae as vl 4 | 5 | class NPTraining(VAETraining): 6 | def __init__(self, encoder, decoder, aggregator, data, 7 | rec_loss=func.binary_cross_entropy_with_logits, **kwargs): 8 | super(NPTraining, self).__init__(encoder, decoder, data, **kwargs) 9 | self.aggregator = aggregator 10 | self.rec_loss = rec_loss 11 | 12 | def loss(self, source_parameters, total_parameters, reconstruction, target): 13 | loss_val = self.rec_loss(reconstruction, target) 14 | kld = vl.normal_kl_loss(total_parameters, source_parameters) 15 | return loss_val + kld 16 | 17 | def run_networks(self, data): 18 | xs, ys, source_indices, target_indices = data 19 | representation = self.encoder(xs, ys) 20 | s_mean, s_logvar = self.aggregator(representation, source_indices) 21 | t_mean, t_logvar = self.aggregator(representation, source_indices + target_indices) 22 | target_access = target_indices.nonzero() 23 | target = ys[target_access] 24 | reconstruction = self.decoder(xs[target_indices.nonzero()]) 25 | return (s_mean, s_logvar), (t_mean, t_logvar), reconstruction, target 26 | -------------------------------------------------------------------------------- /torchsupport/training/score_supervised.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as func 8 | 9 | from torchsupport.data.io import netwrite, to_device, make_differentiable 10 | from torchsupport.training.energy import DenoisingScoreTraining 11 | from torchsupport.training.samplers import AnnealedLangevin 12 | 13 | class ScoreSupervisedTraining(DenoisingScoreTraining): 14 | def logit_energy(self, logits): 15 | return -logits.logsumexp(dim=-1) 16 | 17 | def create_score(self): 18 | def _score(data, sigma, *args): 19 | score, logits = self.score(data, sigma, *args) 20 | return score 21 | return _score 22 | 23 | def classifier_loss(self, logits, labels): 24 | return func.cross_entropy(logits, labels) 25 | 26 | def sample(self): 27 | self.score.eval() 28 | with torch.no_grad(): 29 | integrator = AnnealedLangevin([ 30 | self.sigma * self.factor ** idx for idx in range(self.n_sigma) 31 | ]) 32 | prep = to_device(self.prepare_sample(), self.device) 33 | data, *args = self.data_key(prep) 34 | result = integrator.integrate( 35 | self.create_score(), 36 | data, *args 37 | ).detach() 38 | self.score.train() 39 | return to_device((result, data, *args), self.device) 40 | 41 | def run_energy(self, data): 42 | data, labels = data 43 | data, *args = self.data_key(data) 44 | noisy, sigma = self.noise(data) 45 | score, logits = self.score(noisy, sigma, *args) 46 | 47 | return score, data, noisy, sigma, logits, labels 48 | 49 | def energy_loss(self, score, data, noisy, sigma, logits, labels): 50 | energy = super().energy_loss(score, data, noisy, sigma) 51 | classifier = self.classifier_loss(logits, labels) 52 | self.current_losses["classifier"] = float(classifier) 53 | return energy + classifier 54 | -------------------------------------------------------------------------------- /torchsupport/training/state.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SaveStateError(Exception): 4 | pass 5 | 6 | class State: 7 | def __init__(self, name): 8 | self.name = name 9 | 10 | def read_action(self, training, data): 11 | setattr(training, self.name, data[self.name]) 12 | 13 | def write_action(self, training, data): 14 | data[self.name] = getattr(training, self.name) 15 | 16 | class PathState(State): 17 | def __init__(self, path): 18 | self.path = path[:-1] 19 | self.last = path[-1] 20 | self.name = "/".join(path) 21 | 22 | def walk(self, training): 23 | walk = training 24 | for step in self.path: 25 | walk = getattr(walk, step) 26 | return walk 27 | 28 | def read_action(self, training, data): 29 | setattr(self.walk(training), self.last, data[self.name]) 30 | 31 | def write_action(self, training, data): 32 | data[self.name] = getattr(self.walk(training), self.last) 33 | 34 | class NetState(State): 35 | def read_action(self, training, data): 36 | getattr(training, self.name).load_state_dict(data[self.name]) 37 | 38 | def write_action(self, training, data): 39 | network = getattr(training, self.name) 40 | if isinstance(network, torch.nn.Module): 41 | for param in network.parameters(): 42 | if torch.isnan(param).any(): 43 | raise SaveStateError("Encountered NaN weights!") 44 | data[self.name] = network.state_dict() 45 | 46 | class NetNameListState(NetState): 47 | def read_action(self, training, data): 48 | for key in data[self.name]: 49 | getattr(training, key).load_state_dict(data[self.name][key]) 50 | 51 | def write_action(self, training, data): 52 | net_dict = {} 53 | for key in getattr(training, self.name): 54 | network = getattr(training, key) 55 | if isinstance(network, torch.nn.Module): 56 | for param in network.parameters(): 57 | if torch.isnan(param).any(): 58 | raise SaveStateError("Encountered NaN weights!") 59 | net_dict[key] = network.state_dict() 60 | data[self.name] = net_dict 61 | 62 | class TrainingState(State): 63 | training_parameters = ["epoch_id", "step_id"] 64 | def __init__(self): 65 | super().__init__("training_state") 66 | 67 | def read_action(self, training, data): 68 | for name in self.training_parameters: 69 | setattr(training, name, data[name]) 70 | 71 | def write_action(self, training, data): 72 | for name in self.training_parameters: 73 | data[name] = getattr(training, name) 74 | -------------------------------------------------------------------------------- /torchsupport/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjendrusch/torchsupport/ac95f088da7e26e35d1e2ea4a22181f2f3682c8d/torchsupport/utils/__init__.py -------------------------------------------------------------------------------- /torchsupport/utils/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | import operator 4 | 5 | def memory_used(): 6 | result = {} 7 | for obj in gc.get_objects(): 8 | try: 9 | if torch.is_tensor(obj): 10 | if type(obj) not in result: 11 | result[type(obj)] = 0 12 | count = 1 13 | for elem in list(obj.data.size()): 14 | count *= elem 15 | result[type(obj.data)] += count * obj.data.element_size() 16 | elif hasattr(obj, 'data') and torch.is_tensor(obj.data): 17 | if type(obj.data) not in result: 18 | result[type(obj.data)] = 0 19 | count = 1 20 | for elem in list(obj.data.size()): 21 | count *= elem 22 | result[type(obj.data)] += count * obj.data.element_size() 23 | except: 24 | print("could not track ...") 25 | 26 | return result 27 | --------------------------------------------------------------------------------