├── requirements.txt ├── tyche ├── __init__.py ├── nn │ ├── __init__.py │ ├── vmap.py │ ├── cross_conv.py │ └── init.py ├── validation.py └── model.py ├── CITATION.bib ├── NOTICE ├── README.md └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.7.0 2 | pydantic==1.10.7 3 | torch==1.13.1 4 | validation==0.8.3 5 | -------------------------------------------------------------------------------- /tyche/__init__.py: -------------------------------------------------------------------------------- 1 | from . import model 2 | from .model import TysegXC, CrossOp, CrossOpTarget, tychets -------------------------------------------------------------------------------- /tyche/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from the UniverSeg authors. 3 | Please check: https://github.com/JJGO/UniverSeg/ 4 | """ 5 | from .cross_conv import CrossConv2d 6 | from .init import reset_conv2d_parameters 7 | from .vmap import vmap, Vmap 8 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{rakic2024tyche, 2 | title={Tyche: Stochastic In-Context Learning for Medical Image Segmentation}, 3 | author={Marianne Rakic and Hallee E. Wong and Jose Javier Gonzalez Ortiz and Beth Cimini and John V. Guttag and Adrian V. Dalca}, 4 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 5 | year={2024}, 6 | } 7 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright [2024] [Tyche authors] 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /tyche/validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Coded by the authors of UniverSeg. 3 | Please check: https://github.com/JJGO/UniverSeg 4 | Module containing utility functions for validating arguments using Pydantic. 5 | 6 | Functions: 7 | - as_2tuple(val: size2t) -> Tuple[int, int]: Convert integer or 2-tuple to 2-tuple format. 8 | - validate_arguments_init(class_) -> class_: Decorator to validate the arguments of the __init__ method using Pydantic. 9 | """ 10 | 11 | from typing import Any, Dict, Tuple, Union 12 | 13 | from pydantic import validate_arguments 14 | 15 | size2t = Union[int, Tuple[int, int]] 16 | Kwargs = Dict[str, Any] 17 | 18 | 19 | def as_2tuple(val: size2t) -> Tuple[int, int]: 20 | """ 21 | Convert integer or 2-tuple to 2-tuple format. 22 | 23 | Args: 24 | val (Union[int, Tuple[int, int]]): The value to convert. 25 | 26 | Returns: 27 | Tuple[int, int]: The converted 2-tuple. 28 | 29 | Raises: 30 | AssertionError: If val is not an integer or a 2-tuple with length 2. 31 | """ 32 | if isinstance(val, int): 33 | return (val, val) 34 | assert isinstance(val, (list, tuple)) and len(val) == 2 35 | return tuple(val) 36 | 37 | 38 | def validate_arguments_init(class_): 39 | """ 40 | Decorator to validate the arguments of the __init__ method using Pydantic. 41 | 42 | Args: 43 | class_ (Any): The class to decorate. 44 | 45 | Returns: 46 | class_: The decorated class with validated __init__ method. 47 | """ 48 | class_.__init__ = validate_arguments(class_.__init__) 49 | return class_ 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tyche: Stochastic In-Context Learning for Medical Image Segmentation 2 | 3 | 4 | Official website of: [_Tyche: Stochastic In-Context Learning for Medical Image Segmentation_](http://arxiv.org/abs/2401.13650). 5 | [Marianne Rakic](https://mariannerakic.github.io/), [Hallee E. Wong](https://halleewong.github.io/), [Jose Javier Gonzalez Ortiz](https://josejg.com/), 6 | [Beth Cimini](https://www.broadinstitute.org/bios/beth-cimini), [John V. Guttag](https://people.csail.mit.edu/guttag/) \& [Adrian V. Dalca](https://www.mit.edu/~adalca/) 7 | 8 | 9 | ## Abstract 10 | Existing learning-based solutions to medical image segmentation have two important shortcomings. First, for each new segmentation task, usually a new model has to be trained or fine-tuned. This requires extensive resources and machine-learning expertise, and is therefore often infeasible for medical researchers and clinicians. Second, most existing segmentation methods produce a single deterministic segmentation mask for a given image. However, in practice, there is often considerable uncertainty about what constitutes the _correct_ segmentation, and different expert annotators will often segment the same image differently. We tackle both of these problems with _Tyche_, a model that uses a context set to generate stochastic predictions for previously unseen tasks without the need to retrain. Tyche differs from other in-context segmentation methods in two important ways. 11 | 12 | 1. We introduce a novel convolution block architecture that enables interactions among predictions. 13 | 2. We introduce in-context test-time augmentation, a new mechanism to provide prediction stochasticity. 14 | 15 | When combined with appropriate model design and loss functions, Tyche can predict a set of plausible diverse segmentation candidates for new or unseen medical images and segmentation tasks without the need to retrain. 16 | 17 | 18 | [*Check out our website* :D](http://tyche.csail.mit.edu/) (demo included) 19 | -------------------------------------------------------------------------------- /tyche/nn/vmap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from the UniverSeg authors. 3 | Please check https://github.com/JJGO/UniverSeg/ 4 | """ 5 | 6 | from typing import Callable 7 | 8 | import einops as E 9 | import torch 10 | from torch import nn 11 | 12 | 13 | def vmap(module: Callable, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: 14 | """ 15 | Applies the given module over the initial batch dimension and the second (group) dimension. 16 | 17 | Args: 18 | module: a callable that is applied over the batch dimension and the second (group) dimension. 19 | must support batch operations 20 | x: tensor of shape (batch_size, group_size, ...). 21 | args: positional arguments to pass to `module`. 22 | kwargs: keyword arguments to pass to `module`. 23 | 24 | Returns: 25 | The output tensor with the same shape as the input tensor. 26 | """ 27 | batch_size, group_size, *_ = x.shape 28 | grouped_input = E.rearrange(x, "B S ... -> (B S) ...") 29 | grouped_output = module(grouped_input, *args, **kwargs) 30 | output = E.rearrange( 31 | grouped_output, "(B S) ... -> B S ...", B=batch_size, S=group_size 32 | ) 33 | return output 34 | 35 | 36 | def vmap_fn(fn: Callable) -> Callable: 37 | """ 38 | Returns a callable that applies the input function over the initial batch dimension and the second (group) dimension. 39 | 40 | Args: 41 | fn: function to apply over the batch dimension and the second (group) dimension. 42 | 43 | Returns: 44 | A callable that applies the input function over the initial batch dimension and the second (group) dimension. 45 | """ 46 | 47 | def vmapped_fn(*args, **kwargs): 48 | return vmap(fn, *args, **kwargs) 49 | 50 | return vmapped_fn 51 | 52 | 53 | class Vmap(nn.Module): 54 | def __init__(self, module: nn.Module): 55 | """ 56 | Applies the given module over the initial batch dimension and the second (group) dimension. 57 | 58 | Args: 59 | module: module to apply over the batch dimension and the second (group) dimension. 60 | """ 61 | super().__init__() 62 | self.vmapped = module 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | """ 66 | Applies the given module over the initial batch dimension and the second (group) dimension. 67 | 68 | Args: 69 | x: tensor of shape (batch_size, group_size, ...). 70 | 71 | Returns: 72 | The output tensor with the same shape as the input tensor. 73 | """ 74 | return vmap(self.vmapped, x) 75 | 76 | 77 | def vmap_cls(module_type: type) -> Callable: 78 | """ 79 | Returns a callable that applies the input module type over the initial batch dimension and the second (group) dimension. 80 | 81 | Args: 82 | module_type: module type to apply over the batch dimension and the second (group) dimension. 83 | 84 | Returns: 85 | A callable that applies the input module type over the initial batch dimension and the second (group) dimension. 86 | """ 87 | 88 | def vmapped_cls(*args, **kwargs): 89 | module = module_type(*args, **kwargs) 90 | return Vmap(module) 91 | 92 | return vmapped_cls 93 | -------------------------------------------------------------------------------- /tyche/nn/cross_conv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from UniverSeg authors. 3 | Please check: https://github.com/JJGO/UniverSeg/ 4 | """ 5 | 6 | from typing import Optional, Tuple, Union 7 | 8 | import einops as E 9 | import torch 10 | import torch.nn as nn 11 | from pydantic import validate_arguments 12 | 13 | size2t = Union[int, Tuple[int, int]] 14 | 15 | 16 | class CrossConv2d(nn.Conv2d): 17 | """ 18 | Compute pairwise convolution between all element of x and all elements of y. 19 | x, y are tensors of size B,_,C,H,W where _ could be different number of elements in x and y 20 | essentially, we do a meshgrid of the elements to get B,Sx,Sy,C,H,W tensors, and then 21 | pairwise conv. 22 | Args: 23 | x (tensor): B,Sx,Cx,H,W 24 | y (tensor): B,Sy,Cy,H,W 25 | Returns: 26 | tensor: B,Sx,Sy,Cout,H,W 27 | """ 28 | """ 29 | CrossConv2d is a convolutional layer that performs pairwise convolutions between elements of two input tensors. 30 | 31 | Parameters 32 | ---------- 33 | in_channels : int or tuple of ints 34 | Number of channels in the input tensor(s). 35 | If the tensors have different number of channels, in_channels must be a tuple 36 | out_channels : int 37 | Number of output channels. 38 | kernel_size : int or tuple of ints 39 | Size of the convolutional kernel. 40 | stride : int or tuple of ints, optional 41 | Stride of the convolution. Default is 1. 42 | padding : int or tuple of ints, optional 43 | Zero-padding added to both sides of the input. Default is 0. 44 | dilation : int or tuple of ints, optional 45 | Spacing between kernel elements. Default is 1. 46 | groups : int, optional 47 | Number of blocked connections from input channels to output channels. Default is 1. 48 | bias : bool, optional 49 | If True, adds a learnable bias to the output. Default is True. 50 | padding_mode : str, optional 51 | Padding mode. Default is "zeros". 52 | device : str, optional 53 | Device on which to allocate the tensor. Default is None. 54 | dtype : torch.dtype, optional 55 | Data type assigned to the tensor. Default is None. 56 | 57 | Returns 58 | ------- 59 | torch.Tensor 60 | Tensor resulting from the pairwise convolution between the elements of x and y. 61 | 62 | Notes 63 | ----- 64 | x and y are tensors of size (B, Sx, Cx, H, W) and (B, Sy, Cy, H, W), respectively, 65 | The function does the cartesian product of the elements of x and y to obtain a tensor 66 | of size (B, Sx, Sy, Cx + Cy, H, W), and then performs the same convolution for all 67 | (B, Sx, Sy) in the batch dimension. Runtime and memory are O(Sx * Sy). 68 | 69 | Examples 70 | -------- 71 | >>> x = torch.randn(2, 3, 4, 32, 32) 72 | >>> y = torch.randn(2, 5, 6, 32, 32) 73 | >>> conv = CrossConv2d(in_channels=(4, 6), out_channels=7, kernel_size=3, padding=1) 74 | >>> output = conv(x, y) 75 | >>> output.shape #(2, 3, 5, 7, 32, 32) 76 | """ 77 | 78 | @validate_arguments 79 | def __init__( 80 | self, 81 | in_channels: size2t, 82 | out_channels: int, 83 | kernel_size: size2t, 84 | stride: size2t = 1, 85 | padding: size2t = 0, 86 | dilation: size2t = 1, 87 | groups: int = 1, 88 | bias: bool = True, 89 | padding_mode: str = "zeros", 90 | device=None, 91 | dtype=None, 92 | ) -> None: 93 | 94 | if isinstance(in_channels, (list, tuple)): 95 | concat_channels = sum(in_channels) 96 | else: 97 | concat_channels = 2 * in_channels 98 | 99 | super().__init__( 100 | in_channels=concat_channels, 101 | out_channels=out_channels, 102 | kernel_size=kernel_size, 103 | stride=stride, 104 | padding=padding, 105 | dilation=dilation, 106 | groups=groups, 107 | bias=bias, 108 | padding_mode=padding_mode, 109 | device=device, 110 | dtype=dtype, 111 | ) 112 | 113 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 114 | """ 115 | Compute pairwise convolution between all elements of x and all elements of y. 116 | 117 | Parameters 118 | ---------- 119 | x : torch.Tensor 120 | Input tensor of size (B, Sx, Cx, H, W). 121 | y : torch.Tensor 122 | Input tensor of size (B, Sy, Cy, H, W). 123 | 124 | Returns 125 | ------- 126 | torch.Tensor 127 | Tensor resulting from the cross-convolution between the elements of x and y. 128 | Has size (B, Sx, Sy, Co, H, W), where Co is the number of output channels. 129 | """ 130 | B, Sx, *_ = x.shape 131 | _, Sy, *_ = y.shape 132 | 133 | xs = E.repeat(x, "B Sx Cx H W -> B Sx Sy Cx H W", Sy=Sy) 134 | ys = E.repeat(y, "B Sy Cy H W -> B Sx Sy Cy H W", Sx=Sx) 135 | 136 | xy = torch.cat([xs, ys], dim=3,) 137 | 138 | batched_xy = E.rearrange(xy, "B Sx Sy C2 H W -> (B Sx Sy) C2 H W") 139 | 140 | batched_output = super().forward(batched_xy) 141 | 142 | output = E.rearrange( 143 | batched_output, "(B Sx Sy) Co H W -> B Sx Sy Co H W", B=B, Sx=Sx, Sy=Sy 144 | ) 145 | return output 146 | -------------------------------------------------------------------------------- /tyche/nn/init.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from the UniverSeg authors 3 | Please check: https://github.com/JJGO/UniverSeg/ 4 | """ 5 | 6 | import warnings 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import init 12 | 13 | 14 | def initialize_weight( 15 | weight: torch.Tensor, 16 | distribution: Optional[str], 17 | nonlinearity: Optional[str] = "LeakyReLU", 18 | ) -> None: 19 | """Initialize the weight tensor with a chosen distribution and nonlinearity. 20 | 21 | Args: 22 | weight (torch.Tensor): The weight tensor to initialize. 23 | distribution (Optional[str]): The distribution to use for initialization. Can be one of "zeros", 24 | "kaiming_normal", "kaiming_uniform", "kaiming_normal_fanout", "kaiming_uniform_fanout", 25 | "glorot_normal", "glorot_uniform", or "orthogonal". 26 | nonlinearity (Optional[str]): The type of nonlinearity to use. Can be one of "LeakyReLU", "Sine", 27 | "Tanh", "Silu", or "Gelu". 28 | 29 | Returns: 30 | None 31 | """ 32 | 33 | if distribution is None: 34 | return 35 | 36 | if nonlinearity: 37 | nonlinearity = nonlinearity.lower() 38 | if nonlinearity == "leakyrelu": 39 | nonlinearity = "leaky_relu" 40 | 41 | if nonlinearity == "sine": 42 | warnings.warn("sine gain not implemented, defaulting to tanh") 43 | nonlinearity = "tanh" 44 | 45 | if nonlinearity is None: 46 | nonlinearity = "linear" 47 | 48 | if nonlinearity in ("silu", "gelu"): 49 | nonlinearity = "leaky_relu" 50 | 51 | gain = 1 if nonlinearity is None else init.calculate_gain(nonlinearity) 52 | 53 | if distribution == "zeros": 54 | init.zeros_(weight) 55 | elif distribution == "kaiming_normal": 56 | init.kaiming_normal_(weight, nonlinearity=nonlinearity) 57 | elif distribution == "kaiming_uniform": 58 | init.kaiming_uniform_(weight, nonlinearity=nonlinearity) 59 | elif distribution == "kaiming_normal_fanout": 60 | init.kaiming_normal_(weight, nonlinearity=nonlinearity, mode="fan_out") 61 | elif distribution == "kaiming_uniform_fanout": 62 | init.kaiming_uniform_(weight, nonlinearity=nonlinearity, mode="fan_out") 63 | elif distribution == "glorot_normal": 64 | init.xavier_normal_(weight, gain=gain) 65 | elif distribution == "glorot_uniform": 66 | init.xavier_uniform_(weight, gain) 67 | elif distribution == "orthogonal": 68 | init.orthogonal_(weight, gain) 69 | else: 70 | raise ValueError(f"Unsupported distribution '{distribution}'") 71 | 72 | 73 | def initialize_bias( 74 | bias: torch.Tensor, 75 | distribution: Optional[float] = 0, 76 | nonlinearity: Optional[str] = "LeakyReLU", 77 | weight: Optional[torch.Tensor] = None, 78 | ) -> None: 79 | """Initialize the bias tensor with a constant or a chosen distribution and nonlinearity. 80 | 81 | Args: 82 | bias (torch.Tensor): The bias tensor to initialize. 83 | distribution (Optional[float]): The constant value to initialize the bias to. 84 | nonlinearity (Optional[str]): The type of nonlinearity to use when initializing the bias. 85 | weight (Optional[torch.Tensor]): The weight tensor to use when initializing the bias. 86 | 87 | Returns: 88 | None 89 | """ 90 | 91 | if distribution is None: 92 | return 93 | 94 | if isinstance(distribution, (int, float)): 95 | init.constant_(bias, distribution) 96 | else: 97 | raise NotImplementedError(f"Unsupported distribution '{distribution}'") 98 | 99 | 100 | def initialize_layer( 101 | layer: nn.Module, 102 | distribution: Optional[str] = "kaiming_normal", 103 | init_bias: Optional[float] = 0, 104 | nonlinearity: Optional[str] = "LeakyReLU", 105 | ) -> None: 106 | """Initialize the weight and bias tensors of a linear or convolutional layer. 107 | 108 | Args: 109 | layer (nn.Module): The layer to initialize. 110 | distribution (Optional[str]): The distribution to use for weight initialization. 111 | init_bias (Optional[float]): The value to use for bias initialization. 112 | nonlinearity (Optional[str]): The type of nonlinearity to use when initializing the layer. 113 | 114 | Returns: 115 | None 116 | """ 117 | 118 | assert isinstance( 119 | layer, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d) 120 | ), f"Can only be applied to linear and conv layers, given {layer.__class__.__name__}" 121 | 122 | initialize_weight(layer.weight, distribution, nonlinearity) 123 | if layer.bias is not None: 124 | initialize_bias( 125 | layer.bias, init_bias, nonlinearity=nonlinearity, weight=layer.weight 126 | ) 127 | 128 | 129 | def reset_conv2d_parameters( 130 | model: nn.Module, 131 | init_distribution: Optional[str], 132 | init_bias: Optional[float], 133 | nonlinearity: Optional[str], 134 | ) -> None: 135 | """Reset the parameters of all convolutional layers in the model. 136 | 137 | Args: 138 | model (nn.Module): The model to reset the convolutional layers of. 139 | init_distribution (Optional[str]): The distribution to use for weight initialization. 140 | init_bias (Optional[float]): The value to use for bias initialization. 141 | nonlinearity (Optional[str]): The type of nonlinearity to use when initializing the layers. 142 | 143 | Returns: 144 | None 145 | """ 146 | 147 | for name, module in model.named_modules(): 148 | if isinstance(module, nn.Conv2d): 149 | initialize_layer( 150 | module, 151 | distribution=init_distribution, 152 | init_bias=init_bias, 153 | nonlinearity=nonlinearity, 154 | ) 155 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tyche/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from os import device_encoding 3 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union 4 | 5 | import einops as E 6 | import torch 7 | from torch import nn 8 | 9 | from .nn.cross_conv import CrossConv2d 10 | from .nn.init import reset_conv2d_parameters 11 | from .nn.vmap import Vmap, vmap 12 | from .validation import ( 13 | Kwargs, 14 | as_2tuple, 15 | size2t, 16 | validate_arguments, 17 | validate_arguments_init, 18 | ) 19 | 20 | 21 | def get_nonlinearity(nonlinearity: Optional[str]) -> nn.Module: 22 | if nonlinearity is None: 23 | return nn.Identity() 24 | if nonlinearity == "Softmax": 25 | # For Softmax, we need to specify the channel dimension 26 | return nn.Softmax(dim=1) 27 | if hasattr(nn, nonlinearity): 28 | return getattr(nn, nonlinearity)() 29 | raise ValueError(f"nonlinearity {nonlinearity} not found") 30 | 31 | 32 | @validate_arguments_init 33 | @dataclass(eq=False, repr=False) 34 | class ConvOp(nn.Sequential): 35 | 36 | in_channels: int 37 | out_channels: int 38 | kernel_size: size2t = 3 39 | nonlinearity: Optional[str] = "LeakyReLU" 40 | init_distribution: Optional[str] = "kaiming_normal" 41 | init_bias: Union[None, float, int] = 0.0 42 | 43 | def __post_init__(self): 44 | super().__init__() 45 | self.conv = nn.Conv2d( 46 | self.in_channels, 47 | self.out_channels, 48 | kernel_size=self.kernel_size, 49 | padding=self.kernel_size // 2, 50 | padding_mode="zeros", 51 | bias=True, 52 | ) 53 | 54 | if self.nonlinearity is not None: 55 | self.nonlin = get_nonlinearity(self.nonlinearity) 56 | 57 | reset_conv2d_parameters( 58 | self, self.init_distribution, self.init_bias, self.nonlinearity 59 | ) 60 | 61 | 62 | @validate_arguments_init 63 | @dataclass(eq=False, repr=False) 64 | class CrossOp(nn.Module): 65 | 66 | in_channels: size2t 67 | out_channels: int 68 | kernel_size: size2t = 3 69 | nonlinearity: Optional[str] = "LeakyReLU" 70 | init_distribution: Optional[str] = "kaiming_normal" 71 | init_bias: Union[None, float, int] = 0.0 72 | 73 | def __post_init__(self): 74 | super().__init__() 75 | 76 | self.cross_conv = CrossConv2d( 77 | in_channels=as_2tuple(self.in_channels), 78 | out_channels=self.out_channels, 79 | kernel_size=self.kernel_size, 80 | padding=self.kernel_size // 2, 81 | ) 82 | 83 | if self.nonlinearity is not None: 84 | self.nonlin = get_nonlinearity(self.nonlinearity) 85 | 86 | 87 | reset_conv2d_parameters( 88 | self, self.init_distribution, self.init_bias, self.nonlinearity 89 | ) 90 | 91 | def forward(self, target, support): 92 | interaction = self.cross_conv(target, support).squeeze(dim=1) 93 | 94 | if self.nonlinearity is not None: 95 | interaction = vmap(self.nonlin, interaction) 96 | 97 | new_target = interaction.mean(dim=1, keepdims=True) 98 | 99 | return new_target, interaction 100 | 101 | 102 | @validate_arguments_init 103 | @dataclass(eq=False, repr=False) 104 | class CrossOpTarget(nn.Module): 105 | 106 | in_channels: size2t 107 | out_channels: int 108 | kernel_size: size2t = 3 109 | nonlinearity: Optional[str] = "LeakyReLU" 110 | init_distribution: Optional[str] = "kaiming_normal" 111 | init_bias: Union[None, float, int] = 0.0 112 | 113 | def __post_init__(self): 114 | super().__init__() 115 | 116 | self.cross_conv = CrossConv2d( 117 | in_channels=as_2tuple(self.in_channels), 118 | out_channels=self.out_channels, 119 | kernel_size=self.kernel_size, 120 | padding=self.kernel_size // 2, 121 | ) 122 | 123 | if self.nonlinearity is not None: 124 | self.nonlin = get_nonlinearity(self.nonlinearity) 125 | 126 | 127 | reset_conv2d_parameters( 128 | self, self.init_distribution, self.init_bias, self.nonlinearity 129 | ) 130 | 131 | def forward(self, target, support): 132 | interaction = self.cross_conv(target, support).squeeze(dim=1) 133 | 134 | if self.nonlinearity is not None: 135 | interaction = vmap(self.nonlin, interaction) 136 | 137 | return interaction 138 | 139 | 140 | class Residual(nn.Module): 141 | @validate_arguments 142 | def __init__( 143 | self, module, in_channels: int, out_channels: int, 144 | ): 145 | super().__init__() 146 | self.main = module 147 | self.in_channels = in_channels 148 | self.out_channels = out_channels 149 | if in_channels == out_channels: 150 | self.shortcut = nn.Identity() 151 | else: 152 | # TODO do we want to init these to 1, like controlnet's zeroconv 153 | # TODO do we want to initialize these like the other conv layers 154 | self.shortcut = nn.Conv2d( 155 | in_channels, out_channels, kernel_size=1, bias=False 156 | ) 157 | reset_conv2d_parameters(self.shortcut, "kaiming_normal", 0.0) 158 | 159 | def forward(self, input): 160 | return self.main(input) + self.shortcut(input) 161 | 162 | 163 | class VResidual(Residual): 164 | def forward(self, input): 165 | return self.main(input) + vmap(self.shortcut, input) 166 | 167 | 168 | @validate_arguments_init 169 | @dataclass(eq=False, repr=False) 170 | class CrossBlockTarget(nn.Module): 171 | 172 | in_channels: size2t 173 | cross_features: int 174 | conv_features: Optional[int] = None 175 | cross_kws: Optional[Dict[str, Any]] = None 176 | conv_kws: Optional[Dict[str, Any]] = None 177 | 178 | def __post_init__(self): 179 | super().__init__() 180 | 181 | conv_features = self.conv_features or self.cross_features 182 | cross_kws = self.cross_kws or {} 183 | conv_kws = self.conv_kws or {} 184 | 185 | self.cross = CrossOp(self.in_channels, self.cross_features, **cross_kws) 186 | 187 | if isinstance(self.in_channels, tuple): 188 | mean_convs = self.in_channels[0] + self.cross_features 189 | else: 190 | mean_convs = self.in_channels + self.cross_features 191 | 192 | 193 | self.meanconv = Vmap(ConvOp(mean_convs, self.conv_features, **self.conv_kws or {})) 194 | 195 | self.target = Vmap(ConvOp(self.cross_features, conv_features, **conv_kws)) 196 | self.support = Vmap(ConvOp(self.cross_features, conv_features, **conv_kws)) 197 | 198 | def forward(self, target, support): 199 | mean_img = target.mean(dim=1, keepdims=True)#[:, None, ...] 200 | 201 | mean_img, support = self.cross(mean_img, support) 202 | 203 | K = target.shape[1] 204 | 205 | mean_img = E.repeat(mean_img, 'B 1 C H W -> B K C H W', K=K) 206 | 207 | target = torch.cat([target, mean_img], dim=2) 208 | target = self.meanconv(target) 209 | 210 | target = self.target(target) 211 | support = self.support(support) 212 | return target, support 213 | 214 | 215 | @validate_arguments_init 216 | @dataclass(eq=False, repr=False) 217 | class TysegXC(nn.Module): 218 | 219 | encoder_blocks: List[size2t] 220 | decoder_blocks: Optional[List[size2t]] = None 221 | cross_relu: bool = True 222 | 223 | def __post_init__(self): 224 | super().__init__() 225 | 226 | self.downsample = nn.MaxPool2d(2, 2) 227 | self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) 228 | 229 | self.enc_blocks = nn.ModuleList() 230 | self.dec_blocks = nn.ModuleList() 231 | 232 | encoder_blocks = list(map(as_2tuple, self.encoder_blocks)) 233 | decoder_blocks = self.decoder_blocks or encoder_blocks[-2::-1] 234 | decoder_blocks = list(map(as_2tuple, decoder_blocks)) 235 | 236 | if self.cross_relu: 237 | block_kws = dict(cross_kws=dict(nonlinearity="LeakyReLU")) 238 | else: 239 | block_kws = dict(cross_kws=dict(nonlinearity=None)) 240 | 241 | in_ch = (2, 2) 242 | out_channels = 1 #3 outputs to compare to SAM 243 | out_activation = None 244 | 245 | # Encoder 246 | skip_outputs = [] 247 | for (cross_ch, conv_ch) in encoder_blocks: 248 | block = CrossBlockTarget(in_ch, cross_ch, conv_ch, **block_kws) 249 | in_ch = conv_ch 250 | self.enc_blocks.append(block) 251 | skip_outputs.append(in_ch) 252 | 253 | # Decoder 254 | skip_chs = skip_outputs[-2::-1] 255 | for (cross_ch, conv_ch), skip_ch in zip(decoder_blocks, skip_chs): 256 | block = CrossBlockTarget(in_ch + skip_ch, cross_ch, conv_ch, **block_kws) 257 | in_ch = conv_ch 258 | self.dec_blocks.append(block) 259 | 260 | self.out_conv = ConvOp( 261 | in_ch, out_channels, kernel_size=1, nonlinearity=out_activation, 262 | ) 263 | 264 | def get_model_inputs(self, x, sx, sy, target_size, aug=False): 265 | """ 266 | Gather all the input for the model and put them in the right shape. 267 | """ 268 | x= self.format_target(x, target_size) 269 | noise = torch.randn_like(x) 270 | 271 | sx, sy = [self.format_support(i) for i in [sx, sy]] 272 | 273 | return {'support_images': sx, 'support_labels': sy, 274 | 'target_image': x, 'noise_image': noise} 275 | 276 | 277 | def format_target(self, x, target_size): 278 | """ 279 | This is meant exclusively in an inference setting with batch size of 1. 280 | If target, should have shape: (1 1 H W) or (1, H, W). 281 | For model input, it needs to have shape: (B, K, X, H, W) with B=1. 282 | """ 283 | x_shape = x.shape 284 | 285 | if x_shape[0]>1: 286 | if len(x_shape)==5: # B 1 1 H W 287 | assert x_shape[2]==1, x_shape 288 | assert x_shape[1]==1, x_shape 289 | x = E.repeat(x, 'B 1 1 H W -> B K 1 H W', K=target_size) 290 | return x 291 | 292 | elif len(x_shape)==4: # B 1 H W 293 | x = E.repeat(x, 'B 1 H W -> B K H W', K=target_size) 294 | return x[:, :, None, :, :] 295 | 296 | elif len(x_shape)==3: # B H W 297 | x = E.rearrange(x, "B H W -> B 1 1 H W") 298 | x = E.repeat(x, 'B 1 1 H W -> B K 1 H W', K=target_size) 299 | return x 300 | 301 | else: 302 | return 0 303 | 304 | 305 | assert 3<=len(x_shape)<=4, 'Input should have shape 1, 1, H, W or 1, H, W.' 306 | 307 | if len(x_shape)==3: 308 | x = x[None] 309 | 310 | x = E.repeat(x, '1 1 H W -> K 1 H W', K=target_size) 311 | x = E.rearrange(x, "K 1 H W -> 1 K 1 H W") 312 | return x 313 | 314 | def format_support(self, sx,): 315 | """ 316 | This is meant exclusively in an inference setting with batch size of 1. 317 | If support, should have shape: (S 1 H W) or (1, S, 1, H, W). 318 | For model input, it needs to have shape: (B, S, 1, H, W) with B=1. 319 | """ 320 | sx_shape = sx.shape 321 | if len(sx_shape)==4: 322 | assert sx_shape[1]==1, 'Support should have shape (1, S, 1, H, W), or (S, 1, H, W).' 323 | sx = sx[None] 324 | 325 | elif len(sx_shape)==5: 326 | assert sx_shape[2] == 1, 'Support should have shape (S, 1, H, W) or (1, S, 1, H, W)' 327 | 328 | return sx 329 | 330 | def format_pred(self, yhat): 331 | assert len(yhat.shape)==5 332 | assert yhat.shape[2] == 1, 'Prediction should have shape (1, K, 1, H, W)' 333 | return yhat[:, :, 0] 334 | 335 | def format_label(self, y, target_size): 336 | ''' 337 | Go back from (1, K, 1, H, W) to (1, K, H, W) 338 | ''' 339 | y_shape = y.shape 340 | assert y_shape[0]==1, 'Input should have batch size and channel of 1.' 341 | assert 3<=len(y_shape)<=4, 'Input should have shape 1, 1, H, W or 1, H, W.' 342 | 343 | if len(y_shape)==3: 344 | y = y[None] 345 | 346 | y = E.repeat(y, '1 1 H W -> 1 K H W', K=target_size) 347 | 348 | return y 349 | 350 | def pred_ged_stats(self, m_inputs, sigmoid=True): 351 | model_inputs = self.get_model_inputs(**m_inputs) 352 | yhat_tmp = self.forward(**model_inputs) 353 | yhat_tmp = self.format_pred(yhat_tmp) 354 | 355 | if sigmoid: 356 | yhat_tmp = torch.sigmoid(yhat_tmp) 357 | 358 | return yhat_tmp 359 | 360 | def forward(self, support_images, support_labels, target_image, noise_image): 361 | 362 | 363 | target = torch.cat([target_image, noise_image], dim=2) 364 | support = torch.cat([support_images, support_labels], dim=2) 365 | 366 | B, K, _, _, _ = target.shape 367 | 368 | pass_through = [] 369 | 370 | for i, encoder_block in enumerate(self.enc_blocks): 371 | target, support = encoder_block(target, support) 372 | if i == len(self.encoder_blocks) - 1: 373 | break 374 | pass_through.append((target, support)) 375 | target = vmap(self.downsample, target) 376 | support = vmap(self.downsample, support) 377 | 378 | for decoder_block in self.dec_blocks: 379 | target_skip, support_skip = pass_through.pop() 380 | target = torch.cat([vmap(self.upsample, target), target_skip], dim=2) 381 | support = torch.cat([vmap(self.upsample, support), support_skip], dim=2) 382 | target, support = decoder_block(target, support) 383 | 384 | target = E.rearrange(target, "B K C H W -> (B K) C H W") 385 | target = self.out_conv(target) 386 | 387 | target = E.rearrange(target, "(B K) C H W -> B K C H W", B=B, K=K) 388 | 389 | return target 390 | 391 | 392 | @validate_arguments 393 | def tychets(version: Literal["v1"] = "v1", pretrained: bool = False) -> nn.Module: 394 | weights = { 395 | "v1": "https://github.com/mariannerakic/Tyche/releases/download/weights/tyche_v1_model_weights_CVPR.pt" 396 | } 397 | if version == "v1": 398 | model = TysegXC(encoder_blocks=[64, 64, 64, 64]) 399 | 400 | if pretrained: 401 | state_dict = torch.hub.load_state_dict_from_url(weights[version]) 402 | model.load_state_dict(state_dict['model']) 403 | 404 | return model 405 | --------------------------------------------------------------------------------