├── .gitignore ├── LICENSE ├── README.md ├── example_usage.ipynb ├── featup ├── __init__.py ├── adaptive_conv_cuda │ ├── __init__.py │ ├── adaptive_conv.cpp │ ├── adaptive_conv.py │ ├── adaptive_conv_cuda.cpp │ └── adaptive_conv_kernel.cu ├── configs │ ├── implicit_upsampler.yaml │ ├── jbu_upsampler.yaml │ └── train_probe.yaml ├── datasets │ ├── COCO.py │ ├── DAVIS.py │ ├── EmbeddingFile.py │ ├── HighResEmbs.py │ ├── ImageNetSubset.py │ ├── JitteredImage.py │ ├── SampleImage.py │ ├── __init__.py │ └── util.py ├── downsamplers.py ├── featurizers │ ├── CLIP.py │ ├── DINO.py │ ├── DINOv2.py │ ├── DeepLabV3.py │ ├── MAE.py │ ├── MIDAS.py │ ├── MaskCLIP.py │ ├── ResNet.py │ ├── __init__.py │ ├── dinov2 │ │ ├── __init__.py │ │ └── layers │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── block.py │ │ │ ├── dino_head.py │ │ │ ├── drop_path.py │ │ │ ├── layer_scale.py │ │ │ ├── mlp.py │ │ │ ├── patch_embed.py │ │ │ └── swiglu_ffn.py │ ├── maskclip │ │ ├── README.md │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── clip.py │ │ ├── interpolate.py │ │ ├── model.py │ │ └── simple_tokenizer.py │ ├── modules │ │ ├── __init__.py │ │ ├── layers.py │ │ ├── resnet.py │ │ └── vgg.py │ └── util.py ├── layers.py ├── losses.py ├── plotting.py ├── train_implicit_upsampler.py ├── train_jbu_upsampler.py ├── train_probes.py ├── upsamplers.py └── util.py ├── gradio_app.py ├── hubconf.py ├── manifest.in ├── sample-images ├── bird_full.jpg ├── bird_left.jpg ├── bird_right.jpg ├── cones2.jpg ├── cones3.jpg ├── dog_man_1_crop.jpg ├── plant.png ├── skate.jpg └── teaser_wide.png ├── setup.py └── simple_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Mark Hamilton. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a 6 | copy of this software and associated documentation files (the 7 | "Software"), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included 14 | in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 17 | OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 19 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 20 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 22 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FeatUp: A Model-Agnostic Framework for Features at Any Resolution 2 | ### ICLR 2024 3 | 4 | 5 | [![Website](https://img.shields.io/badge/FeatUp-%F0%9F%8C%90Website-purple?style=flat)](https://aka.ms/featup) [![arXiv](https://img.shields.io/badge/arXiv-2403.10516-b31b1b.svg)](https://arxiv.org/abs/2403.10516) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mhamilton723/FeatUp/blob/main/example_usage.ipynb) 6 | [![Huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-FeatUp-orange)](https://huggingface.co/spaces/mhamilton723/FeatUp) 7 | [![Huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Paper%20Page-orange)](https://huggingface.co/papers/2403.10516) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/featup-a-model-agnostic-framework-for/feature-upsampling-on-imagenet)](https://paperswithcode.com/sota/feature-upsampling-on-imagenet?p=featup-a-model-agnostic-framework-for) 9 | 10 | 11 | 12 | [Stephanie Fu*](https://stephanie-fu.github.io/), 13 | [Mark Hamilton*](https://mhamilton.net/), 14 | [Laura Brandt](https://people.csail.mit.edu/lebrandt/), 15 | [Axel Feldman](https://feldmann.nyc/), 16 | [Zhoutong Zhang](https://ztzhang.info/), 17 | [William T. Freeman](https://billf.mit.edu/about/bio) 18 | *Equal Contribution. 19 | 20 | ![FeatUp Overview Graphic](https://mhamilton.net/images/website_hero_small-p-1080.jpg) 21 | 22 | *TL;DR*:FeatUp improves the spatial resolution of any model's features by 16-32x without changing their semantics. 23 | 24 | https://github.com/mhamilton723/FeatUp/assets/6456637/8fb5aa7f-4514-4a97-aebf-76065163cdfd 25 | 26 | 27 | ## Contents 28 | 29 | * [Install](#install) 30 | * [Using Pretrained Upsamplers](#using-pretrained-upsamplers) 31 | * [Fitting an Implicit Upsampler](#fitting-an-implicit-upsampler-to-an-image) 32 | * [Coming Soon](coming-soon) 33 | * [Citation](#citation) 34 | * [Contact](#contact) 35 | 36 | 37 | ## Install 38 | 39 | ### Pip 40 | For those just looking to quickly use the FeatUp APIs install via: 41 | ```shell script 42 | pip install git+https://github.com/mhamilton723/FeatUp 43 | ``` 44 | 45 | ### Local Development 46 | To install FeatUp for local development and to get access to the sample images install using the following: 47 | ```shell script 48 | git clone https://github.com/mhamilton723/FeatUp.git 49 | cd FeatUp 50 | pip install -e . 51 | ``` 52 | 53 | ## Using Pretrained Upsamplers 54 | 55 | To see examples of pretrained model usage please see our [Collab notebook](https://colab.research.google.com/github/mhamilton723/FeatUp/blob/main/example_usage.ipynb). We currently supply the following pretrained versions of FeatUp's JBU upsampler: 56 | 57 | | Model Name | Checkpoint | Checkpoint (No LayerNorm) | Torch Hub Repository | Torch Hub Name | 58 | |------------|----------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------|----------------------|----------------| 59 | | DINO | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/dino16_jbu_stack_cocostuff.ckpt) | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/no_norm/dino16_jbu_stack_cocostuff.ckpt) | mhamilton723/FeatUp | dino16 | 60 | | DINO v2 | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/dinov2_jbu_stack_cocostuff.ckpt) | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/no_norm/dinov2_jbu_stack_cocostuff.ckpt) | mhamilton723/FeatUp | dinov2 | 61 | | CLIP | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/clip_jbu_stack_cocostuff.ckpt) | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/no_norm/clip_jbu_stack_cocostuff.ckpt) | mhamilton723/FeatUp | clip | 62 | | MaskCLIP | n/a | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/no_norm/maskclip_jbu_stack_cocostuff.ckpt) | mhamilton723/FeatUp | maskclip | 63 | | ViT | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/vit_jbu_stack_cocostuff.ckpt) | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/no_norm/vit_jbu_stack_cocostuff.ckpt) | mhamilton723/FeatUp | vit | 64 | | ResNet50 | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/resnet50_jbu_stack_cocostuff.ckpt) | [Download](https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/no_norm/resnet50_jbu_stack_cocostuff.ckpt) | mhamilton723/FeatUp | resnet50 | 65 | 66 | For example, to load the FeatUp JBU upsampler for the DINO backbone without an additional LayerNorm on the spatial features: 67 | 68 | ```python 69 | upsampler = torch.hub.load("mhamilton723/FeatUp", 'dino16', use_norm=False) 70 | ``` 71 | 72 | To load upsamplers trained on backbones with additional LayerNorm operations which makes training and transfer learning a bit more stable: 73 | 74 | ```python 75 | upsampler = torch.hub.load("mhamilton723/FeatUp", 'dino16') 76 | ``` 77 | 78 | ## Fitting an Implicit Upsampler to an Image 79 | 80 | To train an implicit upsampler for a given image and backbone first clone the repository and install it for 81 | [local development](#local-development). Then run 82 | 83 | ```python 84 | cd featup 85 | python train_implicit_upsampler.py 86 | ``` 87 | 88 | Parameters for this training operation can be found in the [implicit_upsampler config file](featup/configs/implicit_upsampler.yaml). 89 | 90 | ## Local Gradio Demo 91 | 92 | To run our [HuggingFace Spaces hosted FeatUp demo](https://huggingface.co/spaces/mhamilton723/FeatUp) locally first install FeatUp for local development. Then run: 93 | 94 | ```shell 95 | python gradio_app.py 96 | ``` 97 | 98 | Wait a few seconds for the demo to spin up, then navigate to [http://localhost:7860/](http://localhost:7860/) to view the demo. 99 | 100 | 101 | ## Coming Soon: 102 | 103 | - Training your own FeatUp joint bilateral upsampler 104 | - Simple API for Implicit FeatUp training 105 | 106 | 107 | ## Citation 108 | 109 | ``` 110 | @inproceedings{ 111 | fu2024featup, 112 | title={FeatUp: A Model-Agnostic Framework for Features at Any Resolution}, 113 | author={Stephanie Fu and Mark Hamilton and Laura E. Brandt and Axel Feldmann and Zhoutong Zhang and William T. Freeman}, 114 | booktitle={The Twelfth International Conference on Learning Representations}, 115 | year={2024}, 116 | url={https://openreview.net/forum?id=GkJiNn2QDF} 117 | } 118 | ``` 119 | 120 | ## Contact 121 | 122 | For feedback, questions, or press inquiries please contact [Stephanie Fu](mailto:fus@mit.edu) and [Mark Hamilton](mailto:markth@mit.edu) 123 | -------------------------------------------------------------------------------- /featup/__init__.py: -------------------------------------------------------------------------------- 1 | from featup.upsamplers import JBULearnedRange -------------------------------------------------------------------------------- /featup/adaptive_conv_cuda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/featup/adaptive_conv_cuda/__init__.py -------------------------------------------------------------------------------- /featup/adaptive_conv_cuda/adaptive_conv.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | using torch::Tensor; 5 | 6 | Tensor adaptive_conv_forward(Tensor input, Tensor filters) { 7 | 8 | assert(input.dtype() == filters.dtype()); 9 | 10 | auto B = input.sizes()[0]; 11 | auto C_in = input.sizes()[1]; 12 | auto H_in = input.sizes()[2]; 13 | auto W_in = input.sizes()[3]; 14 | 15 | assert(filters.sizes()[0] == B); 16 | auto H_out = filters.sizes()[1]; 17 | auto W_out = filters.sizes()[2]; 18 | auto I = filters.sizes()[3]; 19 | auto J = filters.sizes()[4]; 20 | 21 | assert(I == J); 22 | assert(H_out + I - 1 == H_in); 23 | assert(W_out + J - 1 == W_in); 24 | 25 | auto out = torch::zeros({ B, C_in, H_out, W_out }, input.dtype()); 26 | 27 | // output stationary 28 | for (uint32_t b = 0; b < B; b++) { 29 | for (uint32_t c = 0; c < C_in; c++) { 30 | for (uint32_t h = 0; h < H_out; h++) { 31 | for (uint32_t w = 0; w < W_out; w++) { 32 | // produce output pixel b, h, w, c 33 | for (uint32_t i = 0; i < I; i++) { 34 | for (uint32_t j = 0; j < J; j++) { 35 | auto weight = filters[b][h][w][i][j]; 36 | assert(h+i < H_in); 37 | assert(w+j < W_in); 38 | auto input_val = input[b][c][h+i][w+j]; 39 | out[b][c][h][w] += weight * input_val; 40 | } 41 | } 42 | } 43 | } 44 | } 45 | } 46 | return out; 47 | } 48 | 49 | Tensor adaptive_conv_grad_input(Tensor grad_output, Tensor filters) { 50 | 51 | auto B = grad_output.sizes()[0]; 52 | auto C = grad_output.sizes()[1]; 53 | auto H_out = grad_output.sizes()[2]; 54 | auto W_out = grad_output.sizes()[3]; 55 | 56 | assert(filters.sizes()[0] == B); 57 | assert(filters.sizes()[1] == H_out); 58 | assert(filters.sizes()[2] == W_out); 59 | auto I = filters.sizes()[3]; 60 | auto J = filters.sizes()[4]; 61 | assert(I == J); 62 | 63 | auto H_in = H_out + I - 1; 64 | auto W_in = W_out + J - 1; 65 | 66 | assert(grad_output.dtype() == filters.dtype()); 67 | 68 | auto out = torch::zeros({ B, C, H_in, W_in }, grad_output.dtype()); 69 | 70 | for (int32_t b = 0; b < B; b++) { 71 | for (int32_t c = 0; c < C; c++) { 72 | for (int32_t h = 0; h < H_in; h++) { 73 | for (int32_t w = 0; w < W_in; w++) { 74 | for (int32_t i = 0; i < I; i++) { 75 | for (int32_t j = 0; j < J; j++) { 76 | 77 | int32_t h_out = h - i; 78 | int32_t w_out = w - j; 79 | 80 | if ((h_out >= 0) && (w_out >= 0) && (h_out < H_out) && (w_out < W_out)) { 81 | auto grad = grad_output[b][c][h_out][w_out]; 82 | auto weight = filters[b][h_out][w_out][i][j]; 83 | 84 | out[b][c][h][w] += grad * weight; 85 | } 86 | } 87 | } 88 | } 89 | } 90 | } 91 | } 92 | return out; 93 | } 94 | 95 | Tensor adaptive_conv_grad_filters(Tensor grad_output, Tensor input) { 96 | 97 | auto B = grad_output.sizes()[0]; 98 | auto C = grad_output.sizes()[1]; 99 | auto H_out = grad_output.sizes()[2]; 100 | auto W_out = grad_output.sizes()[3]; 101 | 102 | assert(input.sizes()[0] == B); 103 | assert(input.sizes()[1] == C); 104 | auto H_in = input.sizes()[2]; 105 | auto W_in = input.sizes()[3]; 106 | 107 | assert(H_in > H_out); 108 | assert(W_in > W_out); 109 | 110 | auto I = W_in - W_out + 1; 111 | auto J = H_in - H_out + 1; 112 | 113 | assert(grad_output.dtype() == input.dtype()); 114 | 115 | auto out = torch::zeros({ B, H_out, W_out, I, J }, grad_output.dtype()); 116 | 117 | for (uint32_t b = 0; b < B; b++) { 118 | for (uint32_t h = 0; h < H_out; h++) { 119 | for (uint32_t w = 0; w < W_out; w++) { 120 | for (uint32_t i = 0; i < I; i++) { 121 | for (uint32_t j = 0; j < J; j++) { 122 | for (uint32_t c = 0; c < C; c++) { 123 | auto grad = grad_output[b][c][h][w]; 124 | assert(h + i < H_in); 125 | assert(w + j < W_in); 126 | auto input_val = input[b][c][h+i][w+j]; 127 | out[b][h][w][i][j] += grad * input_val; 128 | } 129 | } 130 | } 131 | } 132 | } 133 | } 134 | 135 | return out; 136 | } 137 | 138 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 139 | m.def("forward", &adaptive_conv_forward, "adaptive_conv forward"); 140 | m.def("grad_input", &adaptive_conv_grad_input, "adaptive_conv grad_input"); 141 | m.def("grad_filters", &adaptive_conv_grad_filters, "adaptive_conv grad_filters"); 142 | } 143 | -------------------------------------------------------------------------------- /featup/adaptive_conv_cuda/adaptive_conv.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | import torch 3 | 4 | import adaptive_conv_cuda_impl as cuda_impl 5 | import adaptive_conv_cpp_impl as cpp_impl 6 | 7 | torch.manual_seed(42) 8 | 9 | 10 | class AdaptiveConv(Function): 11 | 12 | @staticmethod 13 | def forward(ctx, input, filters): 14 | ctx.save_for_backward(filters, input) 15 | b, h2, w2, f1, f2 = filters.shape 16 | assert f1 == f2 17 | 18 | if input.is_cuda: 19 | assert filters.is_cuda 20 | result = cuda_impl.forward(input, filters) 21 | else: 22 | result = cpp_impl.forward(input, filters) 23 | 24 | return result 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | filters, input = ctx.saved_tensors 29 | grad_input = grad_filters = None 30 | b, h2, w2, f1, f2 = filters.shape 31 | assert f1 == f2 32 | 33 | grad_output = grad_output.contiguous() 34 | if grad_output.is_cuda: 35 | assert input.is_cuda 36 | assert filters.is_cuda 37 | if ctx.needs_input_grad[0]: 38 | grad_input = cuda_impl.grad_input(grad_output, filters) 39 | if ctx.needs_input_grad[1]: 40 | grad_filters = cuda_impl.grad_filters(grad_output, input) 41 | else: 42 | if ctx.needs_input_grad[0]: 43 | grad_input = cpp_impl.grad_input(grad_output, filters) 44 | if ctx.needs_input_grad[1]: 45 | grad_filters = cpp_impl.grad_filters(grad_output, input) 46 | 47 | return grad_input, grad_filters 48 | -------------------------------------------------------------------------------- /featup/adaptive_conv_cuda/adaptive_conv_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | using torch::Tensor; 3 | 4 | // CUDA forward declarations 5 | 6 | Tensor adaptive_conv_cuda_forward(Tensor input, Tensor filters); 7 | Tensor adaptive_conv_cuda_grad_input(Tensor grad_output, Tensor filters); 8 | Tensor adaptive_conv_cuda_grad_filters(Tensor grad_output, Tensor input); 9 | 10 | // C++ interface 11 | 12 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 13 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 15 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 16 | 17 | Tensor adaptive_conv_forward(Tensor input, Tensor filters) { 18 | //CHECK_INPUT(input); 19 | //CHECK_INPUT(filters); 20 | return adaptive_conv_cuda_forward(input, filters); 21 | } 22 | 23 | Tensor adaptive_conv_grad_input(Tensor grad_output, Tensor filters) { 24 | //CHECK_INPUT(grad_output); 25 | //CHECK_INPUT(filters); 26 | return adaptive_conv_cuda_grad_input(grad_output, filters); 27 | } 28 | 29 | Tensor adaptive_conv_grad_filters(Tensor grad_output, Tensor input) { 30 | //CHECK_INPUT(grad_output); 31 | //CHECK_INPUT(input); 32 | return adaptive_conv_cuda_grad_filters(grad_output, input); 33 | } 34 | 35 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 36 | m.def("forward", &adaptive_conv_forward, "adaptive_conv forward"); 37 | m.def("grad_input", &adaptive_conv_grad_input, "adaptive_conv grad_input"); 38 | m.def("grad_filters", &adaptive_conv_grad_filters, "adaptive_conv grad_filters"); 39 | } 40 | -------------------------------------------------------------------------------- /featup/adaptive_conv_cuda/adaptive_conv_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | constexpr uint32_t kernel_channel_depth = 2; 8 | 9 | using torch::Tensor; 10 | using namespace at; 11 | 12 | template 13 | __launch_bounds__(1024) __global__ void adaptive_conv_forward_kernel( 14 | torch::PackedTensorAccessor64 out, 15 | torch::PackedTensorAccessor64 input, 16 | torch::PackedTensorAccessor64 filters, 17 | uint32_t batch) { 18 | 19 | const auto w = blockIdx.x * blockDim.x + threadIdx.x; 20 | const auto h = blockIdx.y * blockDim.y + threadIdx.y; 21 | const auto c_lo = blockIdx.z * kernel_channel_depth; 22 | const auto c_hi = min(c_lo + kernel_channel_depth, (uint32_t) input.size(1)); 23 | 24 | const uint32_t I = filters.size(3); 25 | const uint32_t J = filters.size(4); 26 | 27 | if (w < out.size(3) && h < out.size(2)) { 28 | for (uint32_t c = c_lo; c < c_hi; c++) { 29 | scalar_t output_val = 0.0; 30 | for (uint32_t i = 0; i < I; i++) { 31 | for (uint32_t j = 0; j < J; j++) { 32 | 33 | auto weight = filters[batch][h][w][i][j]; 34 | auto input_val = input[batch][c][h+i][w+j]; 35 | 36 | output_val += (weight * input_val); 37 | } 38 | } 39 | out[batch][c][h][w] = output_val; 40 | } 41 | } 42 | } 43 | 44 | template 45 | __launch_bounds__(1024) __global__ void adaptive_conv_grad_input_kernel( 46 | torch::PackedTensorAccessor64 out, 47 | torch::PackedTensorAccessor64 grad_output, 48 | torch::PackedTensorAccessor64 filters, 49 | uint32_t batch) { 50 | 51 | const int32_t w = blockIdx.x * blockDim.x + threadIdx.x; 52 | const int32_t h = blockIdx.y * blockDim.y + threadIdx.y; 53 | 54 | const int32_t H_out = out.size(2); 55 | const int32_t W_out = out.size(3); 56 | 57 | // thread's output index is outside output tensor 58 | if (w >= W_out || h >= H_out) return; 59 | 60 | const int32_t c_lo = blockIdx.z * kernel_channel_depth; 61 | const int32_t c_hi = min(c_lo + kernel_channel_depth, (int32_t) out.size(1)); 62 | 63 | const int32_t I = filters.size(3); 64 | const int32_t J = filters.size(4); 65 | 66 | const int32_t H_grad = grad_output.size(2); 67 | const int32_t W_grad = grad_output.size(3); 68 | 69 | for (int32_t c = c_lo; c < c_hi; c++) { 70 | 71 | scalar_t output_val = 0.0; 72 | 73 | for (int32_t i = 0; i < I; i++) { 74 | for (int32_t j = 0; j < J; j++) { 75 | const int32_t h_grad = h - i; 76 | const int32_t w_grad = w - j; 77 | 78 | if (h_grad >= 0 && w_grad >= 0 && h_grad < H_grad && w_grad < W_grad) { 79 | output_val += grad_output[batch][c][h_grad][w_grad] * filters[batch][h_grad][w_grad][i][j]; 80 | } 81 | } 82 | } 83 | out[batch][c][h][w] = output_val; 84 | } 85 | } 86 | 87 | 88 | template 89 | __launch_bounds__(1024) __global__ void adaptive_conv_grad_filters_kernel( 90 | torch::PackedTensorAccessor64 out, 91 | torch::PackedTensorAccessor64 grad_output, 92 | torch::PackedTensorAccessor64 input, 93 | uint32_t batch) { 94 | 95 | const uint32_t w = blockIdx.x * blockDim.x + threadIdx.x; 96 | const uint32_t h = blockIdx.y * blockDim.y + threadIdx.y; 97 | const uint32_t f = blockIdx.z * blockIdx.z + threadIdx.z; 98 | 99 | const uint32_t H = out.size(1); 100 | const uint32_t W = out.size(2); 101 | const uint32_t I = out.size(3); 102 | const uint32_t J = out.size(4); 103 | 104 | assert(I == J); 105 | 106 | const uint32_t C = input.size(1); 107 | 108 | if (h >= H || w >= W || f >= (I * J)) return; 109 | 110 | const uint32_t i = f / I; 111 | const uint32_t j = f % I; 112 | 113 | scalar_t output_val = 0.0; 114 | for (uint32_t c = 0; c < C; c++) { 115 | auto grad = grad_output[batch][c][h][w]; 116 | auto input_val = input[batch][c][h+i][w+j]; 117 | output_val += grad * input_val; 118 | } 119 | out[batch][h][w][i][j] = output_val; 120 | } 121 | 122 | 123 | template 124 | T div_round_up(T a, T b) { 125 | return (a + b - 1) / b; 126 | } 127 | 128 | Tensor adaptive_conv_cuda_forward(Tensor input, Tensor filters) { 129 | at::cuda::set_device(input.device().index()); 130 | 131 | // Check for error in the input tensors 132 | TORCH_CHECK(input.dim() == 4, "input must have 4 dimensions"); 133 | TORCH_CHECK(filters.dim() == 5, "filters must have 5 dimensions"); 134 | TORCH_CHECK(input.dtype() == filters.dtype(), "input and filters must have the same data type"); 135 | 136 | const uint32_t B = input.size(0); 137 | const uint32_t C = input.size(1); 138 | const uint32_t H_in = input.size(2); 139 | const uint32_t W_in = input.size(3); 140 | 141 | TORCH_CHECK(filters.size(0) == B, "Inconsistent batch size between input and filters"); 142 | const uint32_t H_out = filters.size(1); 143 | const uint32_t W_out = filters.size(2); 144 | const uint32_t I = filters.size(3); 145 | const uint32_t J = filters.size(4); 146 | 147 | TORCH_CHECK(I == J, "filters dimension I and J must be equal"); 148 | TORCH_CHECK(H_out + I - 1 == H_in, "Inconsistent height between input and filters"); 149 | TORCH_CHECK(W_out + J - 1 == W_in, "Inconsistent width between input and filters"); 150 | 151 | auto options = torch::TensorOptions() 152 | .dtype(input.dtype()) 153 | .device(torch::kCUDA); 154 | 155 | auto out = torch::zeros({ B, C, H_out, W_out }, options); 156 | 157 | const dim3 tpb(32, 32); 158 | const dim3 blocks(div_round_up(W_out, tpb.x), 159 | div_round_up(H_out, tpb.y), 160 | div_round_up(C, kernel_channel_depth)); 161 | 162 | for (uint32_t b = 0; b < B; b++) { 163 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_forward_cuda", ([&] { 164 | adaptive_conv_forward_kernel<<>>( 165 | out.packed_accessor64(), 166 | input.packed_accessor64(), 167 | filters.packed_accessor64(), 168 | b); 169 | })); 170 | cudaError_t err = cudaGetLastError(); 171 | if (err != cudaSuccess) { 172 | printf("Error in adaptive_conv_forward_kernel: %s\n", cudaGetErrorString(err)); 173 | } 174 | } 175 | return out; 176 | } 177 | 178 | 179 | Tensor adaptive_conv_cuda_grad_input(Tensor grad_output, Tensor filters) { 180 | at::cuda::set_device(grad_output.device().index()); 181 | 182 | // Check for error in the input tensors 183 | TORCH_CHECK(grad_output.dim() == 4, "grad_output must have 4 dimensions"); 184 | TORCH_CHECK(filters.dim() == 5, "filters must have 5 dimensions"); 185 | 186 | const uint32_t B = grad_output.size(0); 187 | const uint32_t C = grad_output.size(1); 188 | const uint32_t H_out = grad_output.size(2); 189 | const uint32_t W_out = grad_output.size(3); 190 | 191 | TORCH_CHECK(filters.size(0) == B, "Inconsistent batch size between filters and grad_output"); 192 | TORCH_CHECK(filters.size(1) == H_out, "Inconsistent height between filters and grad_output"); 193 | TORCH_CHECK(filters.size(2) == W_out, "Inconsistent width between filters and grad_output"); 194 | 195 | const uint32_t I = filters.size(3); 196 | const uint32_t J = filters.size(4); 197 | TORCH_CHECK(I == J, "filters dimension I and J must be equal"); 198 | 199 | const uint32_t H_in = H_out + I - 1; 200 | const uint32_t W_in = W_out + J - 1; 201 | 202 | TORCH_CHECK(grad_output.dtype() == filters.dtype(), "grad_output and filters must have the same data type"); 203 | 204 | auto options = torch::TensorOptions() 205 | .dtype(filters.dtype()) 206 | .device(torch::kCUDA); 207 | 208 | auto out = torch::zeros({ B, C, H_in, W_in }, options); 209 | 210 | const dim3 tpb(32, 32); 211 | const dim3 blocks(div_round_up(W_in, tpb.x), 212 | div_round_up(H_in, tpb.y), 213 | div_round_up(C, kernel_channel_depth)); 214 | 215 | for (uint32_t b = 0; b < B; b++) { 216 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_grad_input_cuda", ([&] { 217 | adaptive_conv_grad_input_kernel<<>>( 218 | out.packed_accessor64(), 219 | grad_output.packed_accessor64(), 220 | filters.packed_accessor64(), 221 | b); 222 | })); 223 | cudaError_t err = cudaGetLastError(); 224 | if (err != cudaSuccess) { 225 | printf("Error in adaptive_conv_grad_input_kernel: %s\n", cudaGetErrorString(err)); 226 | } 227 | } 228 | return out; 229 | } 230 | 231 | Tensor adaptive_conv_cuda_grad_filters(Tensor grad_output, Tensor input) { 232 | at::cuda::set_device(grad_output.device().index()); 233 | 234 | // Check for error in the input tensors 235 | TORCH_CHECK(grad_output.dim() == 4, "grad_output must have 4 dimensions"); 236 | TORCH_CHECK(input.dim() == 4, "input must have 4 dimensions"); 237 | 238 | const uint32_t B = grad_output.size(0); 239 | const uint32_t C = grad_output.size(1); 240 | const uint32_t H_out = grad_output.size(2); 241 | const uint32_t W_out = grad_output.size(3); 242 | 243 | TORCH_CHECK(input.size(0) == B, "Inconsistent batch size between input and grad_output"); 244 | TORCH_CHECK(input.size(1) == C, "Inconsistent number of channels between input and grad_output"); 245 | 246 | const uint32_t H_in = input.size(2); 247 | const uint32_t W_in = input.size(3); 248 | 249 | TORCH_CHECK(H_in > H_out, "Input height must be greater than grad_output height"); 250 | TORCH_CHECK(W_in > W_out, "Input width must be greater than grad_output width"); 251 | 252 | const uint32_t I = W_in - W_out + 1; 253 | const uint32_t J = H_in - H_out + 1; 254 | 255 | TORCH_CHECK(grad_output.dtype() == input.dtype(), "grad_output and input must have the same data type"); 256 | 257 | auto options = torch::TensorOptions() 258 | .dtype(input.dtype()) 259 | .device(torch::kCUDA); 260 | 261 | auto out = torch::zeros({ B, H_out, W_out, I, J }, options); 262 | 263 | const dim3 tpb(32, 32, 1); 264 | const dim3 blocks(div_round_up(W_out, tpb.x), 265 | div_round_up(H_out, tpb.y), 266 | div_round_up(I * J, tpb.z)); 267 | 268 | 269 | 270 | for (uint32_t b = 0; b < B; b++) { 271 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_grad_filters_cuda", ([&] { 272 | adaptive_conv_grad_filters_kernel<<>>( 273 | out.packed_accessor64(), 274 | grad_output.packed_accessor64(), 275 | input.packed_accessor64(), 276 | b); 277 | })); 278 | cudaError_t err = cudaGetLastError(); 279 | if (err != cudaSuccess) { 280 | printf("Error in adaptive_conv_grad_filters_kernel: %s\n", cudaGetErrorString(err)); 281 | } 282 | } 283 | return out; 284 | } 285 | 286 | -------------------------------------------------------------------------------- /featup/configs/implicit_upsampler.yaml: -------------------------------------------------------------------------------- 1 | # Environment Args 2 | output_root: '../../' 3 | pytorch_data_dir: '/pytorch-data' 4 | submitting_to_aml: false 5 | summarize: true 6 | experiment_name: "exp1" 7 | 8 | # Dataset args 9 | dataset: "sample" 10 | split: "val" 11 | partition: 0 12 | total_partitions: 1 13 | 14 | # Model Args 15 | model_type: "maskclip" 16 | activation_type: "token" 17 | 18 | # Upsampler args 19 | outlier_detection: True 20 | downsampler_type: "attention" 21 | blur_attn: True 22 | mag_tv_weight: 0.05 23 | mag_weight: 0.001 24 | color_feats: true 25 | pca_batch: 50 26 | proj_dim: 128 27 | max_pad: 30 28 | use_flips: true 29 | max_zoom: 1.8 30 | blur_pin: 0.1 31 | n_freqs: 30 32 | param_type: "implicit" 33 | use_norm: false 34 | 35 | # Training args 36 | steps: 1200 37 | n_images: 3000 38 | 39 | # No need to change 40 | hydra: 41 | run: 42 | dir: "." 43 | output_subdir: ~ 44 | 45 | -------------------------------------------------------------------------------- /featup/configs/jbu_upsampler.yaml: -------------------------------------------------------------------------------- 1 | # Environment Args 2 | output_root: '../../' 3 | pytorch_data_dir: '/pytorch-data' 4 | submitting_to_aml: false 5 | 6 | # Dataset args 7 | dataset: "cocostuff" 8 | 9 | # Model Args 10 | model_type: "vit" 11 | activation_type: "token" 12 | 13 | # Upsampling args 14 | outlier_detection: True 15 | upsampler_type: "jbu_stack" 16 | downsampler_type: "attention" 17 | max_pad: 20 18 | max_zoom: 2 19 | n_jitters: 5 20 | random_projection: 30 21 | crf_weight: 0.001 22 | filter_ent_weight: 0.0 23 | tv_weight: 0.0 24 | 25 | implicit_sup_weight: 1.0 26 | 27 | # Training args 28 | batch_size: 4 29 | epochs: 1 30 | num_gpus: 1 31 | num_workers: 24 32 | lr: 1e-3 33 | 34 | # No need to change 35 | hydra: 36 | run: 37 | dir: "." 38 | output_subdir: ~ 39 | 40 | -------------------------------------------------------------------------------- /featup/configs/train_probe.yaml: -------------------------------------------------------------------------------- 1 | # Environment Args 2 | output_root: '../../' 3 | pytorch_data_dir: '/pytorch-data' 4 | submitting_to_aml: false 5 | 6 | # Dataset args 7 | task: "seg" 8 | 9 | # Model Args 10 | model_type: "vit" 11 | activation_type: "token" 12 | 13 | # Upsampling args 14 | outlier_detection: True 15 | upsampler_type: "jbu_stack" 16 | downsampler_type: "attention" 17 | max_pad: 20 18 | max_zoom: 2 19 | n_jitters: 5 20 | random_projection: 30 21 | crf_weight: 0.001 22 | filter_ent_weight: 0.0 23 | tv_weight: 0.0 24 | 25 | # Training args 26 | batch_size: 2 27 | epochs: 200 28 | num_workers: 24 29 | lr: 1e-3 30 | dropout: .5 31 | wd: 0.0 32 | 33 | # No need to change 34 | hydra: 35 | run: 36 | dir: "." 37 | output_subdir: ~ 38 | 39 | -------------------------------------------------------------------------------- /featup/datasets/COCO.py: -------------------------------------------------------------------------------- 1 | import random 2 | from os.path import join 3 | 4 | import numpy as np 5 | import torch 6 | import torch.multiprocessing 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def bit_get(val, idx): 12 | """Gets the bit value. 13 | Args: 14 | val: Input value, int or numpy int array. 15 | idx: Which bit of the input val. 16 | Returns: 17 | The "idx"-th bit of input val. 18 | """ 19 | return (val >> idx) & 1 20 | 21 | 22 | def create_pascal_label_colormap(): 23 | """Creates a label colormap used in PASCAL VOC segmentation benchmark. 24 | Returns: 25 | A colormap for visualizing segmentation results. 26 | """ 27 | colormap = np.zeros((512, 3), dtype=int) 28 | ind = np.arange(512, dtype=int) 29 | 30 | for shift in reversed(list(range(8))): 31 | for channel in range(3): 32 | colormap[:, channel] |= bit_get(ind, channel) << shift 33 | ind >>= 3 34 | 35 | return colormap 36 | 37 | 38 | class Coco(Dataset): 39 | def __init__(self, 40 | root, 41 | split, 42 | transform, 43 | target_transform, 44 | include_labels=True, 45 | coarse_labels=False, 46 | exclude_things=False, 47 | subset=None): 48 | super(Coco, self).__init__() 49 | self.split = split 50 | self.root = join(root, "cocostuff") 51 | self.coarse_labels = coarse_labels 52 | self.transform = transform 53 | self.label_transform = target_transform 54 | self.subset = subset 55 | self.exclude_things = exclude_things 56 | self.include_labels = include_labels 57 | 58 | if self.subset is None: 59 | self.image_list = "Coco164kFull_Stuff_Coarse.txt" 60 | elif self.subset == 6: # IIC Coarse 61 | self.image_list = "Coco164kFew_Stuff_6.txt" 62 | elif self.subset == 7: # IIC Fine 63 | self.image_list = "Coco164kFull_Stuff_Coarse_7.txt" 64 | 65 | assert self.split in ["train", "val", "train+val"] 66 | split_dirs = { 67 | "train": ["train2017"], 68 | "val": ["val2017"], 69 | "train+val": ["train2017", "val2017"] 70 | } 71 | 72 | self.image_files = [] 73 | self.label_files = [] 74 | for split_dir in split_dirs[self.split]: 75 | with open(join(self.root, "curated", split_dir, self.image_list), "r") as f: 76 | img_ids = [fn.rstrip() for fn in f.readlines()] 77 | for img_id in img_ids: 78 | self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg")) 79 | self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png")) 80 | 81 | self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8, 82 | 13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7, 83 | 25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10, 84 | 37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5, 85 | 49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2, 86 | 61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0, 87 | 73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4, 88 | 85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22, 89 | 97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15, 90 | 107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13, 91 | 117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24, 92 | 127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17, 93 | 137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21, 94 | 147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23, 95 | 157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17, 96 | 167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18, 97 | 177: 26, 178: 26, 179: 19, 180: 19, 181: 24} 98 | 99 | self._label_names = [ 100 | "ground-stuff", 101 | "plant-stuff", 102 | "sky-stuff", 103 | ] 104 | self.cocostuff3_coarse_classes = [23, 22, 21] 105 | self.first_stuff_index = 12 106 | 107 | def __len__(self): 108 | return len(self.image_files) 109 | 110 | def __getitem__(self, index): 111 | image_path = self.image_files[index] 112 | label_path = self.label_files[index] 113 | seed = np.random.randint(2147483647) 114 | batch = {} 115 | 116 | random.seed(seed) 117 | torch.manual_seed(seed) 118 | img = self.transform(Image.open(image_path).convert("RGB")) 119 | batch["img"] = img 120 | batch["img_path"] = image_path 121 | 122 | if self.include_labels: 123 | random.seed(seed) 124 | torch.manual_seed(seed) 125 | label = self.label_transform(Image.open(label_path)).squeeze(0) 126 | label[label == 255] = -1 # to be consistent with 10k 127 | coarse_label = torch.zeros_like(label) 128 | for fine, coarse in self.fine_to_coarse.items(): 129 | coarse_label[label == fine] = coarse 130 | coarse_label[label == -1] = -1 131 | 132 | if self.coarse_labels: 133 | coarser_labels = -torch.ones_like(label) 134 | for i, c in enumerate(self.cocostuff3_coarse_classes): 135 | coarser_labels[coarse_label == c] = i 136 | batch["label"] = coarser_labels 137 | else: 138 | if self.exclude_things: 139 | batch["label"] = coarse_label - self.first_stuff_index 140 | else: 141 | batch["label"] = coarse_label 142 | 143 | return batch 144 | 145 | @staticmethod 146 | def colorize_label(label): 147 | cmap = create_pascal_label_colormap() 148 | return cmap[label.cpu()].astype(np.uint8) 149 | -------------------------------------------------------------------------------- /featup/datasets/DAVIS.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import os 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class DAVIS(Dataset): 8 | def __init__(self, root, video_name, transform=None): 9 | """ 10 | Args: 11 | root (string): Directory with all the videos. 12 | video_name (string): Name of the specific video. 13 | transform (callable, optional): Optional transform to be applied on a sample. 14 | """ 15 | self.root_dir = os.path.join(root, "DAVIS/JPEGImages/480p/", video_name) 16 | self.frames = os.listdir(self.root_dir) 17 | self.transform = transform 18 | 19 | def __len__(self): 20 | return len(self.frames) 21 | 22 | def __getitem__(self, idx): 23 | img_path = os.path.join(self.root_dir, self.frames[idx]) 24 | image = Image.open(img_path).convert("RGB") 25 | 26 | if self.transform: 27 | image = self.transform(image) 28 | 29 | return {"img": image, "img_path": img_path} 30 | 31 | 32 | if __name__ == "__main__": 33 | transform = transforms.Compose([ 34 | transforms.Resize((256, 256)), 35 | transforms.ToTensor() 36 | ]) 37 | 38 | davis_dataset = DAVIS(root='/pytorch-data', video_name="motocross-jump", transform=transform) 39 | 40 | frames = davis_dataset[0] 41 | 42 | print("here") 43 | -------------------------------------------------------------------------------- /featup/datasets/EmbeddingFile.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class EmbeddingFile(Dataset): 6 | """ 7 | modified from: https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder 8 | uses cached directory listing if available rather than walking directory 9 | Attributes: 10 | classes (list): List of the class names. 11 | class_to_idx (dict): Dict with items (class_name, class_index). 12 | samples (list): List of (sample path, class_index) tuples 13 | targets (list): The class_index value for each image in the dataset 14 | """ 15 | 16 | def __init__(self, file): 17 | super(Dataset, self).__init__() 18 | self.file = file 19 | loaded = np.load(file) 20 | self.feats = loaded["feats"] 21 | self.labels = loaded["labels"] 22 | 23 | def dim(self): 24 | return self.feats.shape[1] 25 | 26 | def num_classes(self): 27 | return self.labels.max() + 1 28 | 29 | def __getitem__(self, index): 30 | return self.feats[index], self.labels[index] 31 | 32 | def __len__(self): 33 | return len(self.labels) 34 | 35 | 36 | class EmbeddingAndImage(Dataset): 37 | def __init__(self, file, dataset): 38 | super(Dataset, self).__init__() 39 | self.file = file 40 | loaded = np.load(file) 41 | self.feats = loaded["feats"] 42 | self.labels = loaded["labels"] 43 | self.imgs = dataset 44 | 45 | def dim(self): 46 | return self.feats.shape[1] 47 | 48 | def num_classes(self): 49 | return self.labels.max() + 1 50 | 51 | def __getitem__(self, index): 52 | return self.feats[index], self.labels[index], self.imgs[index] 53 | 54 | def __len__(self): 55 | return len(self.labels) 56 | -------------------------------------------------------------------------------- /featup/datasets/HighResEmbs.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import sys 3 | from os.path import join 4 | 5 | import featup.downsamplers 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision.transforms as T 10 | from featup.featurizers.util import get_featurizer 11 | from featup.layers import ChannelNorm 12 | from featup.layers import ChannelNorm 13 | from featup.util import norm 14 | from sklearn.decomposition import PCA 15 | from torch.utils.data import Dataset, DataLoader 16 | from torch.utils.data import Subset 17 | from torch.utils.data import default_collate 18 | from tqdm import tqdm 19 | 20 | from util import get_dataset 21 | 22 | torch.multiprocessing.set_sharing_strategy('file_system') 23 | 24 | 25 | def clamp_mag(t, min_mag, max_mag): 26 | mags = mag(t) 27 | clamped_above = t * (max_mag / mags.clamp_min(.000001)).clamp_max(1.0) 28 | clamped_below = clamped_above * (min_mag / mags.clamp_min(.000001)).clamp_min(1.0) 29 | return clamped_below 30 | 31 | 32 | def pca(image_feats_list, dim=3, fit_pca=None): 33 | device = image_feats_list[0].device 34 | 35 | def flatten(tensor, target_size=None): 36 | if target_size is not None and fit_pca is None: 37 | F.interpolate(tensor, (target_size, target_size), mode="bilinear") 38 | B, C, H, W = tensor.shape 39 | return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() 40 | 41 | if len(image_feats_list) > 1 and fit_pca is None: 42 | target_size = image_feats_list[0].shape[2] 43 | else: 44 | target_size = None 45 | 46 | flattened_feats = [] 47 | for feats in image_feats_list: 48 | flattened_feats.append(flatten(feats, target_size)) 49 | x = torch.cat(flattened_feats, dim=0) 50 | 51 | if fit_pca is None: 52 | fit_pca = PCA(n_components=dim).fit(x) 53 | 54 | reduced_feats = [] 55 | for feats in image_feats_list: 56 | x_red = torch.from_numpy(fit_pca.transform(flatten(feats))) 57 | x_red -= x_red.min(dim=0, keepdim=True).values 58 | x_red /= x_red.max(dim=0, keepdim=True).values 59 | B, C, H, W = feats.shape 60 | reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) 61 | 62 | return reduced_feats, fit_pca 63 | 64 | 65 | def mag(t): 66 | return t.square().sum(1, keepdim=True).sqrt() 67 | 68 | 69 | def model_collate(batch): 70 | elem = batch[0] 71 | elem_type = type(elem) 72 | if isinstance(elem, torch.nn.Module): 73 | return batch 74 | elif isinstance(elem, collections.abc.Mapping): 75 | try: 76 | return elem_type({key: model_collate([d[key] for d in batch]) for key in elem}) 77 | except TypeError: 78 | # The mapping type may not support `__init__(iterable)`. 79 | return {key: model_collate([d[key] for d in batch]) for key in elem} 80 | else: 81 | return default_collate(batch) 82 | 83 | 84 | class HighResEmbHelper(Dataset): 85 | def __init__(self, 86 | root, 87 | output_root, 88 | dataset_name, 89 | emb_name, 90 | split, 91 | model_type, 92 | transform, 93 | target_transform, 94 | limit, 95 | include_labels): 96 | self.root = root 97 | self.emb_dir = join(output_root, "feats", emb_name, dataset_name, split, model_type) 98 | 99 | self.dataset = get_dataset( 100 | root, dataset_name, split, transform, target_transform, include_labels=include_labels) 101 | 102 | if split == 'train': 103 | self.dataset = Subset(self.dataset, generate_subset(len(self.dataset), 5000)) 104 | # TODO factor this limit out 105 | 106 | if limit is not None: 107 | self.dataset = Subset(self.dataset, range(0, limit)) 108 | 109 | def __len__(self): 110 | return len(self.dataset) 111 | 112 | def __getitem__(self, item): 113 | batch = self.dataset[item] 114 | output_location = join(self.emb_dir, "/".join(batch["img_path"].split("/")[-1:]).replace(".jpg", ".pth")) 115 | state_dicts = torch.load(output_location, map_location="cpu") 116 | from featup.train_implicit_upsampler import get_implicit_upsampler 117 | from featup.util import PCAUnprojector 118 | model = get_implicit_upsampler(**state_dicts["model_args"]) 119 | model.load_state_dict(state_dicts["model"]) 120 | unp_state_dict = state_dicts["unprojector"] 121 | unprojector = PCAUnprojector( 122 | None, 123 | unp_state_dict["components_"].shape[0], 124 | device="cpu", 125 | original_dim=unp_state_dict["components_"].shape[1], 126 | **unp_state_dict 127 | ) 128 | batch["model"] = {"model": model, "unprojector": unprojector} 129 | return batch 130 | 131 | 132 | def load_hr_emb(image, loaded_model, target_res): 133 | image = image.cuda() 134 | if isinstance(loaded_model["model"], list): 135 | hr_model = loaded_model["model"][0].cuda().eval() 136 | unprojector = loaded_model["unprojector"][0].eval() 137 | else: 138 | hr_model = loaded_model["model"].cuda().eval() 139 | unprojector = loaded_model["unprojector"].eval() 140 | 141 | with torch.no_grad(): 142 | original_image = F.interpolate( 143 | image, size=(target_res, target_res), mode='bilinear', antialias=True) 144 | hr_feats = hr_model(original_image) 145 | return unprojector(hr_feats.detach().cpu()) 146 | 147 | 148 | class HighResEmb(Dataset): 149 | def __init__(self, 150 | root, 151 | dataset_name, 152 | emb_name, 153 | split, 154 | output_root, 155 | model_type, 156 | transform, 157 | target_transform, 158 | target_res, 159 | limit, 160 | include_labels, 161 | ): 162 | self.root = root 163 | self.dataset = HighResEmbHelper( 164 | root=root, 165 | output_root=output_root, 166 | dataset_name=dataset_name, 167 | emb_name=emb_name, 168 | split=split, 169 | model_type=model_type, 170 | transform=transform, 171 | target_transform=target_transform, 172 | limit=limit, 173 | include_labels=include_labels) 174 | 175 | self.all_hr_feats = [] 176 | self.target_res = target_res 177 | loader = DataLoader(self.dataset, shuffle=False, batch_size=1, num_workers=12, collate_fn=model_collate) 178 | 179 | for img_num, batch in enumerate(tqdm(loader, "Loading hr embeddings")): 180 | with torch.no_grad(): 181 | self.all_hr_feats.append(load_hr_emb(batch["img"], batch["model"], target_res)) 182 | 183 | def __len__(self): 184 | return len(self.dataset) 185 | 186 | def __getitem__(self, item): 187 | batch = self.dataset.dataset[item] 188 | batch["hr_feat"] = self.all_hr_feats[item].squeeze(0) 189 | return batch 190 | 191 | 192 | def generate_subset(n, batch): 193 | np.random.seed(0) 194 | return np.random.permutation(n)[:batch] 195 | 196 | 197 | def load_some_hr_feats(model_type, 198 | activation_type, 199 | dataset_name, 200 | split, 201 | emb_name, 202 | root, 203 | output_root, 204 | input_size, 205 | samples_per_batch, 206 | num_batches, 207 | num_workers 208 | ): 209 | transform = T.Compose([ 210 | T.Resize(input_size), 211 | T.CenterCrop(input_size), 212 | T.ToTensor(), 213 | norm 214 | ]) 215 | 216 | shared_args = dict( 217 | root=root, 218 | dataset_name=dataset_name, 219 | emb_name=emb_name, 220 | output_root=output_root, 221 | model_type=model_type, 222 | transform=transform, 223 | target_transform=None, 224 | target_res=input_size, 225 | include_labels=False, 226 | limit=samples_per_batch * num_batches 227 | ) 228 | 229 | def get_data(model, ds): 230 | loader = DataLoader(ds, batch_size=samples_per_batch, num_workers=num_workers) 231 | all_batches = [] 232 | for batch in loader: 233 | batch["lr_feat"] = model(batch["img"].cuda()).cpu() 234 | all_batches.append(batch) 235 | 236 | big_batch = {} 237 | for k, t in all_batches[0].items(): 238 | if isinstance(t, torch.Tensor): 239 | big_batch[k] = torch.cat([b[k] for b in all_batches], dim=0) 240 | del loader 241 | return big_batch 242 | 243 | with torch.no_grad(): 244 | model, _, dim = get_featurizer(model_type, activation_type) 245 | model = torch.nn.Sequential(model, ChannelNorm(dim)) 246 | model = model.cuda() 247 | batch = get_data(model, HighResEmb(split=split, **shared_args)) 248 | del model 249 | 250 | return batch 251 | 252 | 253 | if __name__ == "__main__": 254 | loaded = load_some_hr_feats( 255 | "vit", 256 | "token", 257 | "cocostuff", 258 | "train", 259 | "3_12_2024", 260 | "/pytorch-data/", 261 | "../../../", 262 | 224, 263 | 50, 264 | 3, 265 | 0 266 | ) 267 | 268 | print(loaded) 269 | -------------------------------------------------------------------------------- /featup/datasets/JitteredImage.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def apply_jitter(img, max_pad, transform_params): 9 | h, w = img.shape[2:] 10 | 11 | padded = F.pad(img, [max_pad] * 4, mode="reflect") 12 | 13 | zoom = transform_params["zoom"].item() 14 | x = transform_params["x"].item() 15 | y = transform_params["y"].item() 16 | flip = transform_params["flip"].item() 17 | 18 | if zoom > 1.0: 19 | zoomed = F.interpolate(padded, scale_factor=zoom, mode="bilinear") 20 | else: 21 | zoomed = padded 22 | 23 | cropped = zoomed[:, :, x:h + x, y:w + y] 24 | 25 | if flip: 26 | return torch.flip(cropped, [3]) 27 | else: 28 | return cropped 29 | 30 | 31 | def sample_transform(use_flips, max_pad, max_zoom, h, w): 32 | if use_flips: 33 | flip = random.random() > .5 34 | else: 35 | flip = False 36 | 37 | apply_zoom = random.random() > .5 38 | if apply_zoom: 39 | zoom = random.random() * (max_zoom - 1) + 1 40 | else: 41 | zoom = 1.0 42 | 43 | valid_area_h = (int((h + max_pad * 2) * zoom) - h) + 1 44 | valid_area_w = (int((w + max_pad * 2) * zoom) - w) + 1 45 | 46 | return { 47 | "x": torch.tensor(torch.randint(0, valid_area_h, ()).item()), 48 | "y": torch.tensor(torch.randint(0, valid_area_w, ()).item()), 49 | "zoom": torch.tensor(zoom), 50 | "flip": torch.tensor(flip) 51 | } 52 | 53 | 54 | class JitteredImage(Dataset): 55 | 56 | def __init__(self, img, length, use_flips, max_zoom, max_pad): 57 | self.img = img 58 | self.length = length 59 | self.use_flips = use_flips 60 | self.max_zoom = max_zoom 61 | self.max_pad = max_pad 62 | 63 | def __len__(self): 64 | return self.length 65 | 66 | def __getitem__(self, item): 67 | h, w = self.img.shape[2:] 68 | transform_params = sample_transform(self.use_flips, self.max_pad, self.max_zoom, h, w) 69 | return apply_jitter(self.img, self.max_pad, transform_params).squeeze(0), transform_params 70 | -------------------------------------------------------------------------------- /featup/datasets/SampleImage.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class SampleImage(Dataset): 6 | def __init__(self, paths, transform, **kwargs): 7 | self.paths = paths 8 | self.transform = transform 9 | 10 | def __getitem__(self, idx): 11 | image_path = self.paths[idx] 12 | image = Image.open(image_path).convert('RGB') 13 | if self.transform is not None: 14 | image = self.transform(image) 15 | batch = { 16 | "img": image, 17 | "img_path": image_path 18 | } 19 | return batch 20 | 21 | def __len__(self): 22 | return len(self.paths) 23 | -------------------------------------------------------------------------------- /featup/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/featup/datasets/__init__.py -------------------------------------------------------------------------------- /featup/datasets/util.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from featup.datasets.ImageNetSubset import ImageNetSubset 3 | from featup.datasets.COCO import Coco 4 | from featup.datasets.DAVIS import DAVIS 5 | from featup.datasets.SampleImage import SampleImage 6 | 7 | 8 | class SlicedDataset(Dataset): 9 | def __init__(self, ds, start, end): 10 | self.ds = ds 11 | self.start = max(0, start) 12 | self.end = min(len(ds), end) 13 | 14 | def __getitem__(self, index): 15 | if index >= self.__len__(): 16 | raise StopIteration 17 | 18 | return self.ds[self.start + index] 19 | 20 | def __len__(self): 21 | return self.end - self.start 22 | 23 | 24 | 25 | class SingleImageDataset(Dataset): 26 | def __init__(self, i, ds, l=None): 27 | self.ds = ds 28 | self.i = i 29 | self.l = len(self.ds) if l is None else l 30 | 31 | def __len__(self): 32 | return self.l 33 | 34 | def __getitem__(self, item): 35 | return self.ds[self.i] 36 | 37 | 38 | def get_dataset(dataroot, name, split, transform, target_transform, include_labels): 39 | if name == 'imagenet': 40 | if split == 'val': 41 | imagenet_subset = f'datalists/val_paths_vit.txt' 42 | else: 43 | imagenet_subset = None 44 | 45 | return ImageNetSubset(dataroot, split, transform, target_transform, 46 | include_labels=include_labels, subset=imagenet_subset) 47 | elif name == 'cocostuff': 48 | return Coco(dataroot, split, transform, target_transform, include_labels=include_labels) 49 | elif name.startswith('davis_'): 50 | return DAVIS(dataroot, name.split("_")[-1], transform) 51 | elif name == "sample": 52 | return SampleImage( 53 | paths=["../sample-images/bird_left.jpg", 54 | "../sample-images/bird_right.jpg"], 55 | transform=transform 56 | ) 57 | else: 58 | raise ValueError(f"Unknown dataset {name}") 59 | -------------------------------------------------------------------------------- /featup/downsamplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from kornia.filters import gaussian_blur2d 4 | 5 | 6 | class SimpleDownsampler(torch.nn.Module): 7 | 8 | def get_kernel(self): 9 | k = self.kernel_params.unsqueeze(0).unsqueeze(0).abs() 10 | k /= k.sum() 11 | return k 12 | 13 | def __init__(self, kernel_size, final_size, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.kernel_size = kernel_size 16 | self.final_size = final_size 17 | self.kernel_params = torch.nn.Parameter(torch.ones(kernel_size, kernel_size)) 18 | 19 | def forward(self, imgs, guidance): 20 | b, c, h, w = imgs.shape 21 | input_imgs = imgs.reshape(b * c, 1, h, w) 22 | stride = (h - self.kernel_size) // (self.final_size - 1) 23 | 24 | return F.conv2d( 25 | input_imgs, 26 | self.get_kernel(), 27 | stride=stride 28 | ).reshape(b, c, self.final_size, self.final_size) 29 | 30 | 31 | class AttentionDownsampler(torch.nn.Module): 32 | 33 | def __init__(self, dim, kernel_size, final_size, blur_attn, *args, **kwargs): 34 | super().__init__(*args, **kwargs) 35 | self.kernel_size = kernel_size 36 | self.final_size = final_size 37 | self.in_dim = dim 38 | self.attention_net = torch.nn.Sequential( 39 | torch.nn.Dropout(p=.2), 40 | torch.nn.Linear(self.in_dim, 1) 41 | ) 42 | self.w = torch.nn.Parameter(torch.ones(kernel_size, kernel_size).cuda() 43 | + .01 * torch.randn(kernel_size, kernel_size).cuda()) 44 | self.b = torch.nn.Parameter(torch.zeros(kernel_size, kernel_size).cuda() 45 | + .01 * torch.randn(kernel_size, kernel_size).cuda()) 46 | self.blur_attn = blur_attn 47 | 48 | def forward_attention(self, feats, guidance): 49 | return self.attention_net(feats.permute(0, 2, 3, 1)).squeeze(-1).unsqueeze(1) 50 | 51 | def forward(self, hr_feats, guidance): 52 | b, c, h, w = hr_feats.shape 53 | 54 | if self.blur_attn: 55 | inputs = gaussian_blur2d(hr_feats, 5, (1.0, 1.0)) 56 | else: 57 | inputs = hr_feats 58 | 59 | stride = (h - self.kernel_size) // (self.final_size - 1) 60 | 61 | patches = torch.nn.Unfold(self.kernel_size, stride=stride)(inputs) \ 62 | .reshape( 63 | (b, self.in_dim, self.kernel_size * self.kernel_size, self.final_size, self.final_size * int(w / h))) \ 64 | .permute(0, 3, 4, 2, 1) 65 | 66 | patch_logits = self.attention_net(patches).squeeze(-1) 67 | 68 | b, h, w, p = patch_logits.shape 69 | dropout = torch.rand(b, h, w, 1, device=patch_logits.device) > 0.2 70 | 71 | w = self.w.flatten().reshape(1, 1, 1, -1) 72 | b = self.b.flatten().reshape(1, 1, 1, -1) 73 | 74 | patch_attn_logits = (patch_logits * dropout) * w + b 75 | patch_attention = F.softmax(patch_attn_logits, dim=-1) 76 | 77 | downsampled = torch.einsum("bhwpc,bhwp->bchw", patches, patch_attention) 78 | 79 | return downsampled[:, :c, :, :] 80 | -------------------------------------------------------------------------------- /featup/featurizers/CLIP.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch 3 | from torch import nn 4 | import os 5 | 6 | class CLIPFeaturizer(nn.Module): 7 | 8 | def __init__(self): 9 | super().__init__() 10 | self.model, self.preprocess = clip.load( 11 | "ViT-B/16", 12 | download_root=os.getenv('TORCH_HOME', os.path.join(os.path.expanduser('~'), '.cache', 'torch')) 13 | ) 14 | self.model.eval() 15 | 16 | def get_cls_token(self, img): 17 | return self.model.encode_image(img).to(torch.float32) 18 | 19 | def forward(self, img): 20 | features = self.model.get_visual_features(img, include_cls=False).to(torch.float32) 21 | return features 22 | 23 | 24 | if __name__ == "__main__": 25 | import torchvision.transforms as T 26 | from PIL import Image 27 | from shared import norm, crop_to_divisor 28 | 29 | device = "cuda" if torch.cuda.is_available() else "cpu" 30 | 31 | image = Image.open("../samples/lex1.jpg") 32 | load_size = 224 # * 3 33 | transform = T.Compose([ 34 | T.Resize(load_size, Image.BILINEAR), 35 | # T.CenterCrop(load_size), 36 | T.ToTensor(), 37 | lambda x: crop_to_divisor(x, 16), 38 | norm]) 39 | 40 | model = CLIPFeaturizer().cuda() 41 | 42 | results = model(transform(image).cuda().unsqueeze(0)) 43 | 44 | print(clip.available_models()) 45 | -------------------------------------------------------------------------------- /featup/featurizers/DeepLabV3.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class DeepLabV3Featurizer(nn.Module): 5 | def __init__(self, model): 6 | super().__init__() 7 | self.model = model 8 | 9 | def get_cls_token(self, img): 10 | return self.model.forward(img) 11 | 12 | def forward(self, img, layer_num=-1): 13 | return self.model.backbone(img)['out'] 14 | -------------------------------------------------------------------------------- /featup/featurizers/MaskCLIP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import os 4 | 5 | from featup.featurizers.maskclip import clip 6 | 7 | 8 | class MaskCLIPFeaturizer(nn.Module): 9 | 10 | def __init__(self): 11 | super().__init__() 12 | self.model, self.preprocess = clip.load( 13 | "ViT-B/16", 14 | download_root=os.getenv('TORCH_HOME', os.path.join(os.path.expanduser('~'), '.cache', 'torch')) 15 | ) 16 | self.model.eval() 17 | self.patch_size = self.model.visual.patch_size 18 | 19 | def forward(self, img): 20 | b, _, input_size_h, input_size_w = img.shape 21 | patch_h = input_size_h // self.patch_size 22 | patch_w = input_size_w // self.patch_size 23 | features = self.model.get_patch_encodings(img).to(torch.float32) 24 | return features.reshape(b, patch_h, patch_w, -1).permute(0, 3, 1, 2) 25 | 26 | 27 | if __name__ == "__main__": 28 | import torchvision.transforms as T 29 | from PIL import Image 30 | from featup.util import norm, unnorm, crop_to_divisor 31 | 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | 34 | image = Image.open("../samples/lex1.jpg") 35 | load_size = 224 # * 3 36 | transform = T.Compose([ 37 | T.Resize(load_size, Image.BILINEAR), 38 | # T.CenterCrop(load_size), 39 | T.ToTensor(), 40 | lambda x: crop_to_divisor(x, 16), 41 | norm]) 42 | 43 | model = MaskCLIPFeaturizer().cuda() 44 | 45 | results = model(transform(image).cuda().unsqueeze(0)) 46 | 47 | print(clip.available_models()) 48 | -------------------------------------------------------------------------------- /featup/featurizers/ResNet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class ResNetFeaturizer(nn.Module): 5 | def __init__(self, model): 6 | super().__init__() 7 | self.model = model 8 | 9 | def get_cls_token(self, img): 10 | return self.model.forward(img) 11 | 12 | def get_layer(self, img, layer_num): 13 | return self.model.get_layer(img, layer_num) 14 | 15 | def forward(self, img, layer_num=-1): 16 | return self.model.get_layer(img, layer_num) 17 | -------------------------------------------------------------------------------- /featup/featurizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/featup/featurizers/__init__.py -------------------------------------------------------------------------------- /featup/featurizers/dinov2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/featup/featurizers/dinov2/__init__.py -------------------------------------------------------------------------------- /featup/featurizers/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .dino_head import DINOHead 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /featup/featurizers/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 22 | try: 23 | if XFORMERS_ENABLED: 24 | from xformers.ops import memory_efficient_attention, unbind 25 | 26 | XFORMERS_AVAILABLE = True 27 | warnings.warn("xFormers is available (Attention)") 28 | else: 29 | warnings.warn("xFormers is disabled (Attention)") 30 | raise ImportError 31 | except ImportError: 32 | XFORMERS_AVAILABLE = False 33 | warnings.warn("xFormers is not available (Attention)") 34 | 35 | 36 | class Attention(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int = 8, 41 | qkv_bias: bool = False, 42 | proj_bias: bool = True, 43 | attn_drop: float = 0.0, 44 | proj_drop: float = 0.0, 45 | ) -> None: 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | self.scale = head_dim**-0.5 50 | 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | def forward(self, x: Tensor) -> Tensor: 57 | B, N, C = x.shape 58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 59 | 60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 61 | attn = q @ k.transpose(-2, -1) 62 | 63 | attn = attn.softmax(dim=-1) 64 | attn = self.attn_drop(attn) 65 | 66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 67 | x = self.proj(x) 68 | x = self.proj_drop(x) 69 | return x 70 | 71 | 72 | class MemEffAttention(Attention): 73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 74 | if not XFORMERS_AVAILABLE: 75 | if attn_bias is not None: 76 | raise AssertionError("xFormers is required for using nested tensors") 77 | return super().forward(x) 78 | 79 | B, N, C = x.shape 80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 81 | 82 | q, k, v = unbind(qkv, 2) 83 | 84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 85 | x = x.reshape([B, N, C]) 86 | 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | return x 90 | -------------------------------------------------------------------------------- /featup/featurizers/dinov2/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | import logging 11 | import os 12 | from typing import Callable, List, Any, Tuple, Dict 13 | import warnings 14 | 15 | import torch 16 | from torch import nn, Tensor 17 | 18 | from .attention import Attention, MemEffAttention 19 | from .drop_path import DropPath 20 | from .layer_scale import LayerScale 21 | from .mlp import Mlp 22 | 23 | 24 | logger = logging.getLogger("dinov2") 25 | 26 | 27 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 28 | try: 29 | if XFORMERS_ENABLED: 30 | from xformers.ops import fmha, scaled_index_add, index_select_cat 31 | 32 | XFORMERS_AVAILABLE = True 33 | warnings.warn("xFormers is available (Block)") 34 | else: 35 | warnings.warn("xFormers is disabled (Block)") 36 | raise ImportError 37 | except ImportError: 38 | XFORMERS_AVAILABLE = False 39 | 40 | warnings.warn("xFormers is not available (Block)") 41 | 42 | 43 | class Block(nn.Module): 44 | def __init__( 45 | self, 46 | dim: int, 47 | num_heads: int, 48 | mlp_ratio: float = 4.0, 49 | qkv_bias: bool = False, 50 | proj_bias: bool = True, 51 | ffn_bias: bool = True, 52 | drop: float = 0.0, 53 | attn_drop: float = 0.0, 54 | init_values=None, 55 | drop_path: float = 0.0, 56 | act_layer: Callable[..., nn.Module] = nn.GELU, 57 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 58 | attn_class: Callable[..., nn.Module] = Attention, 59 | ffn_layer: Callable[..., nn.Module] = Mlp, 60 | ) -> None: 61 | super().__init__() 62 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 63 | self.norm1 = norm_layer(dim) 64 | self.attn = attn_class( 65 | dim, 66 | num_heads=num_heads, 67 | qkv_bias=qkv_bias, 68 | proj_bias=proj_bias, 69 | attn_drop=attn_drop, 70 | proj_drop=drop, 71 | ) 72 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 73 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 74 | 75 | self.norm2 = norm_layer(dim) 76 | mlp_hidden_dim = int(dim * mlp_ratio) 77 | self.mlp = ffn_layer( 78 | in_features=dim, 79 | hidden_features=mlp_hidden_dim, 80 | act_layer=act_layer, 81 | drop=drop, 82 | bias=ffn_bias, 83 | ) 84 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 85 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 86 | 87 | self.sample_drop_ratio = drop_path 88 | 89 | def forward(self, x: Tensor) -> Tensor: 90 | def attn_residual_func(x: Tensor) -> Tensor: 91 | return self.ls1(self.attn(self.norm1(x))) 92 | 93 | def ffn_residual_func(x: Tensor) -> Tensor: 94 | return self.ls2(self.mlp(self.norm2(x))) 95 | 96 | if self.training and self.sample_drop_ratio > 0.1: 97 | # the overhead is compensated only for a drop path rate larger than 0.1 98 | x = drop_add_residual_stochastic_depth( 99 | x, 100 | residual_func=attn_residual_func, 101 | sample_drop_ratio=self.sample_drop_ratio, 102 | ) 103 | x = drop_add_residual_stochastic_depth( 104 | x, 105 | residual_func=ffn_residual_func, 106 | sample_drop_ratio=self.sample_drop_ratio, 107 | ) 108 | elif self.training and self.sample_drop_ratio > 0.0: 109 | x = x + self.drop_path1(attn_residual_func(x)) 110 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 111 | else: 112 | x = x + attn_residual_func(x) 113 | x = x + ffn_residual_func(x) 114 | return x 115 | 116 | 117 | def drop_add_residual_stochastic_depth( 118 | x: Tensor, 119 | residual_func: Callable[[Tensor], Tensor], 120 | sample_drop_ratio: float = 0.0, 121 | ) -> Tensor: 122 | # 1) extract subset using permutation 123 | b, n, d = x.shape 124 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 125 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 126 | x_subset = x[brange] 127 | 128 | # 2) apply residual_func to get residual 129 | residual = residual_func(x_subset) 130 | 131 | x_flat = x.flatten(1) 132 | residual = residual.flatten(1) 133 | 134 | residual_scale_factor = b / sample_subset_size 135 | 136 | # 3) add the residual 137 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 138 | return x_plus_residual.view_as(x) 139 | 140 | 141 | def get_branges_scales(x, sample_drop_ratio=0.0): 142 | b, n, d = x.shape 143 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 144 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 145 | residual_scale_factor = b / sample_subset_size 146 | return brange, residual_scale_factor 147 | 148 | 149 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 150 | if scaling_vector is None: 151 | x_flat = x.flatten(1) 152 | residual = residual.flatten(1) 153 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 154 | else: 155 | x_plus_residual = scaled_index_add( 156 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 157 | ) 158 | return x_plus_residual 159 | 160 | 161 | attn_bias_cache: Dict[Tuple, Any] = {} 162 | 163 | 164 | def get_attn_bias_and_cat(x_list, branges=None): 165 | """ 166 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 167 | """ 168 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 169 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 170 | if all_shapes not in attn_bias_cache.keys(): 171 | seqlens = [] 172 | for b, x in zip(batch_sizes, x_list): 173 | for _ in range(b): 174 | seqlens.append(x.shape[1]) 175 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 176 | attn_bias._batch_sizes = batch_sizes 177 | attn_bias_cache[all_shapes] = attn_bias 178 | 179 | if branges is not None: 180 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 181 | else: 182 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 183 | cat_tensors = torch.cat(tensors_bs1, dim=1) 184 | 185 | return attn_bias_cache[all_shapes], cat_tensors 186 | 187 | 188 | def drop_add_residual_stochastic_depth_list( 189 | x_list: List[Tensor], 190 | residual_func: Callable[[Tensor, Any], Tensor], 191 | sample_drop_ratio: float = 0.0, 192 | scaling_vector=None, 193 | ) -> Tensor: 194 | # 1) generate random set of indices for dropping samples in the batch 195 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 196 | branges = [s[0] for s in branges_scales] 197 | residual_scale_factors = [s[1] for s in branges_scales] 198 | 199 | # 2) get attention bias and index+concat the tensors 200 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 201 | 202 | # 3) apply residual_func to get residual, and split the result 203 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 204 | 205 | outputs = [] 206 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 207 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 208 | return outputs 209 | 210 | 211 | class NestedTensorBlock(Block): 212 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 213 | """ 214 | x_list contains a list of tensors to nest together and run 215 | """ 216 | assert isinstance(self.attn, MemEffAttention) 217 | 218 | if self.training and self.sample_drop_ratio > 0.0: 219 | 220 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 221 | return self.attn(self.norm1(x), attn_bias=attn_bias) 222 | 223 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 224 | return self.mlp(self.norm2(x)) 225 | 226 | x_list = drop_add_residual_stochastic_depth_list( 227 | x_list, 228 | residual_func=attn_residual_func, 229 | sample_drop_ratio=self.sample_drop_ratio, 230 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 231 | ) 232 | x_list = drop_add_residual_stochastic_depth_list( 233 | x_list, 234 | residual_func=ffn_residual_func, 235 | sample_drop_ratio=self.sample_drop_ratio, 236 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 237 | ) 238 | return x_list 239 | else: 240 | 241 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 242 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 243 | 244 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 245 | return self.ls2(self.mlp(self.norm2(x))) 246 | 247 | attn_bias, x = get_attn_bias_and_cat(x_list) 248 | x = x + attn_residual_func(x, attn_bias=attn_bias) 249 | x = x + ffn_residual_func(x) 250 | return attn_bias.split(x) 251 | 252 | def forward(self, x_or_x_list): 253 | if isinstance(x_or_x_list, Tensor): 254 | return super().forward(x_or_x_list) 255 | elif isinstance(x_or_x_list, list): 256 | if not XFORMERS_AVAILABLE: 257 | raise AssertionError("xFormers is required for using nested tensors") 258 | return self.forward_nested(x_or_x_list) 259 | else: 260 | raise AssertionError 261 | -------------------------------------------------------------------------------- /featup/featurizers/dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.init import trunc_normal_ 9 | from torch.nn.utils import weight_norm 10 | 11 | 12 | class DINOHead(nn.Module): 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim, 17 | use_bn=False, 18 | nlayers=3, 19 | hidden_dim=2048, 20 | bottleneck_dim=256, 21 | mlp_bias=True, 22 | ): 23 | super().__init__() 24 | nlayers = max(nlayers, 1) 25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 26 | self.apply(self._init_weights) 27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 28 | self.last_layer.weight_g.data.fill_(1) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=0.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | x = self.mlp(x) 38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 40 | x = self.last_layer(x) 41 | return x 42 | 43 | 44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 45 | if nlayers == 1: 46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 47 | else: 48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 49 | if use_bn: 50 | layers.append(nn.BatchNorm1d(hidden_dim)) 51 | layers.append(nn.GELU()) 52 | for _ in range(nlayers - 2): 53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 54 | if use_bn: 55 | layers.append(nn.BatchNorm1d(hidden_dim)) 56 | layers.append(nn.GELU()) 57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 58 | return nn.Sequential(*layers) 59 | -------------------------------------------------------------------------------- /featup/featurizers/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /featup/featurizers/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /featup/featurizers/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /featup/featurizers/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | _, _, H, W = x.shape 70 | patch_H, patch_W = self.patch_size 71 | 72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 74 | 75 | x = self.proj(x) # B C H W 76 | H, W = x.size(2), x.size(3) 77 | x = x.flatten(2).transpose(1, 2) # B HW C 78 | x = self.norm(x) 79 | if not self.flatten_embedding: 80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 81 | return x 82 | 83 | def flops(self) -> float: 84 | Ho, Wo = self.patches_resolution 85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 86 | if self.norm is not None: 87 | flops += Ho * Wo * self.embed_dim 88 | return flops 89 | -------------------------------------------------------------------------------- /featup/featurizers/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | try: 39 | if XFORMERS_ENABLED: 40 | from xformers.ops import SwiGLU 41 | 42 | XFORMERS_AVAILABLE = True 43 | warnings.warn("xFormers is available (SwiGLU)") 44 | else: 45 | warnings.warn("xFormers is disabled (SwiGLU)") 46 | raise ImportError 47 | except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__( 68 | in_features=in_features, 69 | hidden_features=hidden_features, 70 | out_features=out_features, 71 | bias=bias, 72 | ) 73 | -------------------------------------------------------------------------------- /featup/featurizers/maskclip/README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | Modified version of [CLIP](https://github.com/openai/CLIP) with support for dense patch-level feature extraction 3 | (based on [MaskCLIP](https://arxiv.org/abs/2112.01071) parametrization) and interpolation of the positional encoding. 4 | -------------------------------------------------------------------------------- /featup/featurizers/maskclip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | 3 | """ 4 | Modified from https://github.com/openai/CLIP 5 | """ 6 | -------------------------------------------------------------------------------- /featup/featurizers/maskclip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/featup/featurizers/maskclip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /featup/featurizers/maskclip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | 19 | BICUBIC = InterpolationMode.BICUBIC 20 | except ImportError: 21 | BICUBIC = Image.BICUBIC 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 37 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 38 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 39 | } 40 | 41 | 42 | def _download(url: str, root: str): 43 | os.makedirs(root, exist_ok=True) 44 | filename = os.path.basename(url) 45 | 46 | expected_sha256 = url.split("/")[-2] 47 | download_target = os.path.join(root, filename) 48 | 49 | if os.path.exists(download_target) and not os.path.isfile(download_target): 50 | raise RuntimeError(f"{download_target} exists and is not a regular file") 51 | 52 | if os.path.isfile(download_target): 53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 54 | return download_target 55 | else: 56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 57 | 58 | print(f"Downloading CLIP model from {url}") 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, 61 | unit_divisor=1024) as loop: 62 | while True: 63 | buffer = source.read(8192) 64 | if not buffer: 65 | break 66 | 67 | output.write(buffer) 68 | loop.update(len(buffer)) 69 | 70 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 71 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 72 | 73 | return download_target 74 | 75 | 76 | def _convert_image_to_rgb(image): 77 | return image.convert("RGB") 78 | 79 | 80 | def _transform(n_px): 81 | return Compose([ 82 | Resize(n_px, interpolation=BICUBIC), 83 | CenterCrop(n_px), 84 | _convert_image_to_rgb, 85 | ToTensor(), 86 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 87 | ]) 88 | 89 | 90 | def available_models() -> List[str]: 91 | """Returns the names of available CLIP models""" 92 | return list(_MODELS.keys()) 93 | 94 | 95 | TORCH_HUB_ROOT = os.path.expandvars(os.getenv("$TORCH_HUB_ROOT", "$HOME/.torch_hub")) 96 | 97 | 98 | def load( 99 | name: str, 100 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 101 | jit: bool = False, 102 | download_root: str = None 103 | ): 104 | """Load a CLIP model 105 | 106 | Parameters 107 | ---------- 108 | name : str 109 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 110 | 111 | device : Union[str, torch.device] 112 | The device to put the loaded model 113 | 114 | jit : bool 115 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 116 | 117 | download_root: str 118 | path to download the model files; by default, it uses "~/.torch_hub/clip" 119 | 120 | Returns 121 | ------- 122 | model : torch.nn.Module 123 | The CLIP model 124 | 125 | preprocess : Callable[[PIL.Image], torch.Tensor] 126 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 127 | """ 128 | if name in _MODELS: 129 | model_path = _download(_MODELS[name], download_root or TORCH_HUB_ROOT) 130 | elif os.path.isfile(name): 131 | model_path = name 132 | else: 133 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 134 | 135 | with open(model_path, 'rb') as opened_file: 136 | try: 137 | # loading JIT archive 138 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 139 | state_dict = None 140 | except RuntimeError: 141 | # loading saved state dict 142 | if jit: 143 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 144 | jit = False 145 | state_dict = torch.load(opened_file, map_location="cpu") 146 | 147 | if not jit: 148 | model = build_model(state_dict or model.state_dict()).to(device) 149 | if str(device) == "cpu": 150 | model.float() 151 | return model, _transform(model.visual.input_resolution) 152 | 153 | # patch the device names 154 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 155 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 156 | 157 | def patch_device(module): 158 | try: 159 | graphs = [module.graph] if hasattr(module, "graph") else [] 160 | except RuntimeError: 161 | graphs = [] 162 | 163 | if hasattr(module, "forward1"): 164 | graphs.append(module.forward1.graph) 165 | 166 | for graph in graphs: 167 | for node in graph.findAllNodes("prim::Constant"): 168 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 169 | node.copyAttributes(device_node) 170 | 171 | model.apply(patch_device) 172 | patch_device(model.encode_image) 173 | patch_device(model.encode_text) 174 | 175 | # patch dtype to float32 on CPU 176 | if str(device) == "cpu": 177 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 178 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 179 | float_node = float_input.node() 180 | 181 | def patch_float(module): 182 | try: 183 | graphs = [module.graph] if hasattr(module, "graph") else [] 184 | except RuntimeError: 185 | graphs = [] 186 | 187 | if hasattr(module, "forward1"): 188 | graphs.append(module.forward1.graph) 189 | 190 | for graph in graphs: 191 | for node in graph.findAllNodes("aten::to"): 192 | inputs = list(node.inputs()) 193 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 194 | if inputs[i].node()["value"] == 5: 195 | inputs[i].node().copyAttributes(float_node) 196 | 197 | model.apply(patch_float) 198 | patch_float(model.encode_image) 199 | patch_float(model.encode_text) 200 | 201 | model.float() 202 | 203 | return model, _transform(model.input_resolution.item()) 204 | 205 | 206 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[ 207 | torch.IntTensor, torch.LongTensor]: 208 | """ 209 | Returns the tokenized representation of given input string(s) 210 | 211 | Parameters 212 | ---------- 213 | texts : Union[str, List[str]] 214 | An input string or a list of input strings to tokenize 215 | 216 | context_length : int 217 | The context length to use; all CLIP models use 77 as the context length 218 | 219 | truncate: bool 220 | Whether to truncate the text in case its encoding is longer than the context length 221 | 222 | Returns 223 | ------- 224 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 225 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 226 | """ 227 | if isinstance(texts, str): 228 | texts = [texts] 229 | 230 | sot_token = _tokenizer.encoder["<|startoftext|>"] 231 | eot_token = _tokenizer.encoder["<|endoftext|>"] 232 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 233 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 234 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 235 | else: 236 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 237 | 238 | for i, tokens in enumerate(all_tokens): 239 | if len(tokens) > context_length: 240 | if truncate: 241 | tokens = tokens[:context_length] 242 | tokens[-1] = eot_token 243 | else: 244 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 245 | result[i, :len(tokens)] = torch.tensor(tokens) 246 | 247 | return result 248 | -------------------------------------------------------------------------------- /featup/featurizers/maskclip/interpolate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def interpolate_positional_embedding( 6 | positional_embedding: torch.Tensor, x: torch.Tensor, patch_size: int, w: int, h: int 7 | ): 8 | """ 9 | Interpolate the positional encoding for CLIP to the number of patches in the image given width and height. 10 | Modified from DINO ViT `interpolate_pos_encoding` method. 11 | https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L174 12 | """ 13 | assert positional_embedding.ndim == 2, "pos_encoding must be 2D" 14 | 15 | # Number of patches in input 16 | num_patches = x.shape[1] - 1 17 | # Original number of patches for square images 18 | num_og_patches = positional_embedding.shape[0] - 1 19 | 20 | if num_patches == num_og_patches and w == h: 21 | # No interpolation needed 22 | return positional_embedding.to(x.dtype) 23 | 24 | dim = x.shape[-1] 25 | class_pos_embed = positional_embedding[:1] # (1, dim) 26 | patch_pos_embed = positional_embedding[1:] # (num_og_patches, dim) 27 | 28 | # Compute number of tokens 29 | w0 = w // patch_size 30 | h0 = h // patch_size 31 | assert w0 * h0 == num_patches, "Number of patches does not match" 32 | 33 | # Add a small number to avoid floating point error in the interpolation 34 | # see discussion at https://github.com/facebookresearch/dino/issues/8 35 | w0, h0 = w0 + 0.1, h0 + 0.1 36 | 37 | # Interpolate 38 | patch_per_ax = int(np.sqrt(num_og_patches)) 39 | patch_pos_embed_interp = torch.nn.functional.interpolate( 40 | patch_pos_embed.reshape(1, patch_per_ax, patch_per_ax, dim).permute(0, 3, 1, 2), 41 | # (1, dim, patch_per_ax, patch_per_ax) 42 | scale_factor=(w0 / patch_per_ax, h0 / patch_per_ax), 43 | mode="bicubic", 44 | align_corners=False, 45 | recompute_scale_factor=False, 46 | ) # (1, dim, w0, h0) 47 | assert ( 48 | int(w0) == patch_pos_embed_interp.shape[-2] and int(h0) == patch_pos_embed_interp.shape[-1] 49 | ), "Interpolation error." 50 | 51 | patch_pos_embed_interp = patch_pos_embed_interp.permute(0, 2, 3, 1).reshape(-1, dim) # (w0 * h0, dim) 52 | # Concat class token embedding and interpolated patch embeddings 53 | pos_embed_interp = torch.cat([class_pos_embed, patch_pos_embed_interp], dim=0) # (w0 * h0 + 1, dim) 54 | return pos_embed_interp.to(x.dtype) 55 | -------------------------------------------------------------------------------- /featup/featurizers/maskclip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from collections.abc import Sequence 5 | from functools import lru_cache 6 | 7 | import ftfy 8 | import regex as re 9 | 10 | 11 | @lru_cache() 12 | def default_bpe(): 13 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 14 | 15 | 16 | @lru_cache() 17 | def bytes_to_unicode(): 18 | """ 19 | Returns list of utf-8 byte and a corresponding list of unicode strings. 20 | The reversible bpe codes work on unicode strings. 21 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 22 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 23 | This is a signficant percentage of your normal, say, 32K bpe vocab. 24 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 25 | And avoids mapping to whitespace/control characters the bpe code barfs on. 26 | """ 27 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 28 | cs = bs[:] 29 | n = 0 30 | for b in range(2**8): 31 | if b not in bs: 32 | bs.append(b) 33 | cs.append(2**8+n) 34 | n += 1 35 | cs = [chr(n) for n in cs] 36 | return dict(zip(bs, cs)) 37 | 38 | 39 | def get_pairs(word): 40 | """Return set of symbol pairs in a word. 41 | Word is represented as tuple of symbols (symbols being variable-length strings). 42 | """ 43 | pairs = set() 44 | prev_char = word[0] 45 | for char in word[1:]: 46 | pairs.add((prev_char, char)) 47 | prev_char = char 48 | return pairs 49 | 50 | 51 | def basic_clean(text): 52 | # note: pretty hacky but it is okay! 53 | # ge: bad.this is used by the cli_multi_label.py script 54 | if not isinstance(text, str): 55 | text = ', '.join(text) 56 | 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r'\s+', ' ', text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe()): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 73 | merges = merges[1:49152-256-2+1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v+'' for v in vocab] 77 | for merge in merges: 78 | vocab.append(''.join(merge)) 79 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 80 | self.encoder = dict(zip(vocab, range(len(vocab)))) 81 | self.decoder = {v: k for k, v in self.encoder.items()} 82 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 83 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 84 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 85 | 86 | def bpe(self, token): 87 | if token in self.cache: 88 | return self.cache[token] 89 | word = tuple(token[:-1]) + ( token[-1] + '',) 90 | pairs = get_pairs(word) 91 | 92 | if not pairs: 93 | return token+'' 94 | 95 | while True: 96 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 97 | if bigram not in self.bpe_ranks: 98 | break 99 | first, second = bigram 100 | new_word = [] 101 | i = 0 102 | while i < len(word): 103 | try: 104 | j = word.index(first, i) 105 | new_word.extend(word[i:j]) 106 | i = j 107 | except: 108 | new_word.extend(word[i:]) 109 | break 110 | 111 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 112 | new_word.append(first+second) 113 | i += 2 114 | else: 115 | new_word.append(word[i]) 116 | i += 1 117 | new_word = tuple(new_word) 118 | word = new_word 119 | if len(word) == 1: 120 | break 121 | else: 122 | pairs = get_pairs(word) 123 | word = ' '.join(word) 124 | self.cache[token] = word 125 | return word 126 | 127 | def encode(self, text): 128 | bpe_tokens = [] 129 | text = whitespace_clean(basic_clean(text)).lower() 130 | for token in re.findall(self.pat, text): 131 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 132 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 133 | return bpe_tokens 134 | 135 | def decode(self, tokens): 136 | text = ''.join([self.decoder[token] for token in tokens]) 137 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 138 | return text 139 | -------------------------------------------------------------------------------- /featup/featurizers/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/featup/featurizers/modules/__init__.py -------------------------------------------------------------------------------- /featup/featurizers/modules/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | import math 6 | 7 | __all__ = ['forward_hook', 'AdaptiveAvgPool2d', 'Add', 'AvgPool2d', 'BatchNorm2d', 'Clone', 'Conv2d', 'ConvTranspose2d', 8 | 'Dropout', 'Identity', 'LeakyReLU', 'Linear', 'MaxPool2d', 'Multiply', 'ReLU', 'Sequential', 'safe_divide', 9 | 'ZeroPad2d', 'LayerNorm', 'GELU', 'einsum', 'Softmax'] 10 | 11 | 12 | def safe_divide(a, b): 13 | return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type()) 14 | 15 | 16 | def forward_hook(self, input, output): 17 | if type(input[0]) in (list, tuple): 18 | self.X = [] 19 | for i in input[0]: 20 | x = i.detach() 21 | x.requires_grad = True 22 | self.X.append(x) 23 | else: 24 | self.X = input[0].detach() 25 | self.X.requires_grad = True 26 | 27 | self.Y = output 28 | 29 | 30 | class RelProp(nn.Module): 31 | def __init__(self): 32 | super(RelProp, self).__init__() 33 | # if not self.training: 34 | self.register_forward_hook(forward_hook) 35 | 36 | def gradprop(self, Z, X, S): 37 | C = torch.autograd.grad(Z, X, S, retain_graph=True) 38 | return C 39 | 40 | def relprop(self, R, alpha=1): 41 | return R 42 | 43 | 44 | class RelPropSimple(RelProp): 45 | def relprop(self, R, alpha=1): 46 | Z = self.forward(self.X) 47 | S = safe_divide(R, Z) 48 | C = self.gradprop(Z, self.X, S) 49 | 50 | if torch.is_tensor(self.X) == False: 51 | outputs = [] 52 | outputs.append(self.X[0] * C[0]) 53 | outputs.append(self.X[1] * C[1]) 54 | else: 55 | outputs = self.X * C[0] 56 | return outputs 57 | 58 | 59 | class Identity(nn.Identity, RelProp): 60 | pass 61 | 62 | 63 | class ReLU(nn.ReLU, RelProp): 64 | pass 65 | 66 | 67 | class GELU(nn.GELU, RelProp): 68 | pass 69 | 70 | class LeakyReLU(nn.LeakyReLU, RelProp): 71 | pass 72 | 73 | class Softmax(nn.Softmax, RelProp): 74 | pass 75 | 76 | class einsum(RelPropSimple): 77 | def __init__(self, equation): 78 | super().__init__() 79 | self.equation = equation 80 | def forward(self, *operands): 81 | return torch.einsum(self.equation, *operands) 82 | 83 | class Dropout(nn.Dropout, RelProp): 84 | pass 85 | 86 | 87 | class MaxPool2d(nn.MaxPool2d, RelPropSimple): 88 | pass 89 | 90 | class LayerNorm(nn.LayerNorm, RelProp): 91 | pass 92 | 93 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelProp): 94 | def relprop(self, R, alpha=1): 95 | px = torch.clamp(self.X, min=0) 96 | 97 | def f(x1): 98 | Z1 = F.adaptive_avg_pool2d(x1, self.output_size) 99 | S1 = safe_divide(R, Z1) 100 | C1 = x1 * self.gradprop(Z1, x1, S1)[0] 101 | return C1 102 | 103 | activator_relevances = f(px) 104 | out = activator_relevances 105 | return out 106 | 107 | 108 | class ZeroPad2d(nn.ZeroPad2d, RelPropSimple): 109 | def relprop(self, R, alpha=1): 110 | Z = self.forward(self.X) 111 | S = safe_divide(R, Z) 112 | C = self.gradprop(Z, self.X, S) 113 | outputs = self.X * C[0] 114 | return outputs 115 | 116 | 117 | class AvgPool2d(nn.AvgPool2d, RelPropSimple): 118 | pass 119 | 120 | 121 | class Add(RelPropSimple): 122 | def forward(self, inputs): 123 | return torch.add(*inputs) 124 | 125 | def relprop(self, R, alpha): 126 | Z = self.forward(self.X) 127 | S = safe_divide(R, Z) 128 | C = self.gradprop(Z, self.X, S) 129 | 130 | a = self.X[0] * C[0] 131 | b = self.X[1] * C[1] 132 | 133 | a_sum = a.sum() 134 | b_sum = b.sum() 135 | 136 | a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() 137 | b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() 138 | 139 | a = a * safe_divide(a_fact, a.sum()) 140 | b = b * safe_divide(b_fact, b.sum()) 141 | 142 | outputs = [a, b] 143 | 144 | return outputs 145 | 146 | 147 | class Clone(RelProp): 148 | def forward(self, input, num): 149 | self.__setattr__('num', num) 150 | outputs = [] 151 | for _ in range(num): 152 | outputs.append(input) 153 | 154 | return outputs 155 | 156 | def relprop(self, R, alpha = 1): 157 | Z = [] 158 | for _ in range(self.num): 159 | Z.append(self.X) 160 | S = [safe_divide(r, z) for r, z in zip(R, Z)] 161 | C = self.gradprop(Z, self.X, S)[0] 162 | 163 | R = self.X * C 164 | 165 | return R 166 | 167 | 168 | class Multiply(RelPropSimple): 169 | def forward(self, inputs): 170 | return torch.mul(*inputs) 171 | 172 | def relprop(self, R, alpha=1): 173 | x0 = torch.clamp(self.X[0], min=0) 174 | x1 = torch.clamp(self.X[1], min=0) 175 | x = [x0, x1] 176 | Z = self.forward(x) 177 | S = safe_divide(R, Z) 178 | C = self.gradprop(Z, x, S) 179 | outputs = [] 180 | outputs.append(x[0] * C[0]) 181 | outputs.append(x[1] * C[1]) 182 | return outputs 183 | 184 | class Sequential(nn.Sequential): 185 | def relprop(self, R, alpha=1): 186 | for m in reversed(self._modules.values()): 187 | R = m.relprop(R, alpha) 188 | return R 189 | 190 | 191 | 192 | class BatchNorm2d(nn.BatchNorm2d, RelProp): 193 | def relprop(self, R, alpha=1): 194 | X = self.X 195 | beta = 1 - alpha 196 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( 197 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5)) 198 | Z = X * weight + 1e-9 199 | S = R / Z 200 | Ca = S * weight 201 | R = self.X * (Ca) 202 | return R 203 | 204 | 205 | class Linear(nn.Linear, RelProp): 206 | def relprop(self, R, alpha=1): 207 | beta = alpha - 1 208 | pw = torch.clamp(self.weight, min=0) 209 | nw = torch.clamp(self.weight, max=0) 210 | px = torch.clamp(self.X, min=0) 211 | nx = torch.clamp(self.X, max=0) 212 | 213 | # def f(w1, w2, x1, x2): 214 | # Z1 = F.linear(x1, w1) 215 | # Z2 = F.linear(x2, w2) 216 | # S1 = safe_divide(R, Z1) 217 | # S2 = safe_divide(R, Z2) 218 | # C1 = x1 * self.gradprop(Z1, x1, S1)[0] 219 | # C2 = x2 * self.gradprop(Z2, x2, S2)[0] 220 | # return C1 #+ C2 221 | 222 | def f(w1, w2, x1, x2): 223 | Z1 = F.linear(x1, w1) 224 | Z2 = F.linear(x2, w2) 225 | Z = Z1 + Z2 226 | S = safe_divide(R, Z) 227 | C1 = x1 * self.gradprop(Z1, x1, S)[0] 228 | C2 = x2 * self.gradprop(Z2, x2, S)[0] 229 | return C1 + C2 230 | 231 | activator_relevances = f(pw, nw, px, nx) 232 | inhibitor_relevances = f(nw, pw, px, nx) 233 | 234 | out = alpha * activator_relevances - beta * inhibitor_relevances 235 | 236 | return out 237 | 238 | 239 | 240 | class Conv2d(nn.Conv2d, RelProp): 241 | 242 | def relprop(self, R, alpha=1): 243 | if self.X.shape[1] == 3: 244 | pw = torch.clamp(self.weight, min=0) 245 | nw = torch.clamp(self.weight, max=0) 246 | X = self.X 247 | L = self.X * 0 + \ 248 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 249 | keepdim=True)[0] 250 | H = self.X * 0 + \ 251 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 252 | keepdim=True)[0] 253 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \ 254 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \ 255 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9 256 | 257 | S = R / Za 258 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw) 259 | R = C 260 | else: 261 | beta = alpha - 1 262 | pw = torch.clamp(self.weight, min=0) 263 | nw = torch.clamp(self.weight, max=0) 264 | px = torch.clamp(self.X, min=0) 265 | nx = torch.clamp(self.X, max=0) 266 | 267 | def f(w1, w2, x1, x2): 268 | Z1 = F.conv2d(x1, w1, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups) 269 | Z2 = F.conv2d(x2, w2, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups) 270 | Z = Z1 + Z2 271 | S = safe_divide(R, Z) 272 | C1 = x1 * self.gradprop(Z1, x1, S)[0] 273 | C2 = x2 * self.gradprop(Z2, x2, S)[0] 274 | return C1 + C2 275 | 276 | activator_relevances = f(pw, nw, px, nx) 277 | inhibitor_relevances = f(nw, pw, px, nx) 278 | 279 | R = alpha * activator_relevances - beta * inhibitor_relevances 280 | return R 281 | 282 | 283 | 284 | class ConvTranspose2d(nn.ConvTranspose2d, RelProp): 285 | def relprop(self, R, alpha=1): 286 | pw = torch.clamp(self.weight, min=0) 287 | px = torch.clamp(self.X, min=0) 288 | 289 | def f(w1, x1): 290 | Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.stride, padding=self.padding, 291 | output_padding=self.output_padding) 292 | S1 = safe_divide(R, Z1) 293 | C1 = x1 * self.gradprop(Z1, x1, S1)[0] 294 | return C1 295 | 296 | activator_relevances = f(pw, px) 297 | R = activator_relevances 298 | return R 299 | 300 | 301 | 302 | if __name__ == '__main__': 303 | convt = ConvTranspose2d(100, 50, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False).cuda() 304 | 305 | rand = torch.rand((1, 100, 224, 224)).cuda() 306 | out = convt(rand) 307 | rel = convt.relprop(out) 308 | 309 | print(out.shape) 310 | -------------------------------------------------------------------------------- /featup/featurizers/modules/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from featup.featurizers.modules.layers import * 6 | import torch 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | def conv1x1(in_planes, out_planes, stride=1): 27 | """1x1 convolution""" 28 | return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None): 35 | super(BasicBlock, self).__init__() 36 | self.clone = Clone() 37 | 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = BatchNorm2d(planes) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = BatchNorm2d(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | self.relu1 = ReLU(inplace=True) 46 | self.relu2 = ReLU(inplace=True) 47 | 48 | self.add = Add() 49 | 50 | self.register_forward_hook(forward_hook) 51 | 52 | def forward(self, x): 53 | x1, x2 = self.clone(x, 2) 54 | 55 | out = self.conv1(x1) 56 | out = self.bn1(out) 57 | out = self.relu1(out) 58 | 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | 62 | if self.downsample is not None: 63 | x2 = self.downsample(x2) 64 | 65 | out = self.add([out, x2]) 66 | out = self.relu2(out) 67 | 68 | return out 69 | 70 | def relprop(self, R, alpha): 71 | out = self.relu2.relprop(R, alpha) 72 | out, x2 = self.add.relprop(out, alpha) 73 | 74 | if self.downsample is not None: 75 | x2 = self.downsample.relprop(x2, alpha) 76 | 77 | out = self.bn2.relprop(out, alpha) 78 | out = self.conv2.relprop(out, alpha) 79 | 80 | out = self.relu1.relprop(out, alpha) 81 | out = self.bn1.relprop(out, alpha) 82 | x1 = self.conv1.relprop(out, alpha) 83 | 84 | return self.clone.relprop([x1, x2], alpha) 85 | 86 | 87 | class Bottleneck(nn.Module): 88 | expansion = 4 89 | 90 | def __init__(self, inplanes, planes, stride=1, downsample=None): 91 | super(Bottleneck, self).__init__() 92 | 93 | self.conv1 = conv1x1(inplanes, planes) 94 | self.bn1 = BatchNorm2d(planes) 95 | self.conv2 = conv3x3(planes, planes, stride) 96 | self.bn2 = BatchNorm2d(planes) 97 | self.conv3 = conv1x1(planes, planes * self.expansion) 98 | self.bn3 = BatchNorm2d(planes * self.expansion) 99 | self.downsample = downsample 100 | self.stride = stride 101 | 102 | self.relu1 = ReLU(inplace=True) 103 | self.relu2 = ReLU(inplace=True) 104 | self.relu3 = ReLU(inplace=True) 105 | 106 | self.add = Add() 107 | 108 | self.register_forward_hook(forward_hook) 109 | 110 | def forward(self, x): 111 | 112 | out = self.conv1(x) 113 | out = self.bn1(out) 114 | out = self.relu1(out) 115 | 116 | out = self.conv2(out) 117 | out = self.bn2(out) 118 | out = self.relu2(out) 119 | 120 | out = self.conv3(out) 121 | out = self.bn3(out) 122 | 123 | if self.downsample is not None: 124 | x = self.downsample(x) 125 | 126 | out = self.add([out, x]) 127 | out = self.relu3(out) 128 | 129 | return out 130 | 131 | def relprop(self, R, alpha): 132 | out = self.relu3.relprop(R, alpha) 133 | 134 | out, x = self.add.relprop(out, alpha) 135 | 136 | if self.downsample is not None: 137 | x = self.downsample.relprop(x, alpha) 138 | 139 | out = self.bn3.relprop(out, alpha) 140 | out = self.conv3.relprop(out, alpha) 141 | 142 | out = self.relu2.relprop(out, alpha) 143 | out = self.bn2.relprop(out, alpha) 144 | out = self.conv2.relprop(out, alpha) 145 | 146 | out = self.relu1.relprop(out, alpha) 147 | out = self.bn1.relprop(out, alpha) 148 | x1 = self.conv1.relprop(out, alpha) 149 | 150 | return x1 + x 151 | 152 | 153 | class ResNet(nn.Module): 154 | 155 | def __init__(self, block, layers, num_classes=1000, long=False, zero_init_residual=False): 156 | super(ResNet, self).__init__() 157 | self.inplanes = 64 158 | self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 159 | self.bn1 = BatchNorm2d(64) 160 | self.relu = ReLU(inplace=True) 161 | self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) 162 | self.layer1 = self._make_layer(block, 64, layers[0]) 163 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 164 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 165 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 166 | self.avgpool = AdaptiveAvgPool2d((1, 1)) 167 | self.fc = Linear(512 * block.expansion, num_classes) 168 | self.long = long 169 | self.num_classes = num_classes 170 | 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 174 | elif isinstance(m, nn.BatchNorm2d): 175 | nn.init.constant_(m.weight, 1) 176 | nn.init.constant_(m.bias, 0) 177 | 178 | # Zero-initialize the last BN in each residual branch, 179 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 180 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 181 | if zero_init_residual: 182 | for m in self.modules(): 183 | if isinstance(m, Bottleneck): 184 | nn.init.constant_(m.bn3.weight, 0) 185 | elif isinstance(m, BasicBlock): 186 | nn.init.constant_(m.bn2.weight, 0) 187 | 188 | def _make_layer(self, block, planes, blocks, stride=1): 189 | downsample = None 190 | if stride != 1 or self.inplanes != planes * block.expansion: 191 | downsample = Sequential( 192 | conv1x1(self.inplanes, planes * block.expansion, stride), 193 | BatchNorm2d(planes * block.expansion), 194 | ) 195 | 196 | layers = [] 197 | layers.append(block(self.inplanes, planes, stride, downsample)) 198 | self.inplanes = planes * block.expansion 199 | for _ in range(1, blocks): 200 | layers.append(block(self.inplanes, planes)) 201 | 202 | return Sequential(*layers) 203 | 204 | def CLRP(self, x): 205 | maxindex = torch.argmax(x, dim=1) 206 | R = torch.ones(x.shape, device=x.device) 207 | R /= -self.num_classes 208 | for i in range(R.size(0)): 209 | R[i, maxindex[i]] = 1 210 | return R 211 | 212 | def forward(self, img): 213 | x = self.conv1(img) 214 | x = self.bn1(x) 215 | x = self.relu(x) 216 | x = self.maxpool(x) 217 | layer1 = self.layer1(x) 218 | layer2 = self.layer2(layer1) 219 | layer3 = self.layer3(layer2) 220 | layer4 = self.layer4(layer3) 221 | 222 | x = self.avgpool(layer4) 223 | x = x.view(x.size(0), -1) 224 | return self.fc(x) 225 | 226 | def get_layer(self, img, layer_num): 227 | x = self.conv1(img) 228 | x = self.bn1(x) 229 | x = self.relu(x) 230 | x = self.maxpool(x) 231 | layer1 = self.layer1(x) 232 | if layer_num == 1: 233 | return layer1 234 | layer2 = self.layer2(layer1) 235 | if layer_num == 2: 236 | return layer2 237 | layer3 = self.layer3(layer2) 238 | if layer_num == 3: 239 | return layer3 240 | layer4 = self.layer4(layer3) 241 | if layer_num == 4 or layer_num == -1: 242 | return layer4 243 | if isinstance(layer_num, tuple): 244 | return [[layer1, layer2, layer3, layer4][i-1] for i in layer_num] 245 | 246 | raise ValueError(f"Unknown layer num: {layer_num}") 247 | 248 | def relevance_cam(self, large_img, layer_num, upsampler): 249 | small_img = F.interpolate(large_img, size=(224, 224), mode='bilinear') 250 | layer1, layer2, layer3, layer4 = self.get_layer(small_img, (1, 2, 3, 4)) 251 | x = self.avgpool(layer4) 252 | x = x.view(x.size(0), -1) 253 | z = self.fc(x) 254 | 255 | R = self.CLRP(z) 256 | R = self.fc.relprop(R, 1) 257 | R = R.reshape_as(self.avgpool.Y) 258 | R4 = self.avgpool.relprop(R, 1) 259 | 260 | if layer_num == 4: 261 | r_weight4 = torch.mean(R4, dim=(2, 3), keepdim=True) 262 | r_cam4 = upsampler(large_img, source=layer4) * r_weight4 263 | r_cam4 = torch.sum(r_cam4, dim=(1), keepdim=True) 264 | return r_cam4 265 | elif layer_num == 3: 266 | R3 = self.layer4.relprop(R4, 1) 267 | r_weight3 = torch.mean(R3, dim=(2, 3), keepdim=True) 268 | r_cam3 = upsampler(large_img, source=layer3) * r_weight3 269 | r_cam3 = torch.sum(r_cam3, dim=(1), keepdim=True) 270 | return r_cam3 271 | elif layer_num == 2: 272 | R3 = self.layer4.relprop(R4, 1) 273 | R2 = self.layer3.relprop(R3, 1) 274 | r_weight2 = torch.mean(R2, dim=(2, 3), keepdim=True) 275 | r_cam2 = upsampler(large_img, source=layer2) * r_weight2 276 | r_cam2 = torch.sum(r_cam2, dim=(1), keepdim=True) 277 | return r_cam2 278 | else: 279 | raise ValueError(f"Unknown layer_num: {layer_num}") 280 | 281 | 282 | def resnet18(pretrained=False, **kwargs): 283 | """Constructs a ResNet-18 model. 284 | 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | """ 288 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 289 | if pretrained: 290 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 291 | return model 292 | 293 | 294 | def resnet34(pretrained=False, **kwargs): 295 | """Constructs a ResNet-34 model. 296 | 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | """ 300 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 301 | if pretrained: 302 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 303 | return model 304 | 305 | 306 | def resnet50(pretrained=False, long=False, **kwargs): 307 | """Constructs a ResNet-50 model. 308 | 309 | Args: 310 | pretrained (bool): If True, returns a model pre-trained on ImageNet 311 | """ 312 | model = ResNet(Bottleneck, [3, 4, 6, 3], long=long, **kwargs) 313 | if pretrained: 314 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 315 | return model 316 | 317 | 318 | def resnet101(pretrained=False, **kwargs): 319 | """Constructs a ResNet-101 model. 320 | 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | """ 324 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 325 | if pretrained: 326 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 327 | return model 328 | 329 | 330 | def resnet152(pretrained=False, **kwargs): 331 | """Constructs a ResNet-152 model. 332 | 333 | Args: 334 | pretrained (bool): If True, returns a model pre-trained on ImageNet 335 | """ 336 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 337 | if pretrained: 338 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 339 | return model 340 | -------------------------------------------------------------------------------- /featup/featurizers/modules/vgg.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | from featup.featurizers.modules.layers import * 7 | 8 | __all__ = [ 9 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 10 | 'vgg19_bn', 'vgg19', 11 | ] 12 | 13 | 14 | model_urls = { 15 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 16 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 17 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 18 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 19 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 20 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 21 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 22 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 23 | } 24 | 25 | class VGG_spread(nn.Module): 26 | 27 | def __init__(self, features, num_classes=1000, init_weights=True): 28 | super(VGG_spread, self).__init__() 29 | self.features = features 30 | self.avgpool = AdaptiveAvgPool2d((7, 7)) 31 | self.classifier = Sequential( 32 | Linear(512 * 7 * 7, 4096), 33 | ReLU(True), 34 | Dropout(), 35 | Linear(4096, 4096), 36 | ReLU(True), 37 | Dropout(), 38 | Linear(4096, num_classes), 39 | ) 40 | if init_weights: 41 | self._initialize_weights() 42 | 43 | def forward(self, x): 44 | for layer in self.features: 45 | x = layer(x) 46 | x = self.avgpool(x) 47 | x = x.view(x.size(0), -1) 48 | x = self.classifier(x) 49 | return x 50 | 51 | def relprop(self, R, alpha): 52 | x = self.classifier.relprop(R, alpha) 53 | x = x.reshape_as(next(reversed(self.features._modules.values())).Y) 54 | x = self.avgpool.relprop(x, alpha) 55 | x = self.features.relprop(x, alpha) 56 | return x 57 | 58 | def m_relprop(self, R, pred, alpha): 59 | x = self.classifier.m_relprop(R, pred, alpha) 60 | if torch.is_tensor(x) == False: 61 | for i in range(len(x)): 62 | x[i] = x[i].reshape_as(next(reversed(self.features._modules.values())).Y) 63 | else: 64 | x = x.reshape_as(next(reversed(self.features._modules.values())).Y) 65 | x = self.avgpool.m_relprop(x, pred, alpha) 66 | x = self.features.m_relprop(x, pred, alpha) 67 | return x 68 | 69 | def RAP_relprop(self, R): 70 | x1 = self.classifier.RAP_relprop(R) 71 | if torch.is_tensor(x1) == False: 72 | for i in range(len(x1)): 73 | x1[i] = x1[i].reshape_as(next(reversed(self.features._modules.values())).Y) 74 | else: 75 | x1 = x1.reshape_as(next(reversed(self.features._modules.values())).Y) 76 | x1 = self.avgpool.RAP_relprop(x1) 77 | x1 = self.features.RAP_relprop(x1) 78 | return x1 79 | 80 | def _initialize_weights(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 84 | if m.bias is not None: 85 | nn.init.constant_(m.bias, 0) 86 | elif isinstance(m, nn.BatchNorm2d): 87 | nn.init.constant_(m.weight, 1) 88 | nn.init.constant_(m.bias, 0) 89 | elif isinstance(m, nn.Linear): 90 | nn.init.normal_(m.weight, 0, 0.01) 91 | nn.init.constant_(m.bias, 0) 92 | 93 | 94 | class VGG(nn.Module): 95 | 96 | def __init__(self, features, num_classes=1000, init_weights=True): 97 | super(VGG, self).__init__() 98 | self.features = features 99 | self.avgpool = AdaptiveAvgPool2d((7, 7)) 100 | self.classifier = Sequential( 101 | Linear(512 * 7 * 7, 4096), 102 | ReLU(True), 103 | Dropout(), 104 | Linear(4096, 4096), 105 | ReLU(True), 106 | Dropout(), 107 | Linear(4096, num_classes), 108 | ) 109 | self.num_classes = num_classes 110 | if init_weights: 111 | self._initialize_weights() 112 | 113 | def CLRP(self, x, maxindex = [None]): 114 | if maxindex == [None]: 115 | maxindex = torch.argmax(x, dim=1) 116 | R = torch.ones(x.shape, x.device) 117 | R /= -self.num_classes 118 | for i in range(R.size(0)): 119 | R[i, maxindex[i]] = 1 120 | return R 121 | 122 | def upsample(self, source, guidance_unscaled, upsampler, scale): 123 | _, _, H, W = source.shape 124 | guidance = F.interpolate(guidance_unscaled, size=(H * scale, W * scale), mode='bilinear') 125 | return upsampler(source, guidance) 126 | 127 | def forward(self, x,mode='output', target_class = [None], upsampler=None, scale=1): 128 | inp = copy.deepcopy(x) 129 | for i, layer in enumerate(self.features): 130 | x = layer(x) 131 | if mode.lstrip('-').isnumeric(): 132 | if int(mode) == i: 133 | target_layer = x 134 | 135 | x = self.avgpool(x) 136 | x = x.view(x.size(0), -1) 137 | x = self.classifier(x) 138 | 139 | if mode == 'output': 140 | return x 141 | 142 | R = self.CLRP(x, target_class) 143 | R = self.classifier.relprop(R) 144 | R = R.reshape_as(next(reversed(self.features._modules.values())).Y) 145 | R = self.avgpool.relprop(R) 146 | 147 | for i in range(len(self.features)-1, int(mode), -1): 148 | R = self.features[i].relprop(R) 149 | 150 | if upsampler is not None: 151 | target_layer = self.upsample(target_layer, inp, upsampler, scale) 152 | 153 | r_weight = torch.mean(R, dim=(2, 3), keepdim=True) 154 | r_cam = target_layer * r_weight 155 | r_cam = torch.sum(r_cam, dim=(1), keepdim=True) 156 | return r_cam, x 157 | 158 | 159 | 160 | def relprop(self, R, alpha, flag=-1): 161 | x = self.classifier.relprop(R, alpha) 162 | x = x.reshape_as(next(reversed(self.features._modules.values())).Y) 163 | x = self.avgpool.relprop(x, alpha) 164 | # x = self.features.relprop(x, alpha) 165 | for i in range(43, flag, -1): 166 | x = self.features[i].relprop(x, alpha) 167 | return x 168 | 169 | def m_relprop(self, R, pred, alpha): 170 | x = self.classifier.m_relprop(R, pred, alpha) 171 | if torch.is_tensor(x) == False: 172 | for i in range(len(x)): 173 | x[i] = x[i].reshape_as(next(reversed(self.features._modules.values())).Y) 174 | else: 175 | x = x.reshape_as(next(reversed(self.features._modules.values())).Y) 176 | x = self.avgpool.m_relprop(x, pred, alpha) 177 | x = self.features.m_relprop(x, pred, alpha) 178 | return x 179 | 180 | def RAP_relprop(self, R): 181 | x1 = self.classifier.RAP_relprop(R) 182 | if torch.is_tensor(x1) == False: 183 | for i in range(len(x1)): 184 | x1[i] = x1[i].reshape_as(next(reversed(self.features._modules.values())).Y) 185 | else: 186 | x1 = x1.reshape_as(next(reversed(self.features._modules.values())).Y) 187 | x1 = self.avgpool.RAP_relprop(x1) 188 | x1 = self.features.RAP_relprop(x1) 189 | 190 | return x1 191 | def _initialize_weights(self): 192 | for m in self.modules(): 193 | if isinstance(m, nn.Conv2d): 194 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 195 | if m.bias is not None: 196 | nn.init.constant_(m.bias, 0) 197 | elif isinstance(m, nn.BatchNorm2d): 198 | nn.init.constant_(m.weight, 1) 199 | nn.init.constant_(m.bias, 0) 200 | elif isinstance(m, nn.Linear): 201 | nn.init.normal_(m.weight, 0, 0.01) 202 | nn.init.constant_(m.bias, 0) 203 | 204 | def make_layers(cfg, batch_norm=False): 205 | layers = [] 206 | in_channels = 3 207 | 208 | for v in cfg: 209 | if v == 'M': 210 | layers += [MaxPool2d(kernel_size=2, stride=2)] 211 | else: 212 | conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1) 213 | if batch_norm: 214 | layers += [conv2d, BatchNorm2d(v), ReLU(inplace=True)] 215 | else: 216 | layers += [conv2d, ReLU(inplace=True)] 217 | in_channels = v 218 | 219 | return Sequential(*layers) 220 | 221 | def make_layers_list(cfg, batch_norm=False): 222 | layers = [] 223 | in_channels = 3 224 | for v in cfg: 225 | if v == 'M': 226 | layers += [MaxPool2d(kernel_size=2, stride=2)] 227 | else: 228 | conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1) 229 | if batch_norm: 230 | layers += [conv2d, BatchNorm2d(v), ReLU(inplace=True)] 231 | else: 232 | layers += [conv2d, ReLU(inplace=True)] 233 | in_channels = v 234 | return layers 235 | 236 | 237 | cfg = { 238 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 239 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 240 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 241 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 242 | } 243 | 244 | 245 | def vgg11(pretrained=False, **kwargs): 246 | """VGG 11-layer model (configuration "A") 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | """ 251 | if pretrained: 252 | kwargs['init_weights'] = False 253 | model = VGG(make_layers(cfg['A']), **kwargs) 254 | if pretrained: 255 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 256 | return model 257 | 258 | 259 | def vgg11_bn(pretrained=False, **kwargs): 260 | """VGG 11-layer model (configuration "A") with batch normalization 261 | 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | """ 265 | if pretrained: 266 | kwargs['init_weights'] = False 267 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 268 | if pretrained: 269 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 270 | return model 271 | 272 | 273 | def vgg13(pretrained=False, **kwargs): 274 | """VGG 13-layer model (configuration "B") 275 | 276 | Args: 277 | pretrained (bool): If True, returns a model pre-trained on ImageNet 278 | """ 279 | if pretrained: 280 | kwargs['init_weights'] = False 281 | model = VGG(make_layers(cfg['B']), **kwargs) 282 | if pretrained: 283 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 284 | return model 285 | 286 | 287 | def vgg13_bn(pretrained=False, **kwargs): 288 | """VGG 13-layer model (configuration "B") with batch normalization 289 | 290 | Args: 291 | pretrained (bool): If True, returns a model pre-trained on ImageNet 292 | """ 293 | if pretrained: 294 | kwargs['init_weights'] = False 295 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 296 | if pretrained: 297 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 298 | return model 299 | 300 | 301 | def vgg16(pretrained=False, **kwargs): 302 | """VGG 16-layer model (configuration "D") 303 | 304 | Args: 305 | pretrained (bool): If True, returns a model pre-trained on ImageNet 306 | """ 307 | if pretrained: 308 | kwargs['init_weights'] = False 309 | model = VGG(make_layers(cfg['D']), **kwargs) 310 | if pretrained: 311 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 312 | return model 313 | 314 | def vgg16_spread(pretrained=False, **kwargs): 315 | """VGG 16-layer model (configuration "D") 316 | 317 | Args: 318 | pretrained (bool): If True, returns a model pre-trained on ImageNet 319 | """ 320 | if pretrained: 321 | kwargs['init_weights'] = False 322 | model = VGG_spread(make_layers_list(cfg['D']), **kwargs) 323 | if pretrained: 324 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 325 | return model 326 | 327 | def vgg16_bn(pretrained=False, **kwargs): 328 | """VGG 16-layer model (configuration "D") with batch normalization 329 | 330 | Args: 331 | pretrained (bool): If True, returns a model pre-trained on ImageNet 332 | """ 333 | if pretrained: 334 | kwargs['init_weights'] = False 335 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 336 | if pretrained: 337 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 338 | return model 339 | 340 | 341 | def vgg19(pretrained=False, **kwargs): 342 | """VGG 19-layer model (configuration "E") 343 | 344 | Args: 345 | pretrained (bool): If True, returns a model pre-trained on ImageNet 346 | """ 347 | if pretrained: 348 | kwargs['init_weights'] = False 349 | model = VGG(make_layers(cfg['E']), **kwargs) 350 | if pretrained: 351 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 352 | return model 353 | 354 | 355 | def vgg19_bn(pretrained=False, **kwargs): 356 | """VGG 19-layer model (configuration 'E') with batch normalization 357 | 358 | Args: 359 | pretrained (bool): If True, returns a model pre-trained on ImageNet 360 | """ 361 | if pretrained: 362 | kwargs['init_weights'] = False 363 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 364 | if pretrained: 365 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 366 | return model 367 | -------------------------------------------------------------------------------- /featup/featurizers/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_featurizer(name, activation_type="key", **kwargs): 4 | name = name.lower() 5 | if name == "vit": 6 | from .DINO import DINOFeaturizer 7 | patch_size = 16 8 | model = DINOFeaturizer("vit_small_patch16_224", patch_size, activation_type) 9 | dim = 384 10 | elif name == "midas": 11 | from .MIDAS import MIDASFeaturizer 12 | patch_size = 16 13 | model = MIDASFeaturizer(output_root=kwargs["output_root"]) 14 | dim = 768 15 | elif name == "dino16": 16 | from .DINO import DINOFeaturizer 17 | patch_size = 16 18 | model = DINOFeaturizer("dino_vits16", patch_size, activation_type) 19 | dim = 384 20 | elif name == "dino8": 21 | from .DINO import DINOFeaturizer 22 | patch_size = 8 23 | model = DINOFeaturizer("dino_vits8", patch_size, activation_type) 24 | dim = 384 25 | elif name == "dinov2": 26 | from .DINOv2 import DINOv2Featurizer 27 | patch_size = 14 28 | model = DINOv2Featurizer("dinov2_vits14", patch_size, activation_type) 29 | dim = 384 30 | elif name == "clip": 31 | from .CLIP import CLIPFeaturizer 32 | patch_size = 16 33 | model = CLIPFeaturizer() 34 | dim = 512 35 | elif name == "maskclip": 36 | from .MaskCLIP import MaskCLIPFeaturizer 37 | patch_size = 16 38 | model = MaskCLIPFeaturizer() 39 | dim = 512 40 | elif name == "mae": 41 | from .MAE import MAEFeaturizer 42 | patch_size = 16 43 | model = MAEFeaturizer(**kwargs) 44 | dim = 1024 45 | elif name == "mocov3": 46 | from .MOCOv3 import MOCOv3Featurizer 47 | patch_size = 16 48 | model = MOCOv3Featurizer() 49 | dim = 384 50 | elif name == "msn": 51 | from .MSN import MSNFeaturizer 52 | patch_size = 16 53 | model = MSNFeaturizer() 54 | dim = 384 55 | elif name == "pixels": 56 | patch_size = 1 57 | model = lambda x: x 58 | dim = 3 59 | elif name == "resnet50": 60 | from .modules.resnet import resnet50 61 | from .ResNet import ResNetFeaturizer 62 | model = ResNetFeaturizer(resnet50(pretrained=True)) 63 | patch_size = 1 64 | dim = 2048 65 | elif name == "deeplab": 66 | from .DeepLabV3 import DeepLabV3Featurizer 67 | model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True) 68 | model = DeepLabV3Featurizer(model) 69 | patch_size = 1 70 | dim = 2048 71 | else: 72 | raise ValueError("unknown model: {}".format(name)) 73 | return model, patch_size, dim 74 | -------------------------------------------------------------------------------- /featup/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def id_conv(dim, strength=.9): 5 | conv = torch.nn.Conv2d(dim, dim, 1, padding="same") 6 | start_w = conv.weight.data 7 | conv.weight.data = torch.nn.Parameter( 8 | torch.eye(dim, device=start_w.device).unsqueeze(-1).unsqueeze(-1) * strength + start_w * (1 - strength)) 9 | conv.bias.data = torch.nn.Parameter(conv.bias.data * (1 - strength)) 10 | return conv 11 | 12 | 13 | class ImplicitFeaturizer(torch.nn.Module): 14 | 15 | def __init__(self, color_feats=True, n_freqs=10, learn_bias=False, time_feats=False, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.color_feats = color_feats 18 | self.time_feats = time_feats 19 | self.n_freqs = n_freqs 20 | self.learn_bias = learn_bias 21 | 22 | self.dim_multiplier = 2 23 | 24 | if self.color_feats: 25 | self.dim_multiplier += 3 26 | 27 | if self.time_feats: 28 | self.dim_multiplier += 1 29 | 30 | if self.learn_bias: 31 | self.biases = torch.nn.Parameter(torch.randn(2, self.dim_multiplier, n_freqs).to(torch.float32)) 32 | 33 | def forward(self, original_image): 34 | b, c, h, w = original_image.shape 35 | grid_h = torch.linspace(-1, 1, h, device=original_image.device) 36 | grid_w = torch.linspace(-1, 1, w, device=original_image.device) 37 | feats = torch.cat([t.unsqueeze(0) for t in torch.meshgrid([grid_h, grid_w])]).unsqueeze(0) 38 | feats = torch.broadcast_to(feats, (b, feats.shape[1], h, w)) 39 | 40 | if self.color_feats: 41 | feat_list = [feats, original_image] 42 | else: 43 | feat_list = [feats] 44 | 45 | feats = torch.cat(feat_list, dim=1).unsqueeze(1) 46 | freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=original_image.device)) \ 47 | .reshape(1, self.n_freqs, 1, 1, 1) 48 | feats = (feats * freqs) 49 | 50 | if self.learn_bias: 51 | sin_feats = feats + self.biases[0].reshape(1, self.n_freqs, self.dim_multiplier, 1, 1) 52 | cos_feats = feats + self.biases[1].reshape(1, self.n_freqs, self.dim_multiplier, 1, 1) 53 | else: 54 | sin_feats = feats 55 | cos_feats = feats 56 | 57 | sin_feats = sin_feats.reshape(b, self.n_freqs * self.dim_multiplier, h, w) 58 | cos_feats = cos_feats.reshape(b, self.n_freqs * self.dim_multiplier, h, w) 59 | 60 | if self.color_feats: 61 | all_feats = [torch.sin(sin_feats), torch.cos(cos_feats), original_image] 62 | else: 63 | all_feats = [torch.sin(sin_feats), torch.cos(cos_feats)] 64 | 65 | return torch.cat(all_feats, dim=1) 66 | 67 | 68 | class MinMaxScaler(torch.nn.Module): 69 | 70 | def __init__(self): 71 | super().__init__() 72 | 73 | def forward(self, x): 74 | c = x.shape[1] 75 | flat_x = x.permute(1, 0, 2, 3).reshape(c, -1) 76 | flat_x_min = flat_x.min(dim=-1).values.reshape(1, c, 1, 1) 77 | flat_x_scale = flat_x.max(dim=-1).values.reshape(1, c, 1, 1) - flat_x_min 78 | return ((x - flat_x_min) / flat_x_scale.clamp_min(0.0001)) - .5 79 | 80 | 81 | class ChannelNorm(torch.nn.Module): 82 | 83 | def __init__(self, dim, *args, **kwargs): 84 | super().__init__(*args, **kwargs) 85 | self.norm = torch.nn.LayerNorm(dim) 86 | 87 | def forward(self, x): 88 | new_x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 89 | return new_x 90 | -------------------------------------------------------------------------------- /featup/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def entropy(t): 6 | return -(t * torch.log(t.clamp_min(.0000001))).sum(dim=[-1, -2, -3]).mean() 7 | 8 | 9 | def total_variation(img): 10 | b, c, h, w = img.size() 11 | return ((img[:, :, 1:, :] - img[:, :, :-1, :]).square().sum() + 12 | (img[:, :, :, 1:] - img[:, :, :, :-1]).square().sum()) / (b * c * h * w) 13 | 14 | 15 | class SampledCRFLoss(torch.nn.Module): 16 | 17 | def __init__(self, n_samples, alpha, beta, gamma, w1, w2, shift): 18 | super(SampledCRFLoss, self).__init__() 19 | self.alpha = alpha 20 | self.beta = beta 21 | self.gamma = gamma 22 | self.w1 = w1 23 | self.w2 = w2 24 | self.n_samples = n_samples 25 | self.shift = shift 26 | 27 | def forward(self, guidance, features): 28 | device = features.device 29 | assert (guidance.shape[0] == features.shape[0]) 30 | assert (guidance.shape[2:] == features.shape[2:]) 31 | h = guidance.shape[2] 32 | w = guidance.shape[3] 33 | 34 | coords = torch.cat([ 35 | torch.randint(0, h, size=[1, self.n_samples], device=device), 36 | torch.randint(0, w, size=[1, self.n_samples], device=device)], 0) 37 | norm_coords = coords / torch.tensor([h, w], device=guidance.device).unsqueeze(-1) 38 | 39 | selected_guidance = guidance[:, :, coords[0, :], coords[1, :]] 40 | 41 | coord_diff = (norm_coords.unsqueeze(-1) - norm_coords.unsqueeze(-2)).square().sum(0).unsqueeze(0) 42 | guidance_diff = (selected_guidance.unsqueeze(-1) - selected_guidance.unsqueeze(-2)).square().sum(1) 43 | 44 | sim_kernel = self.w1 * torch.exp(- coord_diff / (2 * self.alpha) - guidance_diff / (2 * self.beta)) + \ 45 | self.w2 * torch.exp(- coord_diff / (2 * self.gamma)) - self.shift 46 | 47 | # selected_clusters = F.normalize(features[:, :, coords[0, :], coords[1, :]], dim=1) 48 | # cluster_sims = torch.einsum("bcn,bcm->bnm", selected_clusters, selected_clusters) 49 | selected_feats = features[:, :, coords[0, :], coords[1, :]] 50 | feat_diff = (selected_feats.unsqueeze(-1) - selected_feats.unsqueeze(-2)).square().sum(1) 51 | 52 | return (feat_diff * sim_kernel).mean() 53 | 54 | 55 | class TVLoss(torch.nn.Module): 56 | 57 | def __init__(self): 58 | super(TVLoss, self).__init__() 59 | 60 | def forward(self, img): 61 | b, c, h, w = img.size() 62 | return ((img[:, :, 1:, :] - img[:, :, :-1, :]).square().sum() + 63 | (img[:, :, :, 1:] - img[:, :, :, :-1]).square().sum()) / (b * c * h * w) 64 | 65 | 66 | def compute_scale_and_shift(prediction, target, mask): 67 | # system matrix: A = [[a_00, a_01], [a_10, a_11]] 68 | a_00 = torch.sum(mask * prediction * prediction, (1, 2)) 69 | a_01 = torch.sum(mask * prediction, (1, 2)) 70 | a_11 = torch.sum(mask, (1, 2)) 71 | 72 | # right hand side: b = [b_0, b_1] 73 | b_0 = torch.sum(mask * prediction * target, (1, 2)) 74 | b_1 = torch.sum(mask * target, (1, 2)) 75 | 76 | # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b 77 | x_0 = torch.zeros_like(b_0) 78 | x_1 = torch.zeros_like(b_1) 79 | 80 | det = a_00 * a_11 - a_01 * a_01 81 | valid = det.nonzero() 82 | 83 | x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] 84 | x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] 85 | 86 | return x_0, x_1 87 | 88 | 89 | def reduction_batch_based(image_loss, M): 90 | # average of all valid pixels of the batch 91 | 92 | # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) 93 | divisor = torch.sum(M) 94 | 95 | if divisor == 0: 96 | return 0 97 | else: 98 | return torch.sum(image_loss) / divisor 99 | 100 | 101 | def reduction_image_based(image_loss, M): 102 | # mean of average of valid pixels of an image 103 | 104 | # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) 105 | valid = M.nonzero() 106 | 107 | image_loss[valid] = image_loss[valid] / M[valid] 108 | 109 | return torch.mean(image_loss) 110 | 111 | 112 | def mse_loss(prediction, target, mask, reduction=reduction_batch_based): 113 | M = torch.sum(mask, (1, 2)) 114 | res = prediction - target 115 | image_loss = torch.sum(mask * res * res, (1, 2)) 116 | 117 | return reduction(image_loss, 2 * M) 118 | 119 | 120 | def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): 121 | M = torch.sum(mask, (1, 2)) 122 | 123 | diff = prediction - target 124 | diff = torch.mul(mask, diff) 125 | 126 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) 127 | mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) 128 | grad_x = torch.mul(mask_x, grad_x) 129 | 130 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) 131 | mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) 132 | grad_y = torch.mul(mask_y, grad_y) 133 | 134 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) 135 | 136 | return reduction(image_loss, M) 137 | 138 | 139 | class MSELoss(nn.Module): 140 | def __init__(self, reduction='batch-based'): 141 | super().__init__() 142 | 143 | if reduction == 'batch-based': 144 | self.__reduction = reduction_batch_based 145 | else: 146 | self.__reduction = reduction_image_based 147 | 148 | def forward(self, prediction, target, mask): 149 | return mse_loss(prediction, target, mask, reduction=self.__reduction) 150 | 151 | 152 | class GradientLoss(nn.Module): 153 | def __init__(self, scales=4, reduction='batch-based'): 154 | super().__init__() 155 | 156 | if reduction == 'batch-based': 157 | self.__reduction = reduction_batch_based 158 | else: 159 | self.__reduction = reduction_image_based 160 | 161 | self.__scales = scales 162 | 163 | def forward(self, prediction, target, mask): 164 | total = 0 165 | 166 | for scale in range(self.__scales): 167 | step = pow(2, scale) 168 | 169 | total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], 170 | mask[:, ::step, ::step], reduction=self.__reduction) 171 | 172 | return total 173 | 174 | 175 | class ScaleAndShiftInvariantLoss(nn.Module): 176 | def __init__(self, alpha=0.5, scales=4, reduction='batch-based'): 177 | super().__init__() 178 | 179 | self.__data_loss = MSELoss(reduction=reduction) 180 | self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) 181 | self.__alpha = alpha 182 | 183 | self.__prediction_ssi = None 184 | 185 | def forward(self, prediction, target, mask): 186 | scale, shift = compute_scale_and_shift(prediction, target, mask) 187 | self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) 188 | 189 | total = self.__data_loss(self.__prediction_ssi, target, mask) 190 | if self.__alpha > 0: 191 | total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) 192 | 193 | return total 194 | 195 | def __get_prediction_ssi(self): 196 | return self.__prediction_ssi 197 | 198 | prediction_ssi = property(__get_prediction_ssi) 199 | -------------------------------------------------------------------------------- /featup/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from featup.util import pca, remove_axes 3 | from featup.featurizers.maskclip.clip import tokenize 4 | from pytorch_lightning import seed_everything 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | @torch.no_grad() 10 | def plot_feats(image, lr, hr): 11 | assert len(image.shape) == len(lr.shape) == len(hr.shape) == 3 12 | seed_everything(0) 13 | [lr_feats_pca, hr_feats_pca], _ = pca([lr.unsqueeze(0), hr.unsqueeze(0)]) 14 | fig, ax = plt.subplots(1, 3, figsize=(15, 5)) 15 | ax[0].imshow(image.permute(1, 2, 0).detach().cpu()) 16 | ax[0].set_title("Image") 17 | ax[1].imshow(lr_feats_pca[0].permute(1, 2, 0).detach().cpu()) 18 | ax[1].set_title("Original Features") 19 | ax[2].imshow(hr_feats_pca[0].permute(1, 2, 0).detach().cpu()) 20 | ax[2].set_title("Upsampled Features") 21 | remove_axes(ax) 22 | plt.show() 23 | 24 | 25 | @torch.no_grad() 26 | def plot_lang_heatmaps(model, image, lr_feats, hr_feats, text_query): 27 | assert len(image.shape) == len(lr_feats.shape) == len(hr_feats.shape) == 3 28 | fig, ax = plt.subplots(1, 3, figsize=(15, 5)) 29 | cmap = plt.get_cmap("turbo") 30 | 31 | # encode query 32 | text = tokenize(text_query).to(lr_feats.device) 33 | text_feats = model.model.encode_text(text).squeeze().to(torch.float32) 34 | assert len(text_feats.shape) == 1 35 | 36 | lr_sims = torch.einsum( 37 | "chw,c->hw", F.normalize(lr_feats.to(torch.float32), dim=0), F.normalize(text_feats, dim=0)) 38 | hr_sims = torch.einsum( 39 | "chw,c->hw", F.normalize(hr_feats.to(torch.float32), dim=0), F.normalize(text_feats, dim=0)) 40 | 41 | lr_sims_norm = (lr_sims - lr_sims.min()) / (lr_sims.max() - lr_sims.min()) 42 | hr_sims_norm = (hr_sims - hr_sims.min()) / (hr_sims.max() - hr_sims.min()) 43 | lr_heatmap = cmap(lr_sims_norm.cpu().numpy()) 44 | hr_heatmap = cmap(hr_sims_norm.cpu().numpy()) 45 | 46 | ax[0].imshow(image.permute(1, 2, 0).detach().cpu()) 47 | ax[0].set_title("Image") 48 | ax[1].imshow(lr_heatmap) 49 | ax[1].set_title(f"Original Similarity to \"{text_query}\"") 50 | ax[2].imshow(hr_heatmap) 51 | ax[2].set_title(f"Upsampled Similarity to \"{text_query}\"") 52 | remove_axes(ax) 53 | 54 | return plt.show() 55 | -------------------------------------------------------------------------------- /featup/train_jbu_upsampler.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | 4 | import hydra 5 | import pytorch_lightning as pl 6 | import torch 7 | import torchvision.transforms as T 8 | from omegaconf import DictConfig 9 | from omegaconf import OmegaConf 10 | from pytorch_lightning import Trainer 11 | from pytorch_lightning import seed_everything 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from pytorch_lightning.loggers import TensorBoardLogger 14 | from torch.utils.data import DataLoader 15 | from torchvision.transforms import InterpolationMode 16 | from os.path import join 17 | 18 | from featup.datasets.JitteredImage import apply_jitter, sample_transform 19 | from featup.datasets.util import get_dataset, SingleImageDataset 20 | from featup.downsamplers import SimpleDownsampler, AttentionDownsampler 21 | from featup.featurizers.util import get_featurizer 22 | from featup.layers import ChannelNorm 23 | from featup.losses import TVLoss, SampledCRFLoss, entropy 24 | from featup.upsamplers import get_upsampler 25 | from featup.util import pca, RollingAvg, unnorm, norm, prep_image 26 | 27 | torch.multiprocessing.set_sharing_strategy('file_system') 28 | 29 | 30 | class ScaleNet(torch.nn.Module): 31 | 32 | def __init__(self, dim): 33 | super().__init__() 34 | self.dim = dim 35 | self.net = torch.nn.Conv2d(dim, 1, 1) 36 | with torch.no_grad(): 37 | self.net.weight.copy_(self.net.weight * .1) 38 | self.net.bias.copy_(self.net.bias * .1) 39 | 40 | def forward(self, x): 41 | return torch.exp(self.net(x) + .1).clamp_min(.0001) 42 | 43 | 44 | class JBUFeatUp(pl.LightningModule): 45 | def __init__(self, 46 | model_type, 47 | activation_type, 48 | n_jitters, 49 | max_pad, 50 | max_zoom, 51 | kernel_size, 52 | final_size, 53 | lr, 54 | random_projection, 55 | predicted_uncertainty, 56 | crf_weight, 57 | filter_ent_weight, 58 | tv_weight, 59 | upsampler, 60 | downsampler, 61 | chkpt_dir, 62 | ): 63 | super().__init__() 64 | self.model_type = model_type 65 | self.activation_type = activation_type 66 | self.n_jitters = n_jitters 67 | self.max_pad = max_pad 68 | self.max_zoom = max_zoom 69 | self.kernel_size = kernel_size 70 | self.final_size = final_size 71 | self.lr = lr 72 | self.random_projection = random_projection 73 | self.predicted_uncertainty = predicted_uncertainty 74 | self.crf_weight = crf_weight 75 | self.filter_ent_weight = filter_ent_weight 76 | self.tv_weight = tv_weight 77 | self.chkpt_dir = chkpt_dir 78 | 79 | self.model, self.patch_size, self.dim = get_featurizer(model_type, activation_type, num_classes=1000) 80 | for p in self.model.parameters(): 81 | p.requires_grad = False 82 | self.model = torch.nn.Sequential(self.model, ChannelNorm(self.dim)) 83 | self.upsampler = get_upsampler(upsampler, self.dim) 84 | 85 | if downsampler == 'simple': 86 | self.downsampler = SimpleDownsampler(self.kernel_size, self.final_size) 87 | elif downsampler == 'attention': 88 | self.downsampler = AttentionDownsampler(self.dim, self.kernel_size, self.final_size, blur_attn=True) 89 | else: 90 | raise ValueError(f"Unknown downsampler {downsampler}") 91 | 92 | if self.predicted_uncertainty: 93 | self.scale_net = ScaleNet(self.dim) 94 | 95 | self.avg = RollingAvg(20) 96 | 97 | self.crf = SampledCRFLoss( 98 | alpha=.1, 99 | beta=.15, 100 | gamma=.005, 101 | w1=10.0, 102 | w2=3.0, 103 | shift=0.00, 104 | n_samples=1000) 105 | self.tv = TVLoss() 106 | 107 | self.automatic_optimization = False 108 | 109 | def forward(self, x): 110 | return self.upsampler(self.model(x)) 111 | 112 | def project(self, feats, proj): 113 | if proj is None: 114 | return feats 115 | else: 116 | return torch.einsum("bchw,bcd->bdhw", feats, proj) 117 | 118 | def training_step(self, batch, batch_idx): 119 | opt = self.optimizers() 120 | opt.zero_grad() 121 | 122 | with torch.no_grad(): 123 | if type(batch) == dict: 124 | img = batch['img'] 125 | else: 126 | img, _ = batch 127 | lr_feats = self.model(img) 128 | 129 | full_rec_loss = 0.0 130 | full_crf_loss = 0.0 131 | full_entropy_loss = 0.0 132 | full_tv_loss = 0.0 133 | full_total_loss = 0.0 134 | for i in range(self.n_jitters): 135 | hr_feats = self.upsampler(lr_feats, img) 136 | 137 | if hr_feats.shape[2] != img.shape[2]: 138 | hr_feats = torch.nn.functional.interpolate(hr_feats, img.shape[2:], mode="bilinear") 139 | 140 | with torch.no_grad(): 141 | transform_params = sample_transform( 142 | True, self.max_pad, self.max_zoom, img.shape[2], img.shape[3]) 143 | jit_img = apply_jitter(img, self.max_pad, transform_params) 144 | lr_jit_feats = self.model(jit_img) 145 | 146 | if self.random_projection is not None: 147 | proj = torch.randn(lr_feats.shape[0], 148 | lr_feats.shape[1], 149 | self.random_projection, device=lr_feats.device) 150 | proj /= proj.square().sum(1, keepdim=True).sqrt() 151 | else: 152 | proj = None 153 | 154 | hr_jit_feats = apply_jitter(hr_feats, self.max_pad, transform_params) 155 | proj_hr_feats = self.project(hr_jit_feats, proj) 156 | 157 | down_jit_feats = self.project(self.downsampler(hr_jit_feats, jit_img), proj) 158 | 159 | if self.predicted_uncertainty: 160 | scales = self.scale_net(lr_jit_feats) 161 | scale_factor = (1 / (2 * scales ** 2)) 162 | mse = (down_jit_feats - self.project(lr_jit_feats, proj)).square() 163 | rec_loss = (scale_factor * mse + scales.log()).mean() / self.n_jitters 164 | else: 165 | rec_loss = (self.project(lr_jit_feats, proj) - down_jit_feats).square().mean() / self.n_jitters 166 | 167 | full_rec_loss += rec_loss.item() 168 | 169 | if self.crf_weight > 0 and i == 0: 170 | crf_loss = self.crf(img, proj_hr_feats) 171 | full_crf_loss += crf_loss.item() 172 | else: 173 | crf_loss = 0.0 174 | 175 | if self.filter_ent_weight > 0.0: 176 | entropy_loss = entropy(self.downsampler.get_kernel()) 177 | full_entropy_loss += entropy_loss.item() 178 | else: 179 | entropy_loss = 0 180 | 181 | if self.tv_weight > 0 and i == 0: 182 | tv_loss = self.tv(proj_hr_feats.square().sum(1, keepdim=True)) 183 | full_tv_loss += tv_loss.item() 184 | else: 185 | tv_loss = 0.0 186 | 187 | loss = rec_loss + self.crf_weight * crf_loss + self.tv_weight * tv_loss - self.filter_ent_weight * entropy_loss 188 | full_total_loss += loss.item() 189 | self.manual_backward(loss) 190 | 191 | self.avg.add("loss/crf", full_crf_loss) 192 | self.avg.add("loss/ent", full_entropy_loss) 193 | self.avg.add("loss/tv", full_tv_loss) 194 | self.avg.add("loss/rec", full_rec_loss) 195 | self.avg.add("loss/total", full_total_loss) 196 | 197 | if self.global_step % 100 == 0: 198 | self.trainer.save_checkpoint(self.chkpt_dir[:-5] + '/' + self.chkpt_dir[:-5] + f'_{self.global_step}.ckpt') 199 | 200 | self.avg.logall(self.log) 201 | if self.global_step < 10: 202 | self.clip_gradients(opt, gradient_clip_val=.0001, gradient_clip_algorithm="norm") 203 | 204 | opt.step() 205 | 206 | return None 207 | 208 | def validation_step(self, batch, batch_idx): 209 | with torch.no_grad(): 210 | if self.trainer.is_global_zero and batch_idx == 0: 211 | 212 | if type(batch) == dict: 213 | img = batch['img'] 214 | else: 215 | img, _ = batch 216 | lr_feats = self.model(img) 217 | 218 | hr_feats = self.upsampler(lr_feats, img) 219 | 220 | if hr_feats.shape[2] != img.shape[2]: 221 | hr_feats = torch.nn.functional.interpolate(hr_feats, img.shape[2:], mode="bilinear") 222 | 223 | transform_params = sample_transform( 224 | True, self.max_pad, self.max_zoom, img.shape[2], img.shape[3]) 225 | jit_img = apply_jitter(img, self.max_pad, transform_params) 226 | lr_jit_feats = self.model(jit_img) 227 | 228 | if self.random_projection is not None: 229 | proj = torch.randn(lr_feats.shape[0], 230 | lr_feats.shape[1], 231 | self.random_projection, device=lr_feats.device) 232 | proj /= proj.square().sum(1, keepdim=True).sqrt() 233 | else: 234 | proj = None 235 | 236 | scales = self.scale_net(lr_jit_feats) 237 | 238 | writer = self.logger.experiment 239 | 240 | hr_jit_feats = apply_jitter(hr_feats, self.max_pad, transform_params) 241 | down_jit_feats = self.downsampler(hr_jit_feats, jit_img) 242 | 243 | [red_lr_feats], fit_pca = pca([lr_feats[0].unsqueeze(0)]) 244 | [red_hr_feats], _ = pca([hr_feats[0].unsqueeze(0)], fit_pca=fit_pca) 245 | [red_lr_jit_feats], _ = pca([lr_jit_feats[0].unsqueeze(0)], fit_pca=fit_pca) 246 | [red_hr_jit_feats], _ = pca([hr_jit_feats[0].unsqueeze(0)], fit_pca=fit_pca) 247 | [red_down_jit_feats], _ = pca([down_jit_feats[0].unsqueeze(0)], fit_pca=fit_pca) 248 | 249 | writer.add_image("viz/image", unnorm(img[0].unsqueeze(0))[0], self.global_step) 250 | writer.add_image("viz/lr_feats", red_lr_feats[0], self.global_step) 251 | writer.add_image("viz/hr_feats", red_hr_feats[0], self.global_step) 252 | writer.add_image("jit_viz/jit_image", unnorm(jit_img[0].unsqueeze(0))[0], self.global_step) 253 | writer.add_image("jit_viz/lr_jit_feats", red_lr_jit_feats[0], self.global_step) 254 | writer.add_image("jit_viz/hr_jit_feats", red_hr_jit_feats[0], self.global_step) 255 | writer.add_image("jit_viz/down_jit_feats", red_down_jit_feats[0], self.global_step) 256 | 257 | norm_scales = scales[0] 258 | norm_scales /= scales.max() 259 | writer.add_image("scales", norm_scales, self.global_step) 260 | writer.add_histogram("scales hist", scales, self.global_step) 261 | 262 | if isinstance(self.downsampler, SimpleDownsampler): 263 | writer.add_image( 264 | "down/filter", 265 | prep_image(self.downsampler.get_kernel().squeeze(), subtract_min=False), 266 | self.global_step) 267 | 268 | if isinstance(self.downsampler, AttentionDownsampler): 269 | writer.add_image( 270 | "down/att", 271 | prep_image(self.downsampler.forward_attention(hr_feats, None)[0]), 272 | self.global_step) 273 | writer.add_image( 274 | "down/w", 275 | prep_image(self.downsampler.w.clone().squeeze()), 276 | self.global_step) 277 | writer.add_image( 278 | "down/b", 279 | prep_image(self.downsampler.b.clone().squeeze()), 280 | self.global_step) 281 | 282 | writer.flush() 283 | 284 | def configure_optimizers(self): 285 | all_params = [] 286 | all_params.extend(list(self.downsampler.parameters())) 287 | all_params.extend(list(self.upsampler.parameters())) 288 | 289 | if self.predicted_uncertainty: 290 | all_params.extend(list(self.scale_net.parameters())) 291 | 292 | return torch.optim.NAdam(all_params, lr=self.lr) 293 | 294 | 295 | @hydra.main(config_path="configs", config_name="jbu_upsampler.yaml") 296 | def my_app(cfg: DictConfig) -> None: 297 | print(OmegaConf.to_yaml(cfg)) 298 | print(cfg.output_root) 299 | seed_everything(seed=0, workers=True) 300 | 301 | load_size = 224 302 | 303 | if cfg.model_type == "dinov2": 304 | final_size = 16 305 | kernel_size = 14 306 | else: 307 | final_size = 14 308 | kernel_size = 16 309 | 310 | name = (f"{cfg.model_type}_{cfg.upsampler_type}_" 311 | f"{cfg.dataset}_{cfg.downsampler_type}_" 312 | f"crf_{cfg.crf_weight}_tv_{cfg.tv_weight}" 313 | f"_ent_{cfg.filter_ent_weight}") 314 | 315 | log_dir = join(cfg.output_root, f"logs/jbu/{name}") 316 | chkpt_dir = join(cfg.output_root, f"checkpoints/jbu/{name}.ckpt") 317 | os.makedirs(log_dir, exist_ok=True) 318 | 319 | model = JBUFeatUp( 320 | model_type=cfg.model_type, 321 | activation_type=cfg.activation_type, 322 | n_jitters=cfg.n_jitters, 323 | max_pad=cfg.max_pad, 324 | max_zoom=cfg.max_zoom, 325 | kernel_size=kernel_size, 326 | final_size=final_size, 327 | lr=cfg.lr, 328 | random_projection=cfg.random_projection, 329 | predicted_uncertainty=cfg.outlier_detection, 330 | crf_weight=cfg.crf_weight, 331 | filter_ent_weight=cfg.filter_ent_weight, 332 | tv_weight=cfg.tv_weight, 333 | upsampler=cfg.upsampler_type, 334 | downsampler=cfg.downsampler_type, 335 | chkpt_dir=chkpt_dir 336 | ) 337 | 338 | transform = T.Compose([ 339 | T.Resize(load_size, InterpolationMode.BILINEAR), 340 | T.CenterCrop(load_size), 341 | T.ToTensor(), 342 | norm]) 343 | 344 | dataset = get_dataset( 345 | cfg.pytorch_data_dir, 346 | cfg.dataset, 347 | "train", 348 | transform=transform, 349 | target_transform=None, 350 | include_labels=False) 351 | 352 | loader = DataLoader( 353 | dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers) 354 | val_loader = DataLoader( 355 | SingleImageDataset(0, dataset, 1), 1, shuffle=False, num_workers=cfg.num_workers) 356 | 357 | tb_logger = TensorBoardLogger(log_dir, default_hp_metric=False) 358 | callbacks = [ModelCheckpoint(chkpt_dir[:-5], every_n_epochs=1)] 359 | 360 | trainer = Trainer( 361 | accelerator='gpu', 362 | strategy="ddp", 363 | devices=cfg.num_gpus, 364 | max_epochs=cfg.epochs, 365 | logger=tb_logger, 366 | val_check_interval=100, 367 | log_every_n_steps=10, 368 | callbacks=callbacks, 369 | reload_dataloaders_every_n_epochs=1, 370 | ) 371 | 372 | gc.collect() 373 | torch.cuda.empty_cache() 374 | gc.collect() 375 | 376 | trainer.fit(model, loader, val_loader) 377 | trainer.save_checkpoint(chkpt_dir) 378 | 379 | 380 | if __name__ == "__main__": 381 | my_app() 382 | -------------------------------------------------------------------------------- /featup/train_probes.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | import hydra 4 | import matplotlib.pyplot as plt 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.nn.functional as F 8 | from omegaconf import DictConfig 9 | from omegaconf import OmegaConf 10 | from pytorch_lightning import Trainer 11 | from pytorch_lightning import seed_everything 12 | from pytorch_lightning.loggers import TensorBoardLogger 13 | from pytorch_lightning.utilities.seed import seed_everything 14 | from torch.utils.data import DataLoader 15 | from torchmetrics.classification import Accuracy, JaccardIndex 16 | 17 | from featup.datasets.COCO import Coco 18 | from featup.datasets.EmbeddingFile import EmbeddingFile 19 | from featup.losses import ScaleAndShiftInvariantLoss 20 | from featup.util import pca 21 | from featup.util import remove_axes 22 | 23 | 24 | def tensor_correlation(a, b): 25 | return torch.einsum("nchw,ncij->nhwij", a, b) 26 | 27 | 28 | def sample(t: torch.Tensor, coords: torch.Tensor): 29 | return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True) 30 | 31 | 32 | class LitPrototypeEvaluator(pl.LightningModule): 33 | def __init__(self, task, n_dim): 34 | super().__init__() 35 | self.task = task 36 | self.n_dim = n_dim 37 | 38 | if self.task == 'seg': 39 | n_classes = 27 40 | elif self.task == 'depth': 41 | n_classes = 1 42 | 43 | self.midas = torch.hub.load('intel-isl/MiDaS', 'MiDaS_small').cuda() 44 | self.midas.eval() 45 | self.midas_loss = ScaleAndShiftInvariantLoss() 46 | 47 | self.mse = 0 48 | self.ssil = 0 49 | self.steps = 0 50 | 51 | self.prototypes_buff = self.register_buffer("prototypes", torch.zeros(n_classes, n_dim)) 52 | self.classifier = torch.nn.Conv2d(n_dim, n_classes, 1) 53 | 54 | self.prot_acc_metric = Accuracy(num_classes=n_classes, task="multiclass") 55 | self.prot_acc_buff = self.register_buffer("prot_acc", torch.tensor(0.0)) 56 | self.prot_iou_metric = JaccardIndex(num_classes=n_classes, task="multiclass") 57 | self.prot_iou_buff = self.register_buffer("prot_iou", torch.tensor(0.0)) 58 | 59 | self.linear_acc_metric = Accuracy(num_classes=n_classes, task="multiclass") 60 | self.linear_acc_buff = self.register_buffer("linear_acc", torch.tensor(0.0)) 61 | self.linear_iou_metric = JaccardIndex(num_classes=n_classes, task="multiclass") 62 | self.linear_iou_buff = self.register_buffer("linear_iou", torch.tensor(0.0)) 63 | 64 | self.ce = torch.nn.CrossEntropyLoss() 65 | 66 | def get_prototypes(self, feats): 67 | b, c, h, w = feats.shape 68 | k = self.prototypes.shape[0] 69 | matches = torch.einsum("kc,bchw->kbhw", F.normalize(self.prototypes, dim=1), F.normalize(feats, dim=1)) \ 70 | .reshape(k, -1).argmax(0) 71 | return self.prototypes[matches].reshape(b, h, w, c).permute(0, 3, 1, 2) 72 | 73 | def training_step(self, batch, batch_idx): 74 | feats, label = batch 75 | b, c, h, w = feats.shape 76 | 77 | small_labels = F.interpolate( 78 | label.unsqueeze(1).to(torch.float32), 79 | size=(feats.shape[2], feats.shape[3])).to(torch.int64) 80 | 81 | linear_preds = self.classifier(feats) 82 | 83 | if self.task == 'seg': 84 | flat_labels = small_labels.permute(0, 2, 3, 1).reshape(b * h * w) 85 | flat_linear_preds = linear_preds.permute(0, 2, 3, 1).reshape(b * h * w, -1) 86 | 87 | selected = flat_labels > -1 88 | linear_loss = self.ce( 89 | flat_linear_preds[selected], 90 | flat_labels[selected]) 91 | loss = linear_loss 92 | self.log("linear_loss", linear_loss) 93 | self.log("loss", loss) 94 | 95 | for l in range(self.n_classes): 96 | self.prototypes[l] += feats.permute(0, 2, 3, 1).reshape(b * h * w, -1)[flat_labels == l].sum(dim=0) 97 | 98 | if self.global_step % 10 == 1 and self.trainer.is_global_zero: 99 | with torch.no_grad(): 100 | prots = self.get_prototypes(feats) 101 | prot_loss = -(F.normalize(feats, dim=1) * F.normalize(prots, dim=1)).sum(1).mean() 102 | self.logger.experiment.add_scalar("prot_loss", prot_loss, self.global_step) 103 | 104 | elif self.task == 'depth': 105 | loss = self.midas_loss(linear_preds.squeeze(), small_labels.squeeze(), 106 | torch.ones_like(linear_preds.squeeze())) 107 | self.log('loss', loss) 108 | 109 | if self.global_step % 200 == 0 and self.trainer.is_global_zero: 110 | n_images = 5 111 | fig, axes = plt.subplots(4, n_images, figsize=(4 * n_images, 5 * 5)) 112 | 113 | prot_preds = torch.einsum("bchw,kc->bkhw", 114 | F.normalize(feats, dim=1), 115 | F.normalize(self.prototypes, dim=1, eps=1e-10)) 116 | 117 | colorize = Coco.colorize_label if self.task == 'seg' else lambda x: x.detach().cpu() 118 | for i in range(n_images): 119 | feats_pca = pca([feats])[0][0][i] 120 | axes[0, i].imshow(feats_pca) 121 | axes[1, i].imshow(colorize(label[i])) 122 | if self.task == 'depth': 123 | axes[2, i].imshow(colorize(linear_preds[i][0])) 124 | axes[3, i].imshow(colorize(prot_preds[i][0])) 125 | elif self.task == 'seg': 126 | axes[2, i].imshow(colorize(linear_preds.argmax(1)[i])) 127 | axes[3, i].imshow(colorize(prot_preds.argmax(1)[i])) 128 | 129 | plt.tight_layout() 130 | remove_axes(axes) 131 | self.logger.experiment.add_figure('predictions', fig, self.global_step) 132 | 133 | return loss 134 | 135 | def validation_step(self, batch, batch_idx): 136 | with torch.no_grad(): 137 | feats, label = batch 138 | 139 | if self.task == 'seg': 140 | label = F.interpolate( 141 | label.to(torch.float32).unsqueeze(1), size=(224, 224)).to(torch.int64).squeeze(1) 142 | 143 | prot_preds = torch.einsum( 144 | "bchw,kc->bkhw", 145 | F.normalize(feats, dim=1), 146 | F.normalize(self.prototypes, dim=1, eps=1e-10)).argmax(1, keepdim=True) 147 | linear_preds = self.classifier(feats).argmax(1, keepdim=True) 148 | 149 | b, h, w = label.shape 150 | flat_labels = label.flatten() 151 | selected = flat_labels > -1 152 | flat_labels = flat_labels[selected] 153 | 154 | flat_prot_preds = F.interpolate( 155 | prot_preds.to(torch.float32), (h, w)).to(torch.int64).flatten()[selected] 156 | self.prot_acc_metric.update(flat_prot_preds, flat_labels) 157 | self.prot_iou_metric.update(flat_prot_preds, flat_labels) 158 | 159 | flat_linear_preds = F.interpolate( 160 | linear_preds.to(torch.float32), (h, w)).to(torch.int64).flatten()[selected] 161 | self.linear_acc_metric.update(flat_linear_preds, flat_labels) 162 | self.linear_iou_metric.update(flat_linear_preds, flat_labels) 163 | 164 | elif self.task == 'depth': 165 | linear_preds = self.classifier(feats) 166 | small_labels = F.interpolate( 167 | label.unsqueeze(1).to(torch.float32), 168 | size=(feats.shape[2], feats.shape[3])).to(torch.int64) 169 | mse = (small_labels - linear_preds).pow(2).mean() 170 | midas_l = self.midas_loss(linear_preds.squeeze(), small_labels.squeeze(), 171 | torch.ones_like(linear_preds.squeeze())) 172 | self.mse += mse.item() 173 | self.ssil += midas_l.item() 174 | 175 | self.steps += 1 176 | 177 | return None 178 | 179 | def validation_epoch_end(self, outputs): 180 | self.prot_acc = self.prot_acc_metric.compute() 181 | self.prot_iou = self.prot_iou_metric.compute() 182 | self.linear_acc = self.linear_acc_metric.compute() 183 | self.linear_iou = self.linear_iou_metric.compute() 184 | 185 | def configure_optimizers(self): 186 | return torch.optim.Adam(self.classifier.parameters(), lr=5e-3) 187 | 188 | 189 | @hydra.main(config_path="configs", config_name="train_probe.yaml") 190 | def my_app(cfg: DictConfig) -> None: 191 | print(OmegaConf.to_yaml(cfg)) 192 | print(cfg.output_root) 193 | seed_everything(seed=0, workers=True) 194 | 195 | log_dir = f"../probes/{cfg.task}-probe" 196 | chkpt_dir = f"../probes/{cfg.task}-probe-{cfg.model_type}.ckpt" 197 | 198 | emb_root = join(cfg.pytorch_data_dir, "cocostuff", "embedding", cfg.model_type) 199 | 200 | train_dataset = EmbeddingFile(join(emb_root, "train", f"embeddings_{cfg.activation_type}.npz")) 201 | train_loader = DataLoader(train_dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers) 202 | 203 | val_dataset = EmbeddingFile(join(emb_root, "val", f"embeddings_{cfg.activation_type}.npz")) 204 | val_loader = DataLoader(val_dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers) 205 | 206 | evaluator = LitPrototypeEvaluator(cfg.task, train_dataset.dim()) 207 | tb_logger = TensorBoardLogger(log_dir, default_hp_metric=False) 208 | 209 | trainer = Trainer( 210 | accelerator='gpu', 211 | devices=1, 212 | max_epochs=cfg.epochs, 213 | logger=tb_logger, 214 | log_every_n_steps=100, 215 | reload_dataloaders_every_n_epochs=1, 216 | check_val_every_n_epoch=10, 217 | ) 218 | 219 | trainer.fit(evaluator, train_loader, val_loader) 220 | 221 | trainer.save_checkpoint(chkpt_dir) 222 | 223 | result = { 224 | "Prototype Accuracy": float(evaluator.prot_acc), 225 | "Prototype mIoU": float(evaluator.prot_iou), 226 | "Linear Accuracy": float(evaluator.linear_acc), 227 | "Linear mIoU": float(evaluator.linear_iou), 228 | "Model": cfg.model_type 229 | } 230 | print(result) 231 | 232 | 233 | if __name__ == "__main__": 234 | my_app() 235 | -------------------------------------------------------------------------------- /featup/upsamplers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from featup.adaptive_conv_cuda.adaptive_conv import AdaptiveConv 8 | 9 | 10 | class SimpleImplicitFeaturizer(torch.nn.Module): 11 | 12 | def __init__(self, n_freqs=20): 13 | super().__init__() 14 | self.n_freqs = n_freqs 15 | self.dim_multiplier = 2 16 | 17 | def forward(self, original_image): 18 | b, c, h, w = original_image.shape 19 | grid_h = torch.linspace(-1, 1, h, device=original_image.device) 20 | grid_w = torch.linspace(-1, 1, w, device=original_image.device) 21 | feats = torch.cat([t.unsqueeze(0) for t in torch.meshgrid([grid_h, grid_w])]).unsqueeze(0) 22 | feats = torch.broadcast_to(feats, (b, feats.shape[1], h, w)) 23 | 24 | feat_list = [feats] 25 | feats = torch.cat(feat_list, dim=1).unsqueeze(1) 26 | freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=original_image.device)) \ 27 | .reshape(1, self.n_freqs, 1, 1, 1) 28 | feats = (feats * freqs) 29 | 30 | feats = feats.reshape(b, self.n_freqs * self.dim_multiplier, h, w) 31 | 32 | all_feats = [torch.sin(feats), torch.cos(feats), original_image] 33 | 34 | return torch.cat(all_feats, dim=1) 35 | 36 | 37 | class IFA(torch.nn.Module): 38 | 39 | def __init__(self, feat_dim, num_scales=20): 40 | super().__init__() 41 | self.scales = 2 * torch.exp(torch.tensor(torch.arange(1, num_scales + 1))) 42 | self.feat_dim = feat_dim 43 | self.sin_feats = SimpleImplicitFeaturizer() 44 | self.mlp = nn.Sequential( 45 | nn.Conv2d(feat_dim + (num_scales * 4) + 2, feat_dim, 1), 46 | nn.BatchNorm2d(feat_dim), 47 | nn.LeakyReLU(), 48 | nn.Conv2d(feat_dim, feat_dim, 1), 49 | ) 50 | 51 | def forward(self, source, guidance): 52 | b, c, h, w = source.shape 53 | up_source = F.interpolate(source, (h * 2, w * 2), mode="nearest") 54 | assert h == w 55 | lr_cord = torch.linspace(0, h, steps=h, device=source.device) 56 | hr_cord = torch.linspace(0, h, steps=2 * h, device=source.device) 57 | lr_coords = torch.cat([x.unsqueeze(0) for x in torch.meshgrid(lr_cord, lr_cord)], dim=0).unsqueeze(0) 58 | hr_coords = torch.cat([x.unsqueeze(0) for x in torch.meshgrid(hr_cord, hr_cord)], dim=0).unsqueeze(0) 59 | up_lr_coords = F.interpolate(lr_coords, (h * 2, w * 2), mode="nearest") 60 | coord_diff = up_lr_coords - hr_coords 61 | coord_diff_feats = self.sin_feats(coord_diff) 62 | c2 = coord_diff_feats.shape[1] 63 | bcast_coord_feats = torch.broadcast_to(coord_diff_feats, (b, c2, h * 2, w * 2)) 64 | return self.mlp(torch.cat([up_source, bcast_coord_feats], dim=1)) # + up_source 65 | 66 | 67 | class SAPAModule(nn.Module): 68 | def __init__(self, dim_y, dim_x=None, 69 | up_factor=2, up_kernel_size=5, embedding_dim=64, 70 | qkv_bias=True, norm=nn.LayerNorm): 71 | super().__init__() 72 | dim_x = dim_x if dim_x is not None else dim_y 73 | 74 | self.up_factor = up_factor 75 | self.up_kernel_size = up_kernel_size 76 | self.embedding_dim = embedding_dim 77 | 78 | self.norm_y = norm(dim_y) 79 | self.norm_x = norm(dim_x) 80 | 81 | self.q = nn.Linear(dim_y, embedding_dim, bias=qkv_bias) 82 | self.k = nn.Linear(dim_x, embedding_dim, bias=qkv_bias) 83 | 84 | self.apply(self._init_weights) 85 | 86 | def forward(self, y, x): 87 | y = y.permute(0, 2, 3, 1).contiguous() 88 | x = x.permute(0, 2, 3, 1).contiguous() 89 | y = self.norm_y(y) 90 | x_ = self.norm_x(x) 91 | 92 | q = self.q(y) 93 | k = self.k(x_) 94 | 95 | return self.attention(q, k, x).permute(0, 3, 1, 2).contiguous() 96 | 97 | def attention(self, q, k, v): 98 | from sapa import sim, atn 99 | 100 | attn = F.softmax(sim(q, k, self.up_kernel_size, self.up_factor), dim=-1) 101 | return atn(attn, v, self.up_kernel_size, self.up_factor) 102 | 103 | def _init_weights(self, m): 104 | from timm.models.layers import trunc_normal_ 105 | 106 | if isinstance(m, nn.Linear): 107 | trunc_normal_(m.weight, std=.02) 108 | if isinstance(m, nn.Linear) and m.bias is not None: 109 | nn.init.constant_(m.bias, 0) 110 | elif isinstance(m, nn.LayerNorm): 111 | nn.init.constant_(m.bias, 0) 112 | nn.init.constant_(m.weight, 1.0) 113 | elif isinstance(m, nn.Conv2d): 114 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 115 | fan_out //= m.groups 116 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 117 | if m.bias is not None: 118 | m.bias.data.zero_() 119 | 120 | 121 | class SAPAUpsampler(torch.nn.Module): 122 | def __init__(self, dim_x, *args, **kwargs): 123 | super().__init__(*args, **kwargs) 124 | self.up1 = SAPAModule(dim_x=dim_x, dim_y=3) 125 | self.up2 = SAPAModule(dim_x=dim_x, dim_y=3) 126 | self.up3 = SAPAModule(dim_x=dim_x, dim_y=3) 127 | self.up4 = SAPAModule(dim_x=dim_x, dim_y=3) 128 | 129 | def adapt_guidance(self, source, guidance): 130 | _, _, h, w = source.shape 131 | small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2)) 132 | return small_guidance 133 | 134 | def forward(self, source, guidance): 135 | source_2 = self.up1(self.adapt_guidance(source, guidance), source) 136 | source_4 = self.up2(self.adapt_guidance(source_2, guidance), source_2) 137 | source_8 = self.up3(self.adapt_guidance(source_4, guidance), source_4) 138 | source_16 = self.up4(self.adapt_guidance(source_8, guidance), source_8) 139 | return source_16 140 | 141 | 142 | class CarafeUpsampler(torch.nn.Module): 143 | 144 | def __init__(self, dim, kernel_size, *args, **kwargs): 145 | super().__init__(*args, **kwargs) 146 | from mmcv.ops import CARAFEPack 147 | self.up1 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) 148 | self.up2 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) 149 | self.up3 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) 150 | self.up4 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) 151 | 152 | def forward(self, source, guidance): 153 | source_2 = self.up1(source) 154 | source_4 = self.up2(source_2) 155 | source_8 = self.up3(source_4) 156 | source_16 = self.up4(source_8) 157 | return source_16 158 | 159 | 160 | class LayeredResizeConv(torch.nn.Module): 161 | 162 | def __init__(self, dim, kernel_size, *args, **kwargs): 163 | super().__init__(*args, **kwargs) 164 | self.conv1 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") 165 | self.conv2 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") 166 | self.conv3 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") 167 | self.conv4 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") 168 | 169 | def apply_conv(self, source, guidance, conv, activation): 170 | big_source = F.interpolate(source, scale_factor=2, mode="bilinear") 171 | _, _, h, w = big_source.shape 172 | small_guidance = F.interpolate(guidance, (h, w), mode="bilinear") 173 | output = activation(conv(torch.cat([big_source, small_guidance], dim=1))) 174 | return big_source + output 175 | 176 | def forward(self, source, guidance): 177 | source_2 = self.apply_conv(source, guidance, self.conv1, F.relu) 178 | source_4 = self.apply_conv(source_2, guidance, self.conv2, F.relu) 179 | source_8 = self.apply_conv(source_4, guidance, self.conv3, F.relu) 180 | source_16 = self.apply_conv(source_8, guidance, self.conv4, lambda x: x) 181 | return source_16 182 | 183 | 184 | class JBULearnedRange(torch.nn.Module): 185 | 186 | def __init__(self, guidance_dim, feat_dim, key_dim, scale=2, radius=3): 187 | super().__init__() 188 | self.scale = scale 189 | self.radius = radius 190 | self.diameter = self.radius * 2 + 1 191 | 192 | self.guidance_dim = guidance_dim 193 | self.key_dim = key_dim 194 | self.feat_dim = feat_dim 195 | 196 | self.range_temp = nn.Parameter(torch.tensor(0.0)) 197 | self.range_proj = torch.nn.Sequential( 198 | torch.nn.Conv2d(guidance_dim, key_dim, 1, 1), 199 | torch.nn.GELU(), 200 | torch.nn.Dropout2d(.1), 201 | torch.nn.Conv2d(key_dim, key_dim, 1, 1), 202 | ) 203 | 204 | self.fixup_proj = torch.nn.Sequential( 205 | torch.nn.Conv2d(guidance_dim + self.diameter ** 2, self.diameter ** 2, 1, 1), 206 | torch.nn.GELU(), 207 | torch.nn.Dropout2d(.1), 208 | torch.nn.Conv2d(self.diameter ** 2, self.diameter ** 2, 1, 1), 209 | ) 210 | 211 | self.sigma_spatial = nn.Parameter(torch.tensor(1.0)) 212 | 213 | def get_range_kernel(self, x): 214 | GB, GC, GH, GW = x.shape 215 | proj_x = self.range_proj(x) 216 | proj_x_padded = F.pad(proj_x, pad=[self.radius] * 4, mode='reflect') 217 | queries = torch.nn.Unfold(self.diameter)(proj_x_padded) \ 218 | .reshape((GB, self.key_dim, self.diameter * self.diameter, GH, GW)) \ 219 | .permute(0, 1, 3, 4, 2) 220 | pos_temp = self.range_temp.exp().clamp_min(1e-4).clamp_max(1e4) 221 | return F.softmax(pos_temp * torch.einsum("bchwp,bchw->bphw", queries, proj_x), dim=1) 222 | 223 | def get_spatial_kernel(self, device): 224 | dist_range = torch.linspace(-1, 1, self.diameter, device=device) 225 | x, y = torch.meshgrid(dist_range, dist_range) 226 | patch = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0) 227 | return torch.exp(- patch.square().sum(0) / (2 * self.sigma_spatial ** 2)) \ 228 | .reshape(1, self.diameter * self.diameter, 1, 1) 229 | 230 | def forward(self, source, guidance): 231 | GB, GC, GH, GW = guidance.shape 232 | SB, SC, SH, SQ = source.shape 233 | assert (SB == GB) 234 | 235 | spatial_kernel = self.get_spatial_kernel(source.device) 236 | range_kernel = self.get_range_kernel(guidance) 237 | 238 | combined_kernel = range_kernel * spatial_kernel 239 | combined_kernel /= combined_kernel.sum(1, keepdim=True).clamp(1e-7) 240 | 241 | combined_kernel += .1 * self.fixup_proj(torch.cat([combined_kernel, guidance], dim=1)) 242 | combined_kernel = combined_kernel.permute(0, 2, 3, 1) \ 243 | .reshape(GB, GH, GW, self.diameter, self.diameter) 244 | 245 | hr_source = torch.nn.Upsample((GH, GW), mode='bicubic', align_corners=False)(source) 246 | hr_source_padded = F.pad(hr_source, pad=[self.radius] * 4, mode='reflect') 247 | 248 | # (B C, H+Pad, W+Pad) x (B, H, W, KH, KW) -> BCHW 249 | result = AdaptiveConv.apply(hr_source_padded, combined_kernel) 250 | return result 251 | 252 | 253 | class JBUStack(torch.nn.Module): 254 | 255 | def __init__(self, feat_dim, *args, **kwargs): 256 | super().__init__(*args, **kwargs) 257 | self.up1 = JBULearnedRange(3, feat_dim, 32, radius=3) 258 | self.up2 = JBULearnedRange(3, feat_dim, 32, radius=3) 259 | self.up3 = JBULearnedRange(3, feat_dim, 32, radius=3) 260 | self.up4 = JBULearnedRange(3, feat_dim, 32, radius=3) 261 | self.fixup_proj = torch.nn.Sequential( 262 | torch.nn.Dropout2d(0.2), 263 | torch.nn.Conv2d(feat_dim, feat_dim, kernel_size=1)) 264 | 265 | def upsample(self, source, guidance, up): 266 | _, _, h, w = source.shape 267 | small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2)) 268 | upsampled = up(source, small_guidance) 269 | return upsampled 270 | 271 | def forward(self, source, guidance): 272 | source_2 = self.upsample(source, guidance, self.up1) 273 | source_4 = self.upsample(source_2, guidance, self.up2) 274 | source_8 = self.upsample(source_4, guidance, self.up3) 275 | source_16 = self.upsample(source_8, guidance, self.up4) 276 | return self.fixup_proj(source_16) * 0.1 + source_16 277 | 278 | 279 | class Bilinear(torch.nn.Module): 280 | 281 | def __init__(self, *args, **kwargs): 282 | super().__init__(*args, **kwargs) 283 | 284 | def forward(self, feats, img): 285 | _, _, h, w = img.shape 286 | return F.interpolate(feats, (h, w), mode="bilinear") 287 | 288 | 289 | def get_upsampler(upsampler, dim): 290 | if upsampler == 'bilinear': 291 | return Bilinear() 292 | elif upsampler == 'jbu_stack': 293 | return JBUStack(dim) 294 | elif upsampler == 'resize_conv': 295 | return LayeredResizeConv(dim, 1) 296 | elif upsampler == 'carafe': 297 | return CarafeUpsampler(dim, 1) 298 | elif upsampler == 'sapa': 299 | return SAPAUpsampler(dim_x=dim) 300 | elif upsampler == 'ifa': 301 | return IFA(dim) 302 | else: 303 | raise ValueError(f"Unknown upsampler {upsampler}") 304 | -------------------------------------------------------------------------------- /featup/util.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torchvision.transforms as T 4 | import numpy as np 5 | from sklearn.decomposition import PCA 6 | import torch.nn.functional as F 7 | from collections import defaultdict, deque 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class RollingAvg: 13 | 14 | def __init__(self, length): 15 | self.length = length 16 | self.metrics = defaultdict(lambda: deque(maxlen=self.length)) 17 | 18 | def add(self, name, metric): 19 | self.metrics[name].append(metric) 20 | 21 | def get(self, name): 22 | return torch.tensor(list(self.metrics[name])).mean() 23 | 24 | def logall(self, log_func): 25 | for k in self.metrics.keys(): 26 | log_func(k, self.get(k)) 27 | 28 | 29 | def _remove_axes(ax): 30 | ax.xaxis.set_major_formatter(plt.NullFormatter()) 31 | ax.yaxis.set_major_formatter(plt.NullFormatter()) 32 | ax.set_xticks([]) 33 | ax.set_yticks([]) 34 | 35 | 36 | def remove_axes(axes): 37 | if len(axes.shape) == 2: 38 | for ax1 in axes: 39 | for ax in ax1: 40 | _remove_axes(ax) 41 | else: 42 | for ax in axes: 43 | _remove_axes(ax) 44 | 45 | 46 | class UnNormalize(object): 47 | def __init__(self, mean, std): 48 | self.mean = mean 49 | self.std = std 50 | 51 | def __call__(self, image): 52 | image2 = torch.clone(image) 53 | if len(image2.shape) == 4: 54 | # batched 55 | image2 = image2.permute(1, 0, 2, 3) 56 | for t, m, s in zip(image2, self.mean, self.std): 57 | t.mul_(s).add_(m) 58 | return image2.permute(1, 0, 2, 3) 59 | 60 | 61 | norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 62 | unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 63 | 64 | midas_norm = T.Normalize([0.5] * 3, [0.5] * 3) 65 | midas_unnorm = UnNormalize([0.5] * 3, [0.5] * 3) 66 | 67 | 68 | class ToTargetTensor(object): 69 | def __call__(self, target): 70 | return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) 71 | 72 | 73 | def show_heatmap(ax, 74 | image, 75 | heatmap, 76 | cmap="bwr", 77 | color=False, 78 | center=False, 79 | show_negative=False, 80 | cax=None, 81 | vmax=None): 82 | frame = [] 83 | 84 | if color: 85 | frame.append(ax.imshow(image)) 86 | else: 87 | bw = np.dot(np.array(image)[..., :3] / 255, [0.2989, 0.5870, 0.1140]) 88 | bw = np.ones_like(image) * np.expand_dims(bw, -1) 89 | frame.append(ax.imshow(bw)) 90 | 91 | if center: 92 | heatmap -= heatmap.mean() 93 | 94 | if not show_negative: 95 | heatmap = heatmap.clamp_min(0) 96 | 97 | heatmap = F.interpolate(heatmap.unsqueeze(0).unsqueeze(0), (image.shape[0], image.shape[1])) \ 98 | .squeeze(0).squeeze(0) 99 | 100 | if vmax is None: 101 | vmax = np.abs(heatmap).max() 102 | 103 | hm = ax.imshow(heatmap, alpha=.5, cmap=cmap, vmax=vmax, vmin=-vmax) 104 | if cax is not None: 105 | plt.colorbar(hm, cax=cax, orientation='vertical') 106 | 107 | frame.extend([hm]) 108 | return frame 109 | 110 | 111 | def implicit_feats(original_image, input_size, color_feats): 112 | n_freqs = 20 113 | grid = torch.linspace(-1, 1, input_size, device=original_image.device) 114 | feats = torch.cat([t.unsqueeze(0) for t in torch.meshgrid([grid, grid])]).unsqueeze(0) 115 | 116 | if color_feats: 117 | feat_list = [feats, original_image] 118 | dim_multiplier = 5 119 | else: 120 | feat_list = [feats] 121 | dim_multiplier = 2 122 | 123 | feats = torch.cat(feat_list, dim=1) 124 | freqs = torch.exp(torch.linspace(-2, 10, n_freqs, device=original_image.device)) \ 125 | .reshape(n_freqs, 1, 1, 1) 126 | feats = (feats * freqs).reshape(1, n_freqs * dim_multiplier, input_size, input_size) 127 | 128 | if color_feats: 129 | all_feats = [torch.sin(feats), torch.cos(feats), original_image] 130 | else: 131 | all_feats = [torch.sin(feats), torch.cos(feats)] 132 | return torch.cat(all_feats, dim=1) 133 | 134 | 135 | def load_hr_emb(original_image, model_path, color_feats=True): 136 | model = torch.load(model_path, map_location="cpu") 137 | hr_model = model["model"].cuda().eval() 138 | unprojector = model["unprojector"].cuda().eval() 139 | 140 | with torch.no_grad(): 141 | h, w = original_image.shape[2:] 142 | assert h == w 143 | feats = implicit_feats(original_image, h, color_feats).cuda() 144 | hr_feats = hr_model(feats) 145 | hr_feats = unprojector(hr_feats.detach().cpu()) 146 | 147 | return hr_feats 148 | 149 | 150 | def generate_subset(n, batch): 151 | np.random.seed(0) 152 | return np.random.permutation(n)[:batch] 153 | 154 | 155 | class TorchPCA(object): 156 | 157 | def __init__(self, n_components): 158 | self.n_components = n_components 159 | 160 | def fit(self, X): 161 | self.mean_ = X.mean(dim=0) 162 | unbiased = X - self.mean_.unsqueeze(0) 163 | U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4) 164 | self.components_ = V.T 165 | self.singular_values_ = S 166 | return self 167 | 168 | def transform(self, X): 169 | t0 = X - self.mean_.unsqueeze(0) 170 | projected = t0 @ self.components_.T 171 | return projected 172 | 173 | 174 | def pca(image_feats_list, dim=3, fit_pca=None, use_torch_pca=True, max_samples=None): 175 | device = image_feats_list[0].device 176 | 177 | def flatten(tensor, target_size=None): 178 | if target_size is not None and fit_pca is None: 179 | tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear") 180 | B, C, H, W = tensor.shape 181 | return tensor.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() 182 | 183 | if len(image_feats_list) > 1 and fit_pca is None: 184 | target_size = image_feats_list[0].shape[2] 185 | else: 186 | target_size = None 187 | 188 | flattened_feats = [] 189 | for feats in image_feats_list: 190 | flattened_feats.append(flatten(feats, target_size)) 191 | x = torch.cat(flattened_feats, dim=0) 192 | 193 | # Subsample the data if max_samples is set and the number of samples exceeds max_samples 194 | if max_samples is not None and x.shape[0] > max_samples: 195 | indices = torch.randperm(x.shape[0])[:max_samples] 196 | x = x[indices] 197 | 198 | if fit_pca is None: 199 | if use_torch_pca: 200 | fit_pca = TorchPCA(n_components=dim).fit(x) 201 | else: 202 | fit_pca = PCA(n_components=dim).fit(x) 203 | 204 | reduced_feats = [] 205 | for feats in image_feats_list: 206 | x_red = fit_pca.transform(flatten(feats)) 207 | if isinstance(x_red, np.ndarray): 208 | x_red = torch.from_numpy(x_red) 209 | x_red -= x_red.min(dim=0, keepdim=True).values 210 | x_red /= x_red.max(dim=0, keepdim=True).values 211 | B, C, H, W = feats.shape 212 | reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) 213 | 214 | return reduced_feats, fit_pca 215 | 216 | 217 | class PCAUnprojector(nn.Module): 218 | 219 | def __init__(self, feats, dim, device, use_torch_pca=False, **kwargs): 220 | super().__init__() 221 | self.dim = dim 222 | 223 | if feats is not None: 224 | self.original_dim = feats.shape[1] 225 | else: 226 | self.original_dim = kwargs["original_dim"] 227 | 228 | if self.dim != self.original_dim: 229 | if feats is not None: 230 | sklearn_pca = pca([feats], dim=dim, use_torch_pca=use_torch_pca)[1] 231 | 232 | # Register tensors as buffers 233 | self.register_buffer('components_', 234 | torch.tensor(sklearn_pca.components_, device=device, dtype=feats.dtype)) 235 | self.register_buffer('singular_values_', 236 | torch.tensor(sklearn_pca.singular_values_, device=device, dtype=feats.dtype)) 237 | self.register_buffer('mean_', torch.tensor(sklearn_pca.mean_, device=device, dtype=feats.dtype)) 238 | else: 239 | self.register_buffer('components_', kwargs["components_"].t()) 240 | self.register_buffer('singular_values_', kwargs["singular_values_"]) 241 | self.register_buffer('mean_', kwargs["mean_"]) 242 | 243 | else: 244 | print("PCAUnprojector will not transform data") 245 | 246 | def forward(self, red_feats): 247 | if self.dim == self.original_dim: 248 | return red_feats 249 | else: 250 | b, c, h, w = red_feats.shape 251 | red_feats_reshaped = red_feats.permute(0, 2, 3, 1).reshape(b * h * w, c) 252 | unprojected = (red_feats_reshaped @ self.components_) + self.mean_.unsqueeze(0) 253 | return unprojected.reshape(b, h, w, self.original_dim).permute(0, 3, 1, 2) 254 | 255 | def project(self, feats): 256 | if self.dim == self.original_dim: 257 | return feats 258 | else: 259 | b, c, h, w = feats.shape 260 | feats_reshaped = feats.permute(0, 2, 3, 1).reshape(b * h * w, c) 261 | t0 = feats_reshaped - self.mean_.unsqueeze(0).to(feats.device) 262 | projected = t0 @ self.components_.t().to(feats.device) 263 | return projected.reshape(b, h, w, self.dim).permute(0, 3, 1, 2) 264 | 265 | 266 | def prep_image(t, subtract_min=True): 267 | if subtract_min: 268 | t -= t.min() 269 | t /= t.max() 270 | t = (t * 255).clamp(0, 255).to(torch.uint8) 271 | 272 | if len(t.shape) == 2: 273 | t = t.unsqueeze(0) 274 | 275 | return t 276 | -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torchvision.transforms as T 4 | from PIL import Image 5 | import gradio as gr 6 | from featup.util import norm, unnorm, pca, remove_axes 7 | from pytorch_lightning import seed_everything 8 | import os 9 | import requests 10 | import os 11 | import csv 12 | 13 | def plot_feats(image, lr, hr): 14 | assert len(image.shape) == len(lr.shape) == len(hr.shape) == 3 15 | seed_everything(0) 16 | [lr_feats_pca, hr_feats_pca], _ = pca([lr.unsqueeze(0), hr.unsqueeze(0)], dim=9) 17 | fig, ax = plt.subplots(3, 3, figsize=(15, 15)) 18 | ax[0, 0].imshow(image.permute(1, 2, 0).detach().cpu()) 19 | ax[1, 0].imshow(image.permute(1, 2, 0).detach().cpu()) 20 | ax[2, 0].imshow(image.permute(1, 2, 0).detach().cpu()) 21 | 22 | ax[0, 0].set_title("Image", fontsize=22) 23 | ax[0, 1].set_title("Original", fontsize=22) 24 | ax[0, 2].set_title("Upsampled Features", fontsize=22) 25 | 26 | ax[0, 1].imshow(lr_feats_pca[0, :3].permute(1, 2, 0).detach().cpu()) 27 | ax[0, 0].set_ylabel("PCA Components 1-3", fontsize=22) 28 | ax[0, 2].imshow(hr_feats_pca[0, :3].permute(1, 2, 0).detach().cpu()) 29 | 30 | ax[1, 1].imshow(lr_feats_pca[0, 3:6].permute(1, 2, 0).detach().cpu()) 31 | ax[1, 0].set_ylabel("PCA Components 4-6", fontsize=22) 32 | ax[1, 2].imshow(hr_feats_pca[0, 3:6].permute(1, 2, 0).detach().cpu()) 33 | 34 | ax[2, 1].imshow(lr_feats_pca[0, 6:9].permute(1, 2, 0).detach().cpu()) 35 | ax[2, 0].set_ylabel("PCA Components 7-9", fontsize=22) 36 | ax[2, 2].imshow(hr_feats_pca[0, 6:9].permute(1, 2, 0).detach().cpu()) 37 | 38 | remove_axes(ax) 39 | plt.tight_layout() 40 | plt.close(fig) # Close plt to avoid additional empty plots 41 | return fig 42 | 43 | 44 | if __name__ == "__main__": 45 | 46 | def download_image(url, save_path): 47 | response = requests.get(url) 48 | with open(save_path, 'wb') as file: 49 | file.write(response.content) 50 | 51 | base_url = "https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/sample_images/" 52 | sample_images_urls = { 53 | "skate.jpg": base_url + "skate.jpg", 54 | "car.jpg": base_url + "car.jpg", 55 | "plant.png": base_url + "plant.png", 56 | } 57 | 58 | sample_images_dir = "/tmp/sample_images" 59 | 60 | # Ensure the directory for sample images exists 61 | os.makedirs(sample_images_dir, exist_ok=True) 62 | 63 | # Download each sample image 64 | for filename, url in sample_images_urls.items(): 65 | save_path = os.path.join(sample_images_dir, filename) 66 | # Download the image if it doesn't already exist 67 | if not os.path.exists(save_path): 68 | print(f"Downloading {filename}...") 69 | download_image(url, save_path) 70 | else: 71 | print(f"{filename} already exists. Skipping download.") 72 | 73 | os.environ['TORCH_HOME'] = '/tmp/.cache' 74 | os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache' 75 | csv.field_size_limit(100000000) 76 | options = ['dino16', 'vit', 'dinov2', 'clip', 'resnet50'] 77 | 78 | image_input = gr.Image(label="Choose an image to featurize", 79 | height=480, 80 | type="pil", 81 | image_mode='RGB', 82 | sources=['upload', 'webcam', 'clipboard'] 83 | ) 84 | model_option = gr.Radio(options, value="dino16", label='Choose a backbone to upsample') 85 | 86 | models = {o: torch.hub.load("mhamilton723/FeatUp", o) for o in options} 87 | 88 | 89 | def upsample_features(image, model_option): 90 | # Image preprocessing 91 | input_size = 224 92 | transform = T.Compose([ 93 | T.Resize(input_size), 94 | T.CenterCrop((input_size, input_size)), 95 | T.ToTensor(), 96 | norm 97 | ]) 98 | image_tensor = transform(image).unsqueeze(0).cuda() 99 | 100 | # Load the selected model 101 | upsampler = models[model_option].cuda() 102 | hr_feats = upsampler(image_tensor) 103 | lr_feats = upsampler.model(image_tensor) 104 | upsampler.cpu() 105 | 106 | return plot_feats(unnorm(image_tensor)[0], lr_feats[0], hr_feats[0]) 107 | 108 | 109 | demo = gr.Interface(fn=upsample_features, 110 | inputs=[image_input, model_option], 111 | outputs="plot", 112 | title="Feature Upsampling Demo", 113 | description="This demo allows you to upsample features of an image using selected models.", 114 | examples=[ 115 | ["/tmp/sample_images/skate.jpg", "dino16"], 116 | ["/tmp/sample_images/car.jpg", "dinov2"], 117 | ["/tmp/sample_images/plant.png", "dino16"], 118 | ] 119 | 120 | ) 121 | 122 | demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) 123 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # hubconf.py 2 | import torch 3 | from featup.featurizers.util import get_featurizer 4 | from featup.layers import ChannelNorm 5 | from featup.upsamplers import get_upsampler 6 | from torch.nn import Module 7 | 8 | dependencies = ['torch', 'torchvision', 'PIL', 'featup'] # List any dependencies here 9 | 10 | 11 | class UpsampledBackbone(Module): 12 | 13 | def __init__(self, model_name, use_norm): 14 | super().__init__() 15 | model, patch_size, self.dim = get_featurizer(model_name, "token", num_classes=1000) 16 | if use_norm: 17 | self.model = torch.nn.Sequential(model, ChannelNorm(self.dim)) 18 | else: 19 | self.model = model 20 | self.upsampler = get_upsampler("jbu_stack", self.dim) 21 | 22 | def forward(self, image): 23 | return self.upsampler(self.model(image), image) 24 | 25 | 26 | def _load_backbone(pretrained, use_norm, model_name): 27 | """ 28 | The function that will be called by Torch Hub users to instantiate your model. 29 | Args: 30 | pretrained (bool): If True, returns a model pre-loaded with weights. 31 | Returns: 32 | An instance of your model. 33 | """ 34 | model = UpsampledBackbone(model_name, use_norm) 35 | if pretrained: 36 | # Define how you load your pretrained weights here 37 | # For example: 38 | if use_norm: 39 | exp_dir = "" 40 | else: 41 | exp_dir = "no_norm/" 42 | 43 | checkpoint_url = f"https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/{exp_dir}{model_name}_jbu_stack_cocostuff.ckpt" 44 | state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["state_dict"] 45 | state_dict = {k: v for k, v in state_dict.items() if "scale_net" not in k and "downsampler" not in k} 46 | model.load_state_dict(state_dict, strict=False) 47 | return model 48 | 49 | 50 | def vit(pretrained=True, use_norm=True): 51 | return _load_backbone(pretrained, use_norm, "vit") 52 | 53 | 54 | def dino16(pretrained=True, use_norm=True): 55 | return _load_backbone(pretrained, use_norm, "dino16") 56 | 57 | 58 | def clip(pretrained=True, use_norm=True): 59 | return _load_backbone(pretrained, use_norm, "clip") 60 | 61 | 62 | def dinov2(pretrained=True, use_norm=True): 63 | return _load_backbone(pretrained, use_norm, "dinov2") 64 | 65 | 66 | def resnet50(pretrained=True, use_norm=True): 67 | return _load_backbone(pretrained, use_norm, "resnet50") 68 | 69 | def maskclip(pretrained=True, use_norm=True): 70 | assert not use_norm, "MaskCLIP only supports unnormed model" 71 | return _load_backbone(pretrained, use_norm, "maskclip") 72 | -------------------------------------------------------------------------------- /manifest.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | recursive-include featup/adaptive_conv_cuda *.cpp *.cu *.h 3 | recursive-include featup/configs *.yaml 4 | recursive-include featup/featurizers *.py 5 | recursive-include sample-images * 6 | -------------------------------------------------------------------------------- /sample-images/bird_full.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/sample-images/bird_full.jpg -------------------------------------------------------------------------------- /sample-images/bird_left.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/sample-images/bird_left.jpg -------------------------------------------------------------------------------- /sample-images/bird_right.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/sample-images/bird_right.jpg -------------------------------------------------------------------------------- /sample-images/cones2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/sample-images/cones2.jpg -------------------------------------------------------------------------------- /sample-images/cones3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/sample-images/cones3.jpg -------------------------------------------------------------------------------- /sample-images/dog_man_1_crop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/sample-images/dog_man_1_crop.jpg -------------------------------------------------------------------------------- /sample-images/plant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/sample-images/plant.png -------------------------------------------------------------------------------- /sample-images/skate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/sample-images/skate.jpg -------------------------------------------------------------------------------- /sample-images/teaser_wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhamilton723/FeatUp/6b5a6c0e91f75e69194807128dcbc39c3084a30d/sample-images/teaser_wide.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 3 | 4 | setup( 5 | name='featup', 6 | version='0.1.2', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'kornia', 11 | 'omegaconf', 12 | 'pytorch-lightning', 13 | 'torchvision', 14 | 'tqdm', 15 | 'torchmetrics', 16 | 'scikit-learn', 17 | 'numpy', 18 | 'matplotlib', 19 | 'timm==0.4.12', 20 | ], 21 | author='Mark Hamilton, Stephanie Fu', 22 | author_email='markth@mit.edu, fus@berkeley.edu', 23 | description='Official code for "FeatUp: A Model-Agnostic Frameworkfor Features at Any Resolution" ICLR 2024', 24 | long_description=open('README.md').read(), 25 | long_description_content_type='text/markdown', 26 | url='https://github.com/mhamilton723/FeatUp', 27 | classifiers=[ 28 | 'Programming Language :: Python :: 3', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Operating System :: OS Independent', 31 | ], 32 | python_requires='>=3.6', 33 | ext_modules=[ 34 | CUDAExtension( 35 | 'adaptive_conv_cuda_impl', 36 | [ 37 | 'featup/adaptive_conv_cuda/adaptive_conv_cuda.cpp', 38 | 'featup/adaptive_conv_cuda/adaptive_conv_kernel.cu', 39 | ]), 40 | CppExtension( 41 | 'adaptive_conv_cpp_impl', 42 | ['featup/adaptive_conv_cuda/adaptive_conv.cpp'], 43 | undef_macros=["NDEBUG"]), 44 | ], 45 | cmdclass={ 46 | 'build_ext': BuildExtension 47 | } 48 | ) 49 | -------------------------------------------------------------------------------- /simple_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from featup.adaptive_conv_cuda.adaptive_conv import AdaptiveConv 3 | 4 | a = torch.zeros((1, 3, 10, 10)).cuda() 5 | b = torch.zeros((1, 8, 8, 3, 3)).cuda() 6 | 7 | AdaptiveConv.apply(a, b) --------------------------------------------------------------------------------