├── pic ├── read ├── gif │ ├── read │ ├── 1.gif │ ├── 2.gif │ ├── 3.gif │ ├── 4.gif │ ├── 5.gif │ ├── 6.gif │ ├── 7.gif │ ├── 8.gif │ ├── 9.gif │ ├── 10.gif │ ├── 11.gif │ ├── 12.gif │ ├── 13.gif │ ├── 14.gif │ ├── 15.gif │ ├── 16.gif │ ├── 17.gif │ ├── 18.gif │ ├── 19.gif │ ├── 20.gif │ ├── 21.gif │ ├── 22.gif │ ├── 23.gif │ ├── 24.gif │ └── 25.gif └── framework_10.png ├── configs ├── read └── sign_language.yml ├── CLIP-main ├── read ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── simple_tokenizer.py │ └── clip.py ├── MANIFEST.in ├── requirements.txt ├── CLIP.png ├── setup.py ├── data │ ├── rendered-sst2.md │ ├── yfcc100m.md │ └── country211.md ├── LICENSE ├── hubconf.py ├── model-card.md └── README.md ├── models ├── fvd │ ├── __init__.py │ ├── convert_tf_pretrained.py │ └── fvd.py ├── better │ ├── op │ │ ├── __init__.py │ │ ├── upfirdn2d.cpp │ │ ├── fused_act.py │ │ └── upfirdn2d.py │ ├── __init__.py │ ├── utils.py │ ├── normalization.py │ ├── up_or_down_sampling.py │ └── layers3d.py ├── base_model.py ├── pndm.py ├── ema.py ├── mha_flash.py ├── eval_models.py ├── gaussianDf.py ├── pretrained_networks.py ├── networks_basic.py └── dist_model.py ├── runners └── __init__.py ├── requirements.txt ├── losses ├── main.py ├── TextToBert.py ├── __init__.py └── dsm.py ├── LICENSE ├── datasets ├── v2p.py ├── vision.py ├── sign_language.py └── utils.py ├── README.md ├── evaluation ├── pr.py ├── nearest_neighbor.py └── fid_score_OLD.py └── load_model_from_ckpt.py /pic/read: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /configs/read: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pic/gif/read: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /CLIP-main/read: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/fvd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/better/op/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /CLIP-main/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | from runners.ncsn_runner import * 2 | -------------------------------------------------------------------------------- /CLIP-main/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include clip/bpe_simple_vocab_16e6.txt.gz 2 | -------------------------------------------------------------------------------- /CLIP-main/requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | torch 5 | torchvision 6 | -------------------------------------------------------------------------------- /pic/gif/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/1.gif -------------------------------------------------------------------------------- /pic/gif/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/2.gif -------------------------------------------------------------------------------- /pic/gif/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/3.gif -------------------------------------------------------------------------------- /pic/gif/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/4.gif -------------------------------------------------------------------------------- /pic/gif/5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/5.gif -------------------------------------------------------------------------------- /pic/gif/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/6.gif -------------------------------------------------------------------------------- /pic/gif/7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/7.gif -------------------------------------------------------------------------------- /pic/gif/8.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/8.gif -------------------------------------------------------------------------------- /pic/gif/9.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/9.gif -------------------------------------------------------------------------------- /pic/gif/10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/10.gif -------------------------------------------------------------------------------- /pic/gif/11.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/11.gif -------------------------------------------------------------------------------- /pic/gif/12.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/12.gif -------------------------------------------------------------------------------- /pic/gif/13.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/13.gif -------------------------------------------------------------------------------- /pic/gif/14.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/14.gif -------------------------------------------------------------------------------- /pic/gif/15.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/15.gif -------------------------------------------------------------------------------- /pic/gif/16.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/16.gif -------------------------------------------------------------------------------- /pic/gif/17.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/17.gif -------------------------------------------------------------------------------- /pic/gif/18.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/18.gif -------------------------------------------------------------------------------- /pic/gif/19.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/19.gif -------------------------------------------------------------------------------- /pic/gif/20.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/20.gif -------------------------------------------------------------------------------- /pic/gif/21.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/21.gif -------------------------------------------------------------------------------- /pic/gif/22.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/22.gif -------------------------------------------------------------------------------- /pic/gif/23.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/23.gif -------------------------------------------------------------------------------- /pic/gif/24.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/24.gif -------------------------------------------------------------------------------- /pic/gif/25.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/gif/25.gif -------------------------------------------------------------------------------- /CLIP-main/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/CLIP-main/CLIP.png -------------------------------------------------------------------------------- /pic/framework_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/pic/framework_10.png -------------------------------------------------------------------------------- /CLIP-main/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingtiannihao/SignGen/HEAD/CLIP-main/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 2 | numpy 3 | PyYAML 4 | imageio 5 | imageio-ffmpeg 6 | matplotlib 7 | opencv-python 8 | scikit-image 9 | tqdm 10 | h5py 11 | progressbar 12 | psutil 13 | ninja 14 | -------------------------------------------------------------------------------- /CLIP-main/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="clip", 8 | py_modules=["clip"], 9 | version="1.0", 10 | description="", 11 | author="OpenAI", 12 | packages=find_packages(exclude=["tests*"]), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 17 | ) 18 | ], 19 | include_package_data=True, 20 | extras_require={'dev': ['pytest']}, 21 | ) 22 | -------------------------------------------------------------------------------- /CLIP-main/data/rendered-sst2.md: -------------------------------------------------------------------------------- 1 | # The Rendered SST2 Dataset 2 | 3 | In the paper, we used an image classification dataset called Rendered SST2, to evaluate the model's capability on optical character recognition. To do so, we rendered the sentences in the [Standford Sentiment Treebank v2](https://nlp.stanford.edu/sentiment/treebank.html) dataset and used those as the input to the CLIP image encoder. 4 | 5 | The following command will download a 131MB archive countaining the images and extract into a subdirectory `rendered-sst2`: 6 | 7 | ```bash 8 | wget https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz 9 | tar zxvf rendered-sst2.tgz 10 | ``` 11 | 12 | -------------------------------------------------------------------------------- /models/better/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /CLIP-main/data/yfcc100m.md: -------------------------------------------------------------------------------- 1 | # The YFCC100M Subset 2 | 3 | In the paper, we performed a dataset ablation using a subset of the YFCC100M dataset and showed that the performance remained largely similar. 4 | 5 | The subset contains 14,829,396 images, about 15% of the full dataset, which have been filtered to only keep those with natural languag titles and/or descriptions in English. 6 | 7 | We provide the list of (line number, photo identifier, photo hash) of each image contained in this subset. These correspond to the first three columns in the dataset's metadata TSV file. 8 | 9 | ```bash 10 | wget https://openaipublic.azureedge.net/clip/data/yfcc100m_subset_data.tsv.bz2 11 | bunzip2 yfcc100m_subset_data.tsv.bz2 12 | ``` 13 | 14 | Use of the underlying media files is subject to the Creative Commons licenses chosen by their creators/uploaders. For more information about the YFCC100M dataset, visit [the official website](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/). -------------------------------------------------------------------------------- /CLIP-main/data/country211.md: -------------------------------------------------------------------------------- 1 | # The Country211 Dataset 2 | 3 | In the paper, we used an image classification dataset called Country211, to evaluate the model's capability on geolocation. To do so, we filtered the YFCC100m dataset that have GPS coordinate corresponding to a [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes) and created a balanced dataset by sampling 150 train images, 50 validation images, and 100 test images images for each country. 4 | 5 | The following command will download an 11GB archive countaining the images and extract into a subdirectory `country211`: 6 | 7 | ```bash 8 | wget https://openaipublic.azureedge.net/clip/data/country211.tgz 9 | tar zxvf country211.tgz 10 | ``` 11 | 12 | These images are a subset of the YFCC100m dataset. Use of the underlying media files is subject to the Creative Commons licenses chosen by their creators/uploaders. For more information about the YFCC100M dataset, visit [the official website](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/). -------------------------------------------------------------------------------- /losses/main.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from losses.TextToBert import textToBert_dim96 5 | from pytorch_transformers import BertModel, BertConfig,BertTokenizer 6 | 7 | def textToBert_dim(text): 8 | tokenizer = BertTokenizer.from_pretrained('bert-base-cased/vocab.txt') 9 | # s = 'Our Deeds are the Reason of this #earthquake May ALLAH Forgive us all' 10 | tokens = tokenizer.tokenize(text) 11 | tokens = ["[CLS]"] + tokens + ["[SEP]"] 12 | ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)]) 13 | model = BertModel.from_pretrained('bert-base-cased') 14 | # pooled 15 | all_layers_all_words, pooled = model(ids) 16 | # print( pooled) 17 | net = nn.Sequential( nn.Linear(768, 768),nn.ReLU(),nn.Linear(768, 96)) 18 | output = net(pooled) 19 | return output 20 | def transpos(text): 21 | 22 | cond = textToBert_dim(text) 23 | print(cond) 24 | 25 | if __name__ == "__main__": 26 | text = "a dog" 27 | cond = transpos(text) 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /losses/TextToBert.py: -------------------------------------------------------------------------------- 1 | 2 | from pytorch_transformers import BertModel, BertConfig,BertTokenizer 3 | import torch 4 | import torch.nn as nn 5 | 6 | def textToBert_dim96(text): 7 | # tokenizer = BertTokenizer.from_pretrained('bert-base-cased/vocab.txt') 8 | tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path='bert-base-chinese') 9 | # s = 'Our Deeds are the Reason of this #earthquake May ALLAH Forgive us all' 10 | tokens = tokenizer.tokenize(text) 11 | print(tokens) 12 | tokens = ["[CLS]"] + tokens + ["[SEP]"] 13 | # print(tokens) 14 | 15 | ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)]) 16 | # print(ids.shape) 17 | 18 | model = BertModel.from_pretrained('bert-base-cased') 19 | # pooled 可以暂时理解为最后一层每一个句子第一个单词[cls]的结果 20 | all_layers_all_words, pooled = model(ids) 21 | # print( pooled) 22 | # print(pooled.shape) 23 | net = nn.Sequential( nn.Linear(768, 768),nn.ReLU(),nn.Linear(768, 96),nn.Tanh()) 24 | output = (net(pooled)) 25 | return output 26 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def get_optimizer(config, parameters): 5 | if config.optim.optimizer == 'Adam': 6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay, 7 | betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad, 8 | eps=config.optim.eps) 9 | elif config.optim.optimizer == 'RMSProp': 10 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay) 11 | elif config.optim.optimizer == 'SGD': 12 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9) 13 | else: 14 | raise NotImplementedError('Optimizer {} not understood.'.format(config.optim.optimizer)) 15 | 16 | 17 | def warmup_lr(optimizer, step, warmup, max_lr): 18 | if step > warmup: 19 | return max_lr 20 | lr = max_lr * min(float(step) / max(warmup, 1), 1.0) 21 | for param_group in optimizer.param_groups: 22 | param_group["lr"] = lr 23 | return lr 24 | 25 | -------------------------------------------------------------------------------- /models/better/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 mingtiannihao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CLIP-main/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /datasets/v2p.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import cv2 5 | 6 | # mp4存放的路径,路径下只有mp4 7 | videos_src_path = r'D:/DeepLearning_Projects/mcvd-pytorch-master/datasets/videos/1/' 8 | # 保存的路径,会在路径下创建mp4文件名的文件夹保存图片 9 | videos_save_path = r'D:/DeepLearning_Projects/mcvd-pytorch-master/datasets/videos/1' 10 | 11 | videos = os.listdir(videos_src_path) 12 | # videos = filter(lambda x: x.endswith('MP4'), videos) 13 | 14 | for each_video in videos: 15 | print('Video Name :', each_video) 16 | # get the name of each video, and make the directory to save frames 17 | each_video_name, _ = each_video.split('.') 18 | os.mkdir(videos_save_path + '/' + each_video_name) 19 | 20 | each_video_save_full_path = os.path.join(videos_save_path, each_video_name) + '/' 21 | 22 | # get the full path of each video, which will open the video tp extract frames 23 | each_video_full_path = os.path.join(videos_src_path, each_video) 24 | 25 | cap = cv2.VideoCapture(each_video_full_path) 26 | # 第几帧 27 | frame_count = 1 28 | # 隔着多少帧取一张 29 | frame_rate = 1 30 | success = True 31 | # 计数 32 | num = 0 33 | while (success): 34 | success, frame = cap.read() 35 | if success == True: 36 | 37 | if frame_count % frame_rate == 0: 38 | cv2.imwrite(each_video_save_full_path + each_video_name + "%06d.jpg" % num, frame) 39 | num += 1 40 | 41 | frame_count = frame_count + 1 42 | print('Final frame:', num) 43 | 44 | -------------------------------------------------------------------------------- /CLIP-main/hubconf.py: -------------------------------------------------------------------------------- 1 | from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models 2 | import re 3 | import string 4 | 5 | dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"] 6 | 7 | # For compatibility (cannot include special characters in function name) 8 | model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()} 9 | 10 | def _create_hub_entrypoint(model): 11 | def entrypoint(**kwargs): 12 | return _load(model, **kwargs) 13 | 14 | entrypoint.__doc__ = f"""Loads the {model} CLIP model 15 | 16 | Parameters 17 | ---------- 18 | device : Union[str, torch.device] 19 | The device to put the loaded model 20 | 21 | jit : bool 22 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 23 | 24 | download_root: str 25 | path to download the model files; by default, it uses "~/.cache/clip" 26 | 27 | Returns 28 | ------- 29 | model : torch.nn.Module 30 | The {model} CLIP model 31 | 32 | preprocess : Callable[[PIL.Image], torch.Tensor] 33 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 34 | """ 35 | return entrypoint 36 | 37 | def tokenize(): 38 | return _tokenize 39 | 40 | _entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()} 41 | 42 | globals().update(_entrypoints) -------------------------------------------------------------------------------- /datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | def __getitem__(self, index): 15 | raise NotImplementedError 16 | 17 | def __len__(self): 18 | raise NotImplementedError 19 | 20 | def __repr__(self): 21 | head = "Dataset " + self.__class__.__name__ 22 | body = ["Number of datapoints: {}".format(self.__len__())] 23 | if self.root is not None: 24 | body.append("Root location: {}".format(self.root)) 25 | body += self.extra_repr().splitlines() 26 | if hasattr(self, 'transform') and self.transform is not None: 27 | body += self._format_transform_repr(self.transform, 28 | "Transforms: ") 29 | if hasattr(self, 'target_transform') and self.target_transform is not None: 30 | body += self._format_transform_repr(self.target_transform, 31 | "Target transforms: ") 32 | lines = [head] + [" " * self._repr_indent + line for line in body] 33 | return '\n'.join(lines) 34 | 35 | def _format_transform_repr(self, transform, head): 36 | lines = transform.__repr__().splitlines() 37 | return (["{}{}".format(head, lines[0])] + 38 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 39 | 40 | def extra_repr(self): 41 | return "" 42 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | class BaseModel(): 6 | def __init__(self): 7 | pass; 8 | 9 | def name(self): 10 | return 'BaseModel' 11 | 12 | def initialize(self, use_gpu=True, gpu_ids=[0]): 13 | self.use_gpu = use_gpu 14 | self.gpu_ids = gpu_ids 15 | 16 | def forward(self): 17 | pass 18 | 19 | def get_image_paths(self): 20 | pass 21 | 22 | def optimize_parameters(self): 23 | pass 24 | 25 | def get_current_visuals(self): 26 | return self.input 27 | 28 | def get_current_errors(self): 29 | return {} 30 | 31 | def save(self, label): 32 | pass 33 | 34 | # helper saving function that can be used by subclasses 35 | def save_network(self, network, path, network_label, epoch_label): 36 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 37 | save_path = os.path.join(path, save_filename) 38 | torch.save(network.state_dict(), save_path) 39 | 40 | # helper loading function that can be used by subclasses 41 | def load_network(self, network, network_label, epoch_label): 42 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 43 | save_path = os.path.join(self.save_dir, save_filename) 44 | print('Loading network from %s'%save_path) 45 | network.load_state_dict(torch.load(save_path)) 46 | 47 | def update_learning_rate(): 48 | pass 49 | 50 | def get_image_paths(self): 51 | return self.image_paths 52 | 53 | def save_done(self, flag=False): 54 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 55 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') -------------------------------------------------------------------------------- /models/pndm.py: -------------------------------------------------------------------------------- 1 | ## Modified from https://github.com/luping-liu/PNDM/blob/f285e8e6da36049ea29e97b741fb71e531505ec8/runner/method.py#L20 2 | 3 | def runge_kutta(x, t_list, model, alphas_cump, ets, clip_before=False): 4 | e_1 = model(x, t_list[0]) 5 | ets.append(e_1) 6 | x_2 = transfer(x, t_list[0], t_list[1], e_1, alphas_cump, clip_before) 7 | 8 | e_2 = model(x_2, t_list[1]) 9 | x_3 = transfer(x, t_list[0], t_list[1], e_2, alphas_cump, clip_before) 10 | 11 | e_3 = model(x_3, t_list[1]) 12 | x_4 = transfer(x, t_list[0], t_list[2], e_3, alphas_cump, clip_before) 13 | 14 | e_4 = model(x_4, t_list[2]) 15 | et = (1 / 6) * (e_1 + 2 * e_2 + 2 * e_3 + e_4) 16 | 17 | return et, ets 18 | 19 | def transfer(x, t, t_next, et, alphas_cump, clip_before=False): 20 | at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1) 21 | at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1) 22 | 23 | # x0 = (1 / c_alpha.sqrt()) * (x_mod - (1 - c_alpha).sqrt() * grad) 24 | # x_mod = c_alpha_prev.sqrt() * x0 + (1 - c_alpha_prev).sqrt() * grad 25 | 26 | x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - \ 27 | 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) 28 | 29 | x_next = x + x_delta 30 | if clip_before: 31 | x_next = x_next.clip_(-1, 1) 32 | 33 | return x_next 34 | 35 | def gen_order_1(img, t, t_next, model, alphas_cump, ets, clip_before=False): ## DDIM 36 | noise = model(img, t) 37 | ets.append(noise) 38 | img_next = transfer(img, t, t_next, noise, alphas_cump, clip_before) 39 | return img_next, ets 40 | 41 | def gen_order_4(img, t, t_next, model, alphas_cump, ets, clip_before=False): ## F-PNDM 42 | t_list = [t, (t+t_next)/2, t_next] 43 | #print(t_list) 44 | if len(ets) > 2: 45 | noise_ = model(img, t) 46 | ets.append(noise_) 47 | noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) 48 | else: 49 | noise, ets = runge_kutta(img, t_list, model, alphas_cump, ets, clip_before) 50 | 51 | img_next = transfer(img, t, t_next, noise, alphas_cump, clip_before) 52 | return img_next, ets 53 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn as nn 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 22 | 23 | def ema(self, module): 24 | if isinstance(module, nn.DataParallel): 25 | module = module.module 26 | for name, param in module.named_parameters(): 27 | if param.requires_grad: 28 | param.data.copy_(self.shadow[name].data) 29 | 30 | def ema_copy(self, module): 31 | if isinstance(module, nn.DataParallel): 32 | inner_module = module.module 33 | module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device) 34 | module_copy.load_state_dict(inner_module.state_dict()) 35 | module_copy = nn.DataParallel(module_copy) 36 | else: 37 | module_copy = type(module)(module.config).to(module.config.device) 38 | module_copy.load_state_dict(module.state_dict()) 39 | # module_copy = copy.deepcopy(module) 40 | self.ema(module_copy) 41 | return module_copy 42 | 43 | def state_dict(self): 44 | return self.shadow 45 | 46 | def load_state_dict(self, state_dict): 47 | self.shadow = state_dict 48 | 49 | 50 | # import glob, torch, tqdm 51 | # ckpt_files = sorted(glob.glob("*.pt")) 52 | # for file in tqdm.tqdm(ckpt_files): 53 | # a = torch.load(file) 54 | # a[0]['module.unet.all_modules.52.Norm_0.weight'] = a[0].pop('module.unet.all_modules.52.weight') 55 | # a[0]['module.unet.all_modules.52.Norm_0.bias'] = a[0].pop('module.unet.all_modules.52.bias') 56 | # a[-1]['unet.all_modules.52.Norm_0.weight'] = a[-1].pop('unet.all_modules.52.weight') 57 | # a[-1]['unet.all_modules.52.Norm_0.bias'] = a[-1].pop('unet.all_modules.52.bias') 58 | # torch.save(a, file) 59 | -------------------------------------------------------------------------------- /losses/dsm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from functools import partial 4 | from torch.distributions.gamma import Gamma 5 | 6 | 7 | def anneal_dsm_score_estimation(scorenet, x, labels=None, loss_type='a', hook=None, cond=None, cond_mask=None, 8 | gamma=False, L1=False, all_frames=False, texts=None,deepth_video=None,motion = None): 9 | net = scorenet.module if hasattr(scorenet, 'module') else scorenet 10 | version = getattr(net, 'version', 'SMLD').upper() 11 | net_type = getattr(net, 'type') if isinstance(getattr(net, 'type'), str) else 'v1' 12 | 13 | if all_frames: 14 | x = torch.cat([x, cond], dim=1) 15 | cond = None 16 | 17 | # z, perturbed_x 18 | if version == "SMLD": 19 | sigmas = net.sigmas 20 | if labels is None: 21 | labels = torch.randint(0, len(sigmas), (x.shape[0],), device=x.device) 22 | used_sigmas = sigmas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 23 | z = torch.randn_like(x) 24 | perturbed_x = x + used_sigmas * z 25 | elif version == "DDPM" or version == "DDIM" or version == "FPNDM": 26 | alphas = net.alphas 27 | if labels is None: 28 | labels = torch.randint(0, len(alphas), (x.shape[0],), device=x.device) 29 | used_alphas = alphas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 30 | if gamma: 31 | used_k = net.k_cum[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))).repeat(1, *x.shape[1:]) 32 | used_theta = net.theta_t[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))).repeat(1, *x.shape[1:]) 33 | z = Gamma(used_k, 1 / used_theta).sample() 34 | z = (z - used_k * used_theta) / (1 - used_alphas).sqrt() 35 | else: 36 | 37 | z = torch.randn_like(x) 38 | used_alphas = alphas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 39 | 40 | perturbed_x = used_alphas.sqrt() * x + (1 - used_alphas).sqrt() * z 41 | scorenet = partial(scorenet, cond=cond) 42 | 43 | # Loss 44 | if L1: 45 | def pow_(x): 46 | return x.abs() 47 | else: 48 | def pow_(x): 49 | return 1 / 2. * x.square() 50 | 51 | loss = pow_((z - scorenet(perturbed_x, labels, cond_mask=cond_mask,texts=texts,deepth_video=deepth_video,motion=motion)).reshape(len(x), -1)).sum(dim=-1) 52 | 53 | if hook is not None: 54 | hook(loss, labels) 55 | 56 | return loss.mean(dim=0) 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

SignGen: End-to-End Sign Language Video Generation with Latent Diffusion

2 | 3 | 4 | 5 | 6 | ## Method 7 | 8 | 9 | Editor 10 | 11 | 12 | 13 | 14 | 15 | ## Experiment Results 16 | 17 | #### RWTH-2014 18 | 19 | ![case1](pic/gif/1.gif "case1")![case2](pic/gif/2.gif "case2")![case3](pic/gif/3.gif "case3") 20 | ![case4](pic/gif/4.gif "case4")![case4](pic/gif/5.gif "case4")![case4](pic/gif/6.gif "case4") 21 | ![case4](pic/gif/10.gif "case4")![case4](pic/gif/11.gif "case4")![case4](pic/gif/12.gif "case4") 22 | ![case4](pic/gif/19.gif "case4")![case4](pic/gif/22.gif "case4")![case4](pic/gif/21.gif "case4") 23 | 24 | #### RWTH-2014T 25 | 26 | ![case1](pic/gif/7.gif "case1")![case2](pic/gif/8.gif "case2")![case3](pic/gif/9.gif "case3") 27 | ![case4](pic/gif/13.gif "case4")![case4](pic/gif/14.gif "case4")![case4](pic/gif/15.gif "case4") 28 | ![case4](pic/gif/16.gif "case4")![case4](pic/gif/17.gif "case4")![case4](pic/gif/23.gif "case4") 29 | 30 | #### AUTSL 31 | ![case4](pic/gif/24.gif "case4") 32 | 33 | ![case4](pic/gif/25.gif "case4") 34 | 35 | ## Running by Yourself 36 | 37 | ### 1. Installation 38 | 39 | create a conda environment. 40 | ``` 41 | conda create -n xxx python==3.8.5 42 | ``` 43 | 44 | Then you can create the same environment as ours with the following command: 45 | ``` 46 | pip install -r requirements.txt # install all requirements 47 | ``` 48 | 49 | ### 2. Download model weights 50 | 51 | #### For LPIPS 52 | 53 | The code will do it for you! 54 | > Code will download [Alex](https://download.pytorch.org/models/alexnet-owt-7be5be79.pth) and move it into: `models/weights/v0.1/alex.pth` 55 | 56 | #### For FVD 57 | 58 | The code will do it for you! 59 | 60 | > Code will download i3D model pretrained on [Kinetics-400](https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI) 61 | > Use `models/fvd/convert_tf_pretrained.py` to make `i3d_pretrained_400.pt` 62 | 63 | ### 3. Datasets 64 | 65 | You can download these datasets such as [RWTH-2014](https://www-i6.informatik.rwth-aachen.de/~koller/RWTH-PHOENIX/),[RWTH-2014T](https://www-i6.informatik.rwth-aachen.de/~koller/RWTH-PHOENIX-2014-T/) and [AUTSL](https://chalearnlap.cvc.uab.cat/dataset/40/data/66/description/). 66 | 67 | > **How the data was processed:** 68 | > 1. Download AUTSL dataset to `/path/to/AUTSL`:\ 69 | > 2. Convert 128x128 images to HDF5 format:\ 70 | > `python datasets/sign_language_convert.py --sl_dir 'datasets/videos' --split 'train' --out_dir 'datasets/signLanguages/train' --image_size 128 --force_h5 False 71 | 72 | ### Training -The code is coming soon. 73 | 74 | 75 | -------------------------------------------------------------------------------- /evaluation/pr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def calc_cdist(feat1, feat2, batch_size=10000): 4 | dists = [] 5 | for feat2_batch in feat2.split(batch_size): 6 | dists.append(torch.cdist(feat1, feat2_batch).cpu()) 7 | return torch.cat(dists, dim=1) 8 | 9 | 10 | def calculate_precision_recall_part(feat_r, feat_g, k=3, batch_size=10000): 11 | # Precision 12 | NNk_r = [] 13 | for feat_r_batch in feat_r.split(batch_size): 14 | NNk_r.append(calc_cdist(feat_r_batch, feat_r, batch_size).kthvalue(k+1).values) 15 | NNk_r = torch.cat(NNk_r) 16 | precision = [] 17 | for feat_g_batch in feat_g.split(batch_size): 18 | dist_g_r_batch = calc_cdist(feat_g_batch, feat_r, batch_size) 19 | precision.append((dist_g_r_batch <= NNk_r).any(dim=1).float()) 20 | precision = torch.cat(precision).mean().item() 21 | # Recall 22 | NNk_g = [] 23 | for feat_g_batch in feat_g.split(batch_size): 24 | NNk_g.append(calc_cdist(feat_g_batch, feat_g, batch_size).kthvalue(k+1).values) 25 | NNk_g = torch.cat(NNk_g) 26 | recall = [] 27 | for feat_r_batch in feat_r.split(batch_size): 28 | dist_r_g_batch = calc_cdist(feat_r_batch, feat_g, batch_size) 29 | recall.append((dist_r_g_batch <= NNk_g).any(dim=1).float()) 30 | recall = torch.cat(recall).mean().item() 31 | return precision, recall 32 | 33 | 34 | def calc_cdist_full(feat1, feat2, batch_size=10000): 35 | dists = [] 36 | for feat1_batch in feat1.split(batch_size): 37 | dists_batch = [] 38 | for feat2_batch in feat2.split(batch_size): 39 | dists_batch.append(torch.cdist(feat1_batch, feat2_batch).cpu()) 40 | dists.append(torch.cat(dists_batch, dim=1)) 41 | return torch.cat(dists, dim=0) 42 | 43 | 44 | def calculate_precision_recall_full(feat_r, feat_g, k=3, batch_size=10000): 45 | NNk_r = calc_cdist_full(feat_r, feat_r, batch_size).kthvalue(k+1).values 46 | NNk_g = calc_cdist_full(feat_g, feat_g, batch_size).kthvalue(k+1).values 47 | dist_g_r = calc_cdist_full(feat_g, feat_r, batch_size) 48 | dist_r_g = dist_g_r.T 49 | # Precision 50 | precision = (dist_g_r <= NNk_r).any(dim=1).float().mean().item() 51 | # Recall 52 | recall = (dist_r_g <= NNk_g).any(dim=1).float().mean().item() 53 | return precision, recall 54 | 55 | 56 | def calculate_precision_recall(feat_r, feat_g, device=torch.device('cuda'), k=3, 57 | batch_size=10000, save_cpu_ram=False, **kwargs): 58 | feat_r = feat_r.to(device) 59 | feat_g = feat_g.to(device) 60 | if save_cpu_ram: 61 | return calculate_precision_recall_part(feat_r, feat_g, k, batch_size) 62 | else: 63 | return calculate_precision_recall_full(feat_r, feat_g, k, batch_size) 64 | 65 | -------------------------------------------------------------------------------- /configs/sign_language.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 6 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 20000 9 | val_freq: 50000 10 | log_freq: 100 11 | log_all_sigmas: true 12 | 13 | sampling: 14 | batch_size: 2 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 100 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 #15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 100 31 | train: false 32 | num_frames_pred: 10 33 | clip_before: true 34 | max_data_iter: 100 35 | init_prev_t: -1 # -1 if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 2 38 | 39 | fast_fid: 40 | batch_size: 100 41 | num_samples: 100 42 | begin_ckpt: 0 43 | freq: 100 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: true 48 | step_lr: 0.0 49 | n_steps_each: 100 50 | 51 | test: 52 | begin_ckpt: 100 53 | end_ckpt: 300000 54 | batch_size: 6 55 | 56 | data: 57 | dataset: "SL" 58 | image_size: 128 59 | channels: 3 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | color_jitter: 0.0 66 | test_subset: -1 67 | num_workers: 6 68 | num_frames: 5 69 | num_frames_cond: 1 70 | num_frames_future: 0 71 | prob_mask_cond: 0.0 #mask 72 | prob_mask_future: 0.0 73 | prob_mask_sync: false 74 | 75 | model: 76 | depth: deeper 77 | version: DDPM 78 | gamma: false 79 | arch: unetmore 80 | type: v1 81 | time_conditional: true 82 | dropout: 0.1 83 | sigma_dist: linear 84 | sigma_begin: 0.02 85 | sigma_end: 0.0001 86 | num_classes: 1000 87 | ema: true 88 | ema_rate: 0.999 89 | spec_norm: false 90 | normalization: InstanceNorm++ 91 | nonlinearity: swish 92 | ngf: 192 93 | ch_mult: 94 | - 1 95 | - 2 96 | - 3 97 | - 4 98 | num_res_blocks: 8 # 8 for traditional 99 | attn_resolutions: 100 | - 8 101 | - 16 # can use only 16 for traditional 102 | n_head_channels: 192 # -1 for traditional #96 103 | conditional: true 104 | noise_in_cond: false 105 | output_all_frames: false # could be useful especially for 3d models 106 | cond_emb: false 107 | sd_checkpoint: None 108 | spade: false 109 | spade_dim: 256 110 | use_motion: true 111 | use_depthmap: true 112 | use_pose: true 113 | concat_dim: 4 114 | context_dim: 512 115 | adapter_transformer_layers: 1 116 | 117 | optim: 118 | weight_decay: 0.001 119 | optimizer: "Adam" 120 | lr: 0.0001 121 | warmup: 5000 122 | beta1: 0.9 123 | amsgrad: false 124 | eps: 0.00000001 125 | grad_clip: 1.0 126 | -------------------------------------------------------------------------------- /models/better/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.system("unset TORCH_CUDA_ARCH_LIST") 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.autograd import Function 8 | from torch.utils.cpp_extension import load 9 | 10 | 11 | module_path = os.path.dirname(__file__) 12 | 13 | 14 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 15 | if input.device.type == "cpu": 16 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 17 | return ( 18 | F.leaky_relu( 19 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 20 | ) 21 | * scale 22 | ) 23 | 24 | else: 25 | 26 | fused = load( 27 | "fused", 28 | sources=[ 29 | os.path.join(module_path, "fused_bias_act.cpp"), 30 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 31 | ], 32 | ) 33 | 34 | class FusedLeakyReLUFunctionBackward(Function): 35 | 36 | @staticmethod 37 | def forward(ctx, grad_output, out, negative_slope, scale): 38 | ctx.save_for_backward(out) 39 | ctx.negative_slope = negative_slope 40 | ctx.scale = scale 41 | 42 | empty = grad_output.new_empty(0) 43 | 44 | grad_input = fused.fused_bias_act( 45 | grad_output, empty, out, 3, 1, negative_slope, scale 46 | ) 47 | 48 | dim = [0] 49 | 50 | if grad_input.ndim > 2: 51 | dim += list(range(2, grad_input.ndim)) 52 | 53 | grad_bias = grad_input.sum(dim).detach() 54 | 55 | return grad_input, grad_bias 56 | 57 | @staticmethod 58 | def backward(ctx, gradgrad_input, gradgrad_bias): 59 | out, = ctx.saved_tensors 60 | gradgrad_out = fused.fused_bias_act( 61 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 62 | ) 63 | 64 | return gradgrad_out, None, None, None 65 | 66 | 67 | class FusedLeakyReLUFunction(Function): 68 | 69 | @staticmethod 70 | def forward(ctx, input, bias, negative_slope, scale): 71 | empty = input.new_empty(0) 72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 73 | ctx.save_for_backward(out) 74 | ctx.negative_slope = negative_slope 75 | ctx.scale = scale 76 | 77 | return out 78 | 79 | @staticmethod 80 | def backward(ctx, grad_output): 81 | out, = ctx.saved_tensors 82 | 83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 84 | grad_output, out, ctx.negative_slope, ctx.scale 85 | ) 86 | 87 | return grad_input, grad_bias, None, None 88 | 89 | 90 | class FusedLeakyReLU(nn.Module): 91 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 92 | super().__init__() 93 | 94 | self.bias = nn.Parameter(torch.zeros(channel)) 95 | self.negative_slope = negative_slope 96 | self.scale = scale 97 | 98 | def forward(self, input): 99 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 100 | 101 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 102 | -------------------------------------------------------------------------------- /models/mha_flash.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.cuda.amp as amp 4 | import torch.nn.functional as F 5 | import math 6 | import os 7 | import time 8 | import numpy as np 9 | import random 10 | 11 | from flash_attn.flash_attention import FlashAttention 12 | 13 | class FlashAttentionBlock(nn.Module): 14 | 15 | def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4): 16 | # consider head_dim first, then num_heads 17 | num_heads = dim // head_dim if head_dim else num_heads 18 | head_dim = dim // num_heads 19 | assert num_heads * head_dim == dim 20 | super(FlashAttentionBlock, self).__init__() 21 | self.dim = dim 22 | self.context_dim = context_dim 23 | self.num_heads = num_heads 24 | self.head_dim = head_dim 25 | self.scale = math.pow(head_dim, -0.25) 26 | 27 | # layers 28 | self.norm = nn.GroupNorm(32, dim) 29 | self.to_qkv = nn.Conv2d(dim, dim * 3, 1) 30 | if context_dim is not None: 31 | self.context_kv = nn.Linear(context_dim, dim * 2) 32 | self.proj = nn.Conv2d(dim, dim, 1) 33 | 34 | if self.head_dim <= 128 and (self.head_dim % 8) == 0: 35 | new_scale = math.pow(head_dim, -0.5) 36 | self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0) 37 | 38 | # zero out the last layer params 39 | nn.init.zeros_(self.proj.weight) 40 | # self.apply(self._init_weight) 41 | 42 | 43 | def _init_weight(self, module): 44 | if isinstance(module, nn.Linear): 45 | module.weight.data.normal_(mean=0.0, std=0.15) 46 | if module.bias is not None: 47 | module.bias.data.zero_() 48 | elif isinstance(module, nn.Conv2d): 49 | module.weight.data.normal_(mean=0.0, std=0.15) 50 | if module.bias is not None: 51 | module.bias.data.zero_() 52 | 53 | def forward(self, x, context=None): 54 | r"""x: [B, C, H, W]. 55 | context: [B, L, C] or None. 56 | """ 57 | identity = x 58 | b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim 59 | 60 | # compute query, key, value 61 | x = self.norm(x) 62 | q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) 63 | if context is not None: 64 | ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) 65 | k = torch.cat([ck, k], dim=-1) 66 | v = torch.cat([cv, v], dim=-1) 67 | cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device) 68 | q = torch.cat([q, cq], dim=-1) 69 | 70 | qkv = torch.cat([q,k,v], dim=1) 71 | origin_dtype = qkv.dtype 72 | qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous() 73 | out, _ = self.flash_attn(qkv) 74 | out.to(origin_dtype) 75 | 76 | if context is not None: 77 | out = out[:, :-4, :, :] 78 | out = out.permute(0, 2, 3, 1).reshape(b, c, h, w) 79 | 80 | # output 81 | x = self.proj(out) 82 | return x + identity 83 | 84 | if __name__ == '__main__': 85 | batch_size = 8 86 | flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda() 87 | 88 | x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda() 89 | context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda() 90 | # context = None 91 | flash_net.eval() 92 | 93 | with amp.autocast(enabled=True): 94 | # warm up 95 | for i in range(5): 96 | y = flash_net(x, context) 97 | torch.cuda.synchronize() 98 | s1 = time.time() 99 | for i in range(10): 100 | y = flash_net(x, context) 101 | torch.cuda.synchronize() 102 | s2 = time.time() 103 | 104 | print(f'Average cost time {(s2-s1)*1000/10} ms') -------------------------------------------------------------------------------- /models/fvd/convert_tf_pretrained.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | import tensorflow_hub as hub 4 | import torch 5 | 6 | from src_pytorch.fvd.pytorch_i3d import InceptionI3d 7 | 8 | 9 | def convert_name(name): 10 | mapping = { 11 | 'conv_3d': 'conv3d', 12 | 'batch_norm': 'bn', 13 | 'w:0': 'weight', 14 | 'b:0': 'bias', 15 | 'moving_mean:0': 'running_mean', 16 | 'moving_variance:0': 'running_var', 17 | 'beta:0': 'bias' 18 | } 19 | 20 | segs = name.split('/') 21 | new_segs = [] 22 | i = 0 23 | while i < len(segs): 24 | seg = segs[i] 25 | if 'Mixed' in seg: 26 | new_segs.append(seg) 27 | elif 'Conv' in seg and 'Mixed' not in name: 28 | new_segs.append(seg) 29 | elif 'Branch' in seg: 30 | branch_i = int(seg.split('_')[-1]) 31 | i += 1 32 | seg = segs[i] 33 | 34 | # special case due to typo in original code 35 | if 'Mixed_5b' in name and branch_i == 2: 36 | if '1x1' in seg: 37 | new_segs.append(f'b{branch_i}a') 38 | elif '3x3' in seg: 39 | new_segs.append(f'b{branch_i}b') 40 | else: 41 | raise Exception() 42 | # Either Conv3d_{i}a_... or Conv3d_{i}b_... 43 | elif 'a' in seg: 44 | if branch_i == 0: 45 | new_segs.append('b0') 46 | else: 47 | new_segs.append(f'b{branch_i}a') 48 | elif 'b' in seg: 49 | new_segs.append(f'b{branch_i}b') 50 | else: 51 | raise Exception 52 | elif seg == 'Logits': 53 | new_segs.append('logits') 54 | i += 1 55 | elif seg in mapping: 56 | new_segs.append(mapping[seg]) 57 | else: 58 | raise Exception(f"No match found for seg {seg} in name {name}") 59 | 60 | i += 1 61 | return '.'.join(new_segs) 62 | 63 | def convert_tensor(tensor): 64 | tensor_dim = len(tensor.shape) 65 | if tensor_dim == 5: # conv or bn 66 | if all([t == 1 for t in tensor.shape[:-1]]): 67 | tensor = tensor.squeeze() 68 | else: 69 | tensor = tensor.permute(4, 3, 0, 1, 2).contiguous() 70 | elif tensor_dim == 1: # conv bias 71 | pass 72 | else: 73 | raise Exception(f"Invalid shape {tensor.shape}") 74 | return tensor 75 | 76 | n_class = int(sys.argv[1]) # 600 or 400 77 | assert n_class in [400, 600] 78 | 79 | # Converts model from https://github.com/google-research/google-research/tree/master/frechet_video_distance 80 | # to pytorch version for loading 81 | model_url = f"https://tfhub.dev/deepmind/i3d-kinetics-{n_class}/1" 82 | i3d = hub.load(model_url) 83 | name_prefix = 'RGB/inception_i3d/' 84 | 85 | print('Creating state_dict...') 86 | all_names = [] 87 | state_dict = OrderedDict() 88 | for var in i3d.variables: 89 | name = var.name[len(name_prefix):] 90 | new_name = convert_name(name) 91 | all_names.append(new_name) 92 | 93 | tensor = torch.FloatTensor(var.value().numpy()) 94 | new_tensor = convert_tensor(tensor) 95 | 96 | state_dict[new_name] = new_tensor 97 | 98 | if 'bn.bias' in new_name: 99 | new_name = new_name[:-4] + 'weight' # bn.weight 100 | new_tensor = torch.ones_like(new_tensor).float() 101 | state_dict[new_name] = new_tensor 102 | 103 | print(f'Complete state_dict with {len(state_dict)} entries') 104 | 105 | s = dict() 106 | for i, n in enumerate(all_names): 107 | s[n] = s.get(n, []) + [i] 108 | 109 | for k, v in s.items(): 110 | if len(v) > 1: 111 | print('dup', k) 112 | for i in v: 113 | print('\t', i3d.variables[i].name) 114 | 115 | print('Testing load_state_dict...') 116 | print('Creating model...') 117 | 118 | i3d = InceptionI3d(n_class, in_channels=3) 119 | 120 | print('Loading state_dict...') 121 | i3d.load_state_dict(state_dict) 122 | 123 | print(f'Saving state_dict as fvd/i3d_pretrained_{n_class}.pt') 124 | torch.save(state_dict, f'fvd/i3d_pretrained_{n_class}.pt') 125 | 126 | print('Done') 127 | 128 | -------------------------------------------------------------------------------- /datasets/sign_language.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import os 5 | import pickle 6 | import torch 7 | 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | from tqdm import tqdm 12 | 13 | from .h5 import HDF5Dataset 14 | 15 | 16 | class SignLanguage(Dataset): 17 | 18 | def __init__(self, data_path, frames_per_sample=5, train=True, random_time=True, random_horizontal_flip=True,color_jitter=0, 19 | total_videos=-1, skip_videos=0,image_size=64): 20 | 21 | self.data_path = data_path # '/path/to/Datasets/SignLanguage_h5' (with .hdf5 file in it), or to the hdf5 file itself 22 | self.train = train 23 | self.frames_per_sample = frames_per_sample 24 | self.image_size = image_size 25 | self.random_time = random_time 26 | self.color_jitter = color_jitter 27 | self.random_horizontal_flip = random_horizontal_flip 28 | self.total_videos = total_videos # If we wish to restrict total number of videos (e.g. for val) 29 | self.jitter = transforms.ColorJitter(hue=color_jitter) #改变亮度 30 | # Read h5 files as dataset 31 | self.videos_ds = HDF5Dataset(self.data_path) 32 | 33 | # Train 34 | # Read h5 files as dataset 35 | # self.videos_ds = HDF5Dataset(self.data_path) 36 | 37 | print(f"Dataset length: {self.__len__()}") 38 | 39 | def len_of_vid(self, index): 40 | video_index = index % self.__len__() 41 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 42 | print(shard_idx,' ',idx_in_shard,' ') 43 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 44 | video_len = f['len'][str(idx_in_shard)][()] 45 | return video_len 46 | 47 | def __len__(self): 48 | return self.total_videos if self.total_videos > 0 else len(self.videos_ds) 49 | 50 | def max_index(self): 51 | return len(self.videos_ds) 52 | 53 | def __getitem__(self, index, time_idx=0): 54 | 55 | # Use `index` to select the video, and then 56 | # randomly choose a `frames_per_sample` window of frames in the video 57 | video_index = round(index / (self.__len__() - 1) * (self.max_index() - 1)) 58 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 59 | if int(idx_in_shard) >= self.__len__(): idx_in_shard = str(int(idx_in_shard)) 60 | prefinals = [] 61 | prefinals_depth = [] 62 | prefinals_pose = [] 63 | prefinals_motion = [] 64 | flip_p = np.random.randint(2) == 0 if self.random_horizontal_flip else 0 65 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 66 | 67 | label = f['text'][str(idx_in_shard)][()] 68 | video_len = f['len'][str(idx_in_shard)][()] 69 | if self.random_time and video_len > self.frames_per_sample: 70 | time_idx = np.random.choice(video_len - self.frames_per_sample) 71 | 72 | for i in range(time_idx, min(time_idx + self.frames_per_sample, video_len)): 73 | img = f['video'][str(idx_in_shard)][str(i)][()] 74 | arr = transforms.RandomHorizontalFlip(flip_p)(transforms.ToTensor()(img)) 75 | prefinals.append(arr) 76 | pose = f['pose_video'][str(idx_in_shard)][str(i)][()] 77 | prefinals_pose.append(torch.from_numpy(pose)) 78 | depth = f['depth'][str(idx_in_shard)][str(i)][()] 79 | prefinals_depth.append(torch.from_numpy(depth)) 80 | motion_img = f['motion'][str(idx_in_shard)][str(i)][()] 81 | prefinals_motion.append(torch.from_numpy(motion_img)) 82 | 83 | 84 | data = torch.stack(prefinals) 85 | depth_data = torch.stack(prefinals_depth) 86 | # pose = torch.permute(deepth_data, (0,3,1,2)) 87 | motion_data = torch.stack(prefinals_motion) 88 | pose_data = torch.stack(prefinals_pose) 89 | # motion_data = torch.permute(motion_data, (0, 3, 1, 2)) 90 | data = self.jitter(data) 91 | # print(deepth_data.shape, 'dataloader pose') 92 | # print(motion_data.shape,'dataloader motion') 93 | return data,label,depth_data,motion_data,pose_data 94 | 95 | if __name__ == "__main__": 96 | 97 | data_path ='/sda/home/immc_guest/dy/new_project/datasets/phoenix2014-release_T_en_clip/test/shard_0001.hdf5' 98 | dataset = SignLanguage(os.path.join(data_path), frames_per_sample=2, random_time=True, 99 | random_horizontal_flip=True, 100 | color_jitter= 0) 101 | 102 | dataset.__getitem__(0) -------------------------------------------------------------------------------- /CLIP-main/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /evaluation/nearest_neighbor.py: -------------------------------------------------------------------------------- 1 | """ 2 | prdc 3 | Copyright (c) 2020-present NAVER Corp. 4 | Modified by Yang Song (yangsong@cs.stanford.edu) 5 | MIT license 6 | """ 7 | import sklearn.metrics 8 | import pathlib 9 | 10 | import numpy as np 11 | import torch 12 | from torchvision.datasets import LSUN, CelebA, CIFAR10 13 | from datasets.ffhq import FFHQ 14 | from torch.utils.data import DataLoader 15 | from torchvision.transforms import Compose, Resize, CenterCrop, RandomHorizontalFlip, ToPILImage, ToTensor 16 | from torchvision.utils import save_image 17 | from scipy import linalg 18 | from torch.nn.functional import adaptive_avg_pool2d 19 | import os 20 | import argparse 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--path', type=str, required=True) 23 | parser.add_argument('--k', type=int, default=9) 24 | parser.add_argument('--n_samples', type=int, default=10) 25 | parser.add_argument('--dataset', type=str, required=True) 26 | parser.add_argument('-i', type=str, required=True) 27 | 28 | from PIL import Image 29 | 30 | try: 31 | from tqdm import tqdm 32 | except ImportError: 33 | # If not tqdm is not available, provide a mock version of it 34 | def tqdm(x): return x 35 | 36 | from evaluation.inception import InceptionV3 37 | 38 | def imread(filename): 39 | """ 40 | Loads an image file into a (height, width, 3) uint8 ndarray. 41 | """ 42 | return np.asarray(Image.open(filename), dtype=np.uint8)[..., :3] 43 | 44 | 45 | def get_activations(model, images, dims=2048): 46 | # Reshape to (n_images, 3, height, width) 47 | with torch.no_grad(): 48 | pred = model(images)[0] 49 | 50 | # If model output is not scalar, apply global spatial average pooling. 51 | # This happens if you choose a dimensionality not equal 2048. 52 | if pred.size(2) != 1 or pred.size(3) != 1: 53 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 54 | 55 | return pred.reshape(pred.size(0), -1) 56 | 57 | 58 | def _compute_features_of_path(path, model, batch_size, dims, cuda): 59 | if path.endswith('.npz'): 60 | f = np.load(path) 61 | act = f['features'][:] 62 | f.close() 63 | else: 64 | path = pathlib.Path(path) 65 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 66 | act = get_activations(files, model, batch_size, dims, cuda, verbose=False) 67 | return act 68 | 69 | 70 | def get_nearest_neighbors(dataset, path, name, n_samples, k=10, cuda=True): 71 | if not os.path.exists(path): 72 | raise RuntimeError('Invalid path: %s' % path) 73 | 74 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] 75 | 76 | model = InceptionV3([block_idx]) 77 | if cuda: 78 | model.cuda() 79 | model.eval() 80 | 81 | flipper = RandomHorizontalFlip(p=1.) 82 | to_pil = ToPILImage() 83 | to_tensor = ToTensor() 84 | data_features = [] 85 | data = [] 86 | for x, _ in tqdm(dataset, desc="sweeping the whole dataset"): 87 | if cuda: x = x.cuda() 88 | data_features.append(get_activations(model, x).cpu()) 89 | data.append(x.cpu()) 90 | 91 | data_features = torch.cat(data_features, dim=0) 92 | data = torch.cat(data, dim=0) 93 | 94 | samples = torch.load(path)[:n_samples] 95 | flipped_samples = torch.stack([to_tensor(flipper(to_pil(img))) for img in samples], dim=0) 96 | if cuda: 97 | samples = samples.cuda() 98 | flipped_samples = flipped_samples.cuda() 99 | 100 | sample_features = get_activations(model, samples).cpu() 101 | flip_sample_feature = get_activations(model, flipped_samples).cpu() 102 | sample_cdist = torch.cdist(sample_features, data_features) 103 | flip_sample_cdist = torch.cdist(flip_sample_feature, data_features) 104 | 105 | plot_data = [] 106 | for i in tqdm(range(len(samples)), desc='find nns and save images'): 107 | plot_data.append(samples[i].cpu()) 108 | all_dists = torch.min(sample_cdist[i], flip_sample_cdist[i]) 109 | indices = torch.topk(-all_dists, k=k)[1] 110 | for ind in indices: 111 | plot_data.append(data[ind]) 112 | 113 | plot_data = torch.stack(plot_data, dim=0) 114 | save_image(plot_data, '{}.png'.format(name), nrow=k+1) 115 | 116 | 117 | if __name__ == '__main__': 118 | args = parser.parse_args() 119 | if args.dataset == 'church': 120 | transforms = Compose([ 121 | Resize(96), 122 | CenterCrop(96), 123 | ToTensor() 124 | ]) 125 | dataset = LSUN('exp/datasets/lsun', ['church_outdoor_train'], transform=transforms) 126 | 127 | elif args.dataset == 'tower' or args.dataset == 'bedroom': 128 | transforms = Compose([ 129 | Resize(128), 130 | CenterCrop(128), 131 | ToTensor() 132 | ]) 133 | dataset = LSUN('exp/datasets/lsun', ['{}_train'.format(args.dataset)], transform=transforms) 134 | 135 | elif args.dataset == 'celeba': 136 | transforms = Compose([ 137 | CenterCrop(140), 138 | Resize(64), 139 | ToTensor(), 140 | ]) 141 | dataset = CelebA('exp/datasets/celeba', split='train', transform=transforms) 142 | 143 | elif args.dataset == 'cifar10': 144 | dataset = CIFAR10('exp/datasets/cifar10', train=True, transform=ToTensor()) 145 | elif args.dataset == 'ffhq': 146 | dataset = FFHQ(path='exp/datasets/FFHQ', transform=ToTensor(), resolution=256) 147 | 148 | dataloader = DataLoader(dataset, batch_size=128, drop_last=False) 149 | get_nearest_neighbors(dataloader, args.path, args.i, args.n_samples, args.k, torch.cuda.is_available()) 150 | -------------------------------------------------------------------------------- /models/eval_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import models.dist_model as dist_model 4 | import numpy as np 5 | import models.dist_model as dist_model 6 | 7 | # Taken from https://github.com/psh01087/Vid-ODE/blob/main/eval_models/__init__.py 8 | class PerceptualLoss(torch.nn.Module): 9 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, device='cpu'): # VGG using our perceptually-learned weights (LPIPS metric) 10 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 11 | super(PerceptualLoss, self).__init__() 12 | print('Setting up Perceptual loss...') 13 | self.device = device 14 | self.spatial = spatial 15 | self.model = dist_model.DistModel() 16 | self.model.initialize(model=model, net=net, colorspace=colorspace, spatial=self.spatial, device=device) 17 | print('...[%s] initialized'%self.model.name()) 18 | print('...Done') 19 | 20 | def forward(self, pred, target, normalize=False): 21 | """ 22 | Pred and target are Variables. 23 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 24 | If normalize is False, assumes the images are already between [-1,+1] 25 | Inputs pred and target are Nx3xHxW 26 | Output pytorch Variable N long 27 | """ 28 | 29 | if normalize: 30 | target = 2 * target - 1 31 | pred = 2 * pred - 1 32 | 33 | return self.model.forward(target, pred) 34 | 35 | def normalize_tensor(in_feat,eps=1e-10): 36 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 37 | return in_feat/(norm_factor+eps) 38 | 39 | def l2(p0, p1, range=255.): 40 | return .5*np.mean((p0 / range - p1 / range)**2) 41 | 42 | def psnr(p0, p1, peak=255.): 43 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 44 | 45 | #def dssim(p0, p1, range=255.): 46 | # return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 47 | 48 | def rgb2lab(in_img,mean_cent=False): 49 | from skimage import color 50 | img_lab = color.rgb2lab(in_img) 51 | if(mean_cent): 52 | img_lab[:,:,0] = img_lab[:,:,0]-50 53 | return img_lab 54 | 55 | def tensor2np(tensor_obj): 56 | # change dimension of a tensor object into a numpy array 57 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 58 | 59 | def np2tensor(np_obj): 60 | # change dimenion of np array into tensor array 61 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 62 | 63 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 64 | # image tensor to lab tensor 65 | from skimage import color 66 | 67 | img = tensor2im(image_tensor) 68 | img_lab = color.rgb2lab(img) 69 | if(mc_only): 70 | img_lab[:,:,0] = img_lab[:,:,0]-50 71 | if(to_norm and not mc_only): 72 | img_lab[:,:,0] = img_lab[:,:,0]-50 73 | img_lab = img_lab/100. 74 | 75 | return np2tensor(img_lab) 76 | 77 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 78 | from skimage import color 79 | import warnings 80 | warnings.filterwarnings("ignore") 81 | 82 | lab = tensor2np(lab_tensor)*100. 83 | lab[:,:,0] = lab[:,:,0]+50 84 | 85 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 86 | if(return_inbnd): 87 | # convert back to lab, see if we match 88 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 89 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 90 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 91 | return (im2tensor(rgb_back),mask) 92 | else: 93 | return im2tensor(rgb_back) 94 | 95 | def rgb2lab(input): 96 | from skimage import color 97 | return color.rgb2lab(input / 255.) 98 | 99 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 100 | image_numpy = image_tensor[0].cpu().float().numpy() 101 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 102 | return image_numpy.astype(imtype) 103 | 104 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 105 | return torch.Tensor((image / factor - cent) 106 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 107 | 108 | def tensor2vec(vector_tensor): 109 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 110 | 111 | def voc_ap(rec, prec, use_07_metric=False): 112 | """ ap = voc_ap(rec, prec, [use_07_metric]) 113 | Compute VOC AP given precision and recall. 114 | If use_07_metric is true, uses the 115 | VOC 07 11 point method (default:False). 116 | """ 117 | if use_07_metric: 118 | # 11 point metric 119 | ap = 0. 120 | for t in np.arange(0., 1.1, 0.1): 121 | if np.sum(rec >= t) == 0: 122 | p = 0 123 | else: 124 | p = np.max(prec[rec >= t]) 125 | ap = ap + p / 11. 126 | else: 127 | # correct AP calculation 128 | # first append sentinel values at the end 129 | mrec = np.concatenate(([0.], rec, [1.])) 130 | mpre = np.concatenate(([0.], prec, [0.])) 131 | 132 | # compute the precision envelope 133 | for i in range(mpre.size - 1, 0, -1): 134 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 135 | 136 | # to calculate area under PR curve, look for points 137 | # where X axis (recall) changes value 138 | i = np.where(mrec[1:] != mrec[:-1])[0] 139 | 140 | # and sum (\Delta recall) * prec 141 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 142 | return ap 143 | 144 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 145 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 146 | image_numpy = image_tensor[0].cpu().float().numpy() 147 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 148 | return image_numpy.astype(imtype) 149 | 150 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 151 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 152 | return torch.Tensor((image / factor - cent) 153 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from torch.utils.model_zoo import tqdm 6 | 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | 20 | def check_integrity(fpath, md5=None): 21 | if md5 is None: 22 | return True 23 | if not os.path.isfile(fpath): 24 | return False 25 | md5o = hashlib.md5() 26 | with open(fpath, 'rb') as f: 27 | # read in 1MB chunks 28 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 29 | md5o.update(chunk) 30 | md5c = md5o.hexdigest() 31 | if md5c != md5: 32 | return False 33 | return True 34 | 35 | 36 | def makedir_exist_ok(dirpath): 37 | """ 38 | Python2 support for os.makedirs(.., exist_ok=True) 39 | """ 40 | try: 41 | os.makedirs(dirpath) 42 | except OSError as e: 43 | if e.errno == errno.EEXIST: 44 | pass 45 | else: 46 | raise 47 | 48 | 49 | def download_url(url, root, filename=None, md5=None): 50 | """Download a file from a url and place it in root. 51 | 52 | Args: 53 | url (str): URL to download file from 54 | root (str): Directory to place downloaded file in 55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 56 | md5 (str, optional): MD5 checksum of the download. If None, do not check 57 | """ 58 | from six.moves import urllib 59 | 60 | root = os.path.expanduser(root) 61 | if not filename: 62 | filename = os.path.basename(url) 63 | fpath = os.path.join(root, filename) 64 | 65 | makedir_exist_ok(root) 66 | 67 | # downloads file 68 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 69 | print('Using downloaded and verified file: ' + fpath) 70 | else: 71 | try: 72 | print('Downloading ' + url + ' to ' + fpath) 73 | urllib.request.urlretrieve( 74 | url, fpath, 75 | reporthook=gen_bar_updater() 76 | ) 77 | except OSError: 78 | if url[:5] == 'https': 79 | url = url.replace('https:', 'http:') 80 | print('Failed download. Trying https -> http instead.' 81 | ' Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | 87 | 88 | def list_dir(root, prefix=False): 89 | """List all directories at a given root 90 | 91 | Args: 92 | root (str): Path to directory whose folders need to be listed 93 | prefix (bool, optional): If true, prepends the path to each result, otherwise 94 | only returns the name of the directories found 95 | """ 96 | root = os.path.expanduser(root) 97 | directories = list( 98 | filter( 99 | lambda p: os.path.isdir(os.path.join(root, p)), 100 | os.listdir(root) 101 | ) 102 | ) 103 | 104 | if prefix is True: 105 | directories = [os.path.join(root, d) for d in directories] 106 | 107 | return directories 108 | 109 | 110 | def list_files(root, suffix, prefix=False): 111 | """List all files ending with a suffix at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 116 | It uses the Python "str.endswith" method and is passed directly 117 | prefix (bool, optional): If true, prepends the path to each result, otherwise 118 | only returns the name of the files found 119 | """ 120 | root = os.path.expanduser(root) 121 | files = list( 122 | filter( 123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 124 | os.listdir(root) 125 | ) 126 | ) 127 | 128 | if prefix is True: 129 | files = [os.path.join(root, d) for d in files] 130 | 131 | return files 132 | 133 | 134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 135 | """Download a Google Drive file from and place it in root. 136 | 137 | Args: 138 | file_id (str): id of file to be downloaded 139 | root (str): Directory to place downloaded file in 140 | filename (str, optional): Name to save the file under. If None, use the id of the file. 141 | md5 (str, optional): MD5 checksum of the download. If None, do not check 142 | """ 143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 144 | import requests 145 | url = "https://docs.google.com/uc?export=download" 146 | 147 | root = os.path.expanduser(root) 148 | if not filename: 149 | filename = file_id 150 | fpath = os.path.join(root, filename) 151 | 152 | makedir_exist_ok(root) 153 | 154 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 155 | print('Using downloaded and verified file: ' + fpath) 156 | else: 157 | session = requests.Session() 158 | 159 | response = session.get(url, params={'id': file_id}, stream=True) 160 | token = _get_confirm_token(response) 161 | 162 | if token: 163 | params = {'id': file_id, 'confirm': token} 164 | response = session.get(url, params=params, stream=True) 165 | 166 | _save_response_content(response, fpath) 167 | 168 | 169 | def _get_confirm_token(response): 170 | for key, value in response.cookies.items(): 171 | if key.startswith('download_warning'): 172 | return value 173 | 174 | return None 175 | 176 | 177 | def _save_response_content(response, destination, chunk_size=32768): 178 | with open(destination, "wb") as f: 179 | pbar = tqdm(total=None) 180 | progress = 0 181 | for chunk in response.iter_content(chunk_size): 182 | if chunk: # filter out keep-alive new chunks 183 | f.write(chunk) 184 | progress += len(chunk) 185 | pbar.update(progress - pbar.n) 186 | pbar.close() 187 | -------------------------------------------------------------------------------- /models/better/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | 19 | import torch 20 | import sde_lib 21 | import numpy as np 22 | 23 | 24 | _MODELS = {} 25 | 26 | 27 | def register_model(cls=None, *, name=None): 28 | """A decorator for registering model classes.""" 29 | 30 | def _register(cls): 31 | if name is None: 32 | local_name = cls.__name__ 33 | else: 34 | local_name = name 35 | if local_name in _MODELS: 36 | raise ValueError(f'Already registered model with name: {local_name}') 37 | _MODELS[local_name] = cls 38 | return cls 39 | 40 | if cls is None: 41 | return _register 42 | else: 43 | return _register(cls) 44 | 45 | 46 | def get_model(name): 47 | return _MODELS[name] 48 | 49 | 50 | def get_sigmas(config): 51 | """Get sigmas --- the set of noise levels for SMLD from config files. 52 | Args: 53 | config: A ConfigDict object parsed from the config file 54 | Returns: 55 | sigmas: a jax numpy arrary of noise levels 56 | """ 57 | sigmas = np.exp( 58 | np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) 59 | 60 | return sigmas 61 | 62 | 63 | def get_ddpm_params(config): 64 | """Get betas and alphas --- parameters used in the original DDPM paper.""" 65 | num_diffusion_timesteps = 1000 66 | # parameters need to be adapted if number of time steps differs from 1000 67 | beta_start = config.model.beta_min / config.model.num_scales 68 | beta_end = config.model.beta_max / config.model.num_scales 69 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 70 | 71 | alphas = 1. - betas 72 | alphas_cumprod = np.cumprod(alphas, axis=0) 73 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 74 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 75 | 76 | return { 77 | 'betas': betas, 78 | 'alphas': alphas, 79 | 'alphas_cumprod': alphas_cumprod, 80 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 81 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 82 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 83 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 84 | 'num_diffusion_timesteps': num_diffusion_timesteps 85 | } 86 | 87 | 88 | def create_model(config): 89 | """Create the score model.""" 90 | model_name = config.model.name 91 | score_model = get_model(model_name)(config) 92 | score_model = score_model.to(config.device) 93 | score_model = torch.nn.DataParallel(score_model) 94 | return score_model 95 | 96 | 97 | def get_model_fn(model, train=False): 98 | """Create a function to give the output of the score-based model. 99 | 100 | Args: 101 | model: The score model. 102 | train: `True` for training and `False` for evaluation. 103 | 104 | Returns: 105 | A model function. 106 | """ 107 | 108 | def model_fn(x, labels): 109 | """Compute the output of the score-based model. 110 | 111 | Args: 112 | x: A mini-batch of input data. 113 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 114 | for different models. 115 | 116 | Returns: 117 | A tuple of (model output, new mutable states) 118 | """ 119 | if not train: 120 | model.eval() 121 | return model(x, labels) 122 | else: 123 | model.train() 124 | return model(x, labels) 125 | 126 | return model_fn 127 | 128 | 129 | def get_score_fn(sde, model, train=False, continuous=False): 130 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 131 | 132 | Args: 133 | sde: An `sde_lib.SDE` object that represents the forward SDE. 134 | model: A score model. 135 | train: `True` for training and `False` for evaluation. 136 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 137 | 138 | Returns: 139 | A score function. 140 | """ 141 | model_fn = get_model_fn(model, train=train) 142 | 143 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 144 | def score_fn(x, t): 145 | # Scale neural network output by standard deviation and flip sign 146 | if continuous or isinstance(sde, sde_lib.subVPSDE): 147 | # For VP-trained models, t=0 corresponds to the lowest noise level 148 | # The maximum value of time embedding is assumed to 999 for 149 | # continuously-trained models. 150 | labels = t * 999 151 | score = model_fn(x, labels) 152 | std = sde.marginal_prob(torch.zeros_like(x), t)[1] 153 | else: 154 | # For VP-trained models, t=0 corresponds to the lowest noise level 155 | labels = t * (sde.N - 1) 156 | score = model_fn(x, labels) 157 | std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] 158 | 159 | score = -score / std[:, None, None, None] 160 | return score 161 | 162 | elif isinstance(sde, sde_lib.VESDE): 163 | def score_fn(x, t): 164 | if continuous: 165 | labels = sde.marginal_prob(torch.zeros_like(x), t)[1] 166 | else: 167 | # For VE-trained models, t=0 corresponds to the highest noise level 168 | labels = sde.T - t 169 | labels *= sde.N - 1 170 | labels = torch.round(labels).long() 171 | 172 | score = model_fn(x, labels) 173 | return score 174 | 175 | else: 176 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 177 | 178 | return score_fn 179 | 180 | 181 | def to_flattened_numpy(x): 182 | """Flatten a torch tensor `x` and convert it to numpy.""" 183 | return x.detach().cpu().numpy().reshape((-1,)) 184 | 185 | 186 | def from_flattened_numpy(x, shape): 187 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 188 | return torch.from_numpy(x.reshape(shape)) -------------------------------------------------------------------------------- /models/gaussianDf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from video_diffusion_pytorch import Unet3D 6 | 7 | BERT_MODEL_DIM = 768 8 | 9 | 10 | # helpers functions 11 | 12 | def exists(x): 13 | return x is not None 14 | 15 | 16 | def noop(*args, **kwargs): 17 | pass 18 | 19 | 20 | def is_odd(n): 21 | return (n % 2) == 1 22 | 23 | 24 | def default(val, d): 25 | if exists(val): 26 | return val 27 | return d() if callable(d) else d 28 | 29 | 30 | def cycle(dl): 31 | while True: 32 | for data in dl: 33 | yield data 34 | 35 | 36 | def num_to_groups(num, divisor): 37 | groups = num // divisor 38 | remainder = num % divisor 39 | arr = [divisor] * groups 40 | if remainder > 0: 41 | arr.append(remainder) 42 | return arr 43 | 44 | 45 | def prob_mask_like(shape, prob, device): 46 | if prob == 1: 47 | return torch.ones(shape, device=device, dtype=torch.bool) 48 | elif prob == 0: 49 | return torch.zeros(shape, device=device, dtype=torch.bool) 50 | else: 51 | return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob 52 | 53 | 54 | def is_list_str(x): 55 | if not isinstance(x, (list, tuple)): 56 | return False 57 | return all([type(el) == str for el in x]) 58 | 59 | 60 | # gaussian diffusion trainer class 61 | # todo b 应该是batch 62 | def extract(a, t, x_shape): 63 | b, *_ = t.shape 64 | out = a.gather(-1, t) 65 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 66 | 67 | 68 | def cosine_beta_schedule(timesteps, s=0.008): 69 | """ 70 | cosine schedule 71 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 72 | """ 73 | steps = timesteps + 1 74 | x = torch.linspace(0, timesteps, steps, dtype=torch.float64) 75 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 76 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 77 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 78 | return torch.clip(betas, 0, 0.9999) 79 | 80 | 81 | class GaussianDiffusion(nn.Module): 82 | def __init__( 83 | self, 84 | # video, 85 | # t, 86 | # cond, 87 | *, 88 | image_size=128, 89 | num_frames=7, 90 | text_use_bert_cls=True, 91 | channels=3, 92 | timesteps=1000, 93 | loss_type='l1', 94 | use_dynamic_thres=False, # from the Imagen paper 95 | dynamic_thres_percentile=0.9 96 | ): 97 | super().__init__() 98 | self.channels = channels 99 | self.image_size = image_size 100 | self.num_frames = num_frames 101 | self.denoise = Unet3D(dim=64, 102 | use_bert_text_cond=True, 103 | # this must be set to True to auto-use the bert model dimensions 104 | dim_mults=(1, 2, 4, 8), ) 105 | betas = cosine_beta_schedule(timesteps) 106 | 107 | alphas = 1. - betas 108 | alphas_cumprod = torch.cumprod(alphas, axis=0) 109 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) 110 | 111 | timesteps, = betas.shape 112 | self.num_timesteps = int(timesteps) 113 | self.loss_type = loss_type 114 | 115 | # register buffer helper function that casts float64 to float32 116 | 117 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) 118 | 119 | register_buffer('betas', betas) 120 | register_buffer('alphas_cumprod', alphas_cumprod) 121 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 122 | 123 | # calculations for diffusion q(x_t | x_{t-1}) and others 124 | 125 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 126 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 127 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 128 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 129 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 130 | 131 | # calculations for posterior q(x_{t-1} | x_t, x_0) 132 | 133 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 134 | 135 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 136 | 137 | register_buffer('posterior_variance', posterior_variance) 138 | 139 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 140 | 141 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20))) 142 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 143 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 144 | 145 | # text conditioning parameters 146 | 147 | self.text_use_bert_cls = text_use_bert_cls 148 | 149 | # dynamic thresholding when sampling 150 | 151 | self.use_dynamic_thres = use_dynamic_thres 152 | self.dynamic_thres_percentile = dynamic_thres_percentile 153 | 154 | def predict_start_from_noise(self, x_t, t, noise): 155 | return ( 156 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 157 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 158 | ) 159 | 160 | @torch.inference_mode() 161 | def diffusionwithCond(self, x, t,cond=None, cond_scale=1.): 162 | 163 | x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise.forward_with_cond_scale(x, t, cond=cond,cond_scale=cond_scale)) 164 | # x_recon = self.predict_start_from_noise(x, t=t, noise=Unet3D().forward_with_cond_scale(x, t, cond=cond,cond_scale=cond_scale)) 165 | return x_recon 166 | 167 | def forward(self, video, text, *args, **kwargs): 168 | t = torch.full((video.shape[0],), 1, dtype=torch.long,).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 169 | z = self.diffusionwithCond(video,t,text).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 170 | z_t = z.transpose(1, 2).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 171 | z = z_t[:, 2:7].reshape(len(z_t), -1, 128, 128).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 172 | print(z.shape) 173 | return z 174 | -------------------------------------------------------------------------------- /load_model_from_ckpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import torch 5 | import yaml 6 | 7 | from collections import OrderedDict 8 | from functools import partial 9 | from imageio import mimwrite 10 | from torch.utils.data import DataLoader 11 | from torchvision.utils import make_grid, save_image 12 | 13 | try: 14 | from torchvision.transforms.functional import resize, InterpolationMode 15 | interp = InterpolationMode.NEAREST 16 | except: 17 | from torchvision.transforms.functional import resize 18 | interp = 0 19 | 20 | from datasets import get_dataset, data_transform, inverse_data_transform 21 | from main import dict2namespace 22 | from models import get_sigmas, anneal_Langevin_dynamics, anneal_Langevin_dynamics_consistent, ddpm_sampler, ddim_sampler, FPNDM_sampler 23 | from models.ema import EMAHelper 24 | from runners.ncsn_runner import get_model 25 | 26 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 27 | # device = torch.device('cpu') 28 | 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 32 | parser.add_argument('--ckpt_path', type=str, required=True, help='Path to checkpoint.pt') 33 | parser.add_argument('--data_path', type=str, default='/mnt/data/scratch/data/CIFAR10', help='Path to the dataset') 34 | args = parser.parse_args() 35 | return args.ckpt_path, args.data_path 36 | 37 | 38 | # Make and load model 39 | def load_model(ckpt_path, device=device): 40 | # Parse config file 41 | with open(os.path.join(os.path.dirname(ckpt_path), 'config.yml'), 'r') as f: 42 | config = yaml.load(f, Loader=yaml.FullLoader) 43 | # Load config file 44 | config = dict2namespace(config) 45 | config.device = device 46 | # Load model 47 | scorenet = get_model(config) 48 | if config.device != torch.device('cpu'): 49 | scorenet = torch.nn.DataParallel(scorenet) 50 | states = torch.load(ckpt_path, map_location=config.device) 51 | else: 52 | states = torch.load(ckpt_path, map_location='cpu') 53 | states[0] = OrderedDict([(k.replace('module.', ''), v) for k, v in states[0].items()]) 54 | scorenet.load_state_dict(states[0], strict=False) 55 | if config.model.ema: 56 | ema_helper = EMAHelper(mu=config.model.ema_rate) 57 | ema_helper.register(scorenet) 58 | ema_helper.load_state_dict(states[-1]) 59 | ema_helper.ema(scorenet) 60 | scorenet.eval() 61 | return scorenet, config 62 | 63 | 64 | def get_sampler_from_config(config): 65 | version = getattr(config.model, 'version', "DDPM") 66 | # Sampler 67 | if version == "SMLD": 68 | consistent = getattr(config.sampling, 'consistent', False) 69 | sampler = anneal_Langevin_dynamics_consistent if consistent else anneal_Langevin_dynamics 70 | elif version == "DDPM": 71 | sampler = partial(ddpm_sampler, config=config) 72 | elif version == "DDIM": 73 | sampler = partial(ddim_sampler, config=config) 74 | elif version == "FPNDM": 75 | sampler = partial(FPNDM_sampler, config=config) 76 | return sampler 77 | 78 | 79 | def get_sampler(config): 80 | sampler = get_sampler_from_config(config) 81 | sampler_partial = partial(sampler, n_steps_each=config.sampling.n_steps_each, 82 | step_lr=config.sampling.step_lr, just_beta=False, 83 | final_only=True, denoise=config.sampling.denoise, 84 | subsample_steps=getattr(config.sampling, 'subsample', None), 85 | clip_before=getattr(config.sampling, 'clip_before', True), 86 | verbose=False, log=False, gamma=getattr(config.model, 'gamma', False)) 87 | def sampler_fn(init, scorenet, cond, cond_mask, subsample=getattr(config.sampling, 'subsample', None), verbose=False): 88 | init = init.to(config.device) 89 | cond = cond.to(config.device) 90 | if cond_mask is not None: 91 | cond_mask = cond_mask.to(config.device) 92 | return inverse_data_transform(config, sampler_partial(init, scorenet, cond=cond, cond_mask=cond_mask, 93 | subsample_steps=subsample, verbose=verbose)[-1].to('cpu')) 94 | return sampler_fn 95 | 96 | 97 | def init_samples(n_init_samples, config): 98 | # Initial samples 99 | # n_init_samples = min(36, config.training.batch_size) 100 | version = getattr(config.model, 'version', "DDPM") 101 | init_samples_shape = (n_init_samples, config.data.channels*config.data.num_frames, config.data.image_size, config.data.image_size) 102 | if version == "SMLD": 103 | init_samples = torch.rand(init_samples_shape) 104 | init_samples = data_transform(self.config, init_samples) 105 | elif version == "DDPM" or self.version == "DDIM" or self.version == "FPNDM": 106 | if getattr(config.model, 'gamma', False): 107 | used_k, used_theta = net.k_cum[0], net.theta_t[0] 108 | z = Gamma(torch.full(init_samples_shape, used_k), torch.full(init_samples_shape, 1 / used_theta)).sample().to(config.device) 109 | init_samples = z - used_k*used_theta # we don't scale here 110 | else: 111 | init_samples = torch.randn(init_samples_shape) 112 | return init_samples 113 | 114 | 115 | if __name__ == '__main__': 116 | # data_path = '/path/to/data/CIFAR10' 117 | ckpt_path, data_path = parse_args() 118 | 119 | scorenet, config = load_model(ckpt_path, device) 120 | 121 | # Initial samples 122 | dataset, test_dataset = get_dataset(data_path, config) 123 | dataloader = DataLoader(dataset, batch_size=config.training.batch_size, shuffle=True, 124 | num_workers=config.data.num_workers) 125 | train_iter = iter(dataloader) 126 | x, y = next(train_iter) 127 | test_loader = DataLoader(test_dataset, batch_size=config.training.batch_size, shuffle=False, 128 | num_workers=config.data.num_workers, drop_last=True) 129 | test_iter = iter(test_loader) 130 | test_x, test_y = next(test_iter) 131 | 132 | net = scorenet.module if hasattr(scorenet, 'module') else scorenet 133 | version = getattr(net, 'version', 'SMLD').upper() 134 | net_type = getattr(net, 'type') if isinstance(getattr(net, 'type'), str) else 'v1' 135 | 136 | if version == "SMLD": 137 | sigmas = net.sigmas 138 | labels = torch.randint(0, len(sigmas), (x.shape[0],), device=x.device) 139 | used_sigmas = sigmas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 140 | device = sigmas.device 141 | 142 | elif version == "DDPM" or version == "DDIM": 143 | alphas = net.alphas 144 | labels = torch.randint(0, len(alphas), (x.shape[0],), device=x.device) 145 | used_alphas = alphas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 146 | device = alphas.device 147 | 148 | 149 | # CUDA_VISIBLE_DEVICES=3 python -i load_model_from_ckpt.py --ckpt_path /path/to/ncsnv2/cifar10/BASELINE_DDPM_800k/logs/checkpoint.pt 150 | -------------------------------------------------------------------------------- /models/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | class squeezenet(torch.nn.Module): 6 | def __init__(self, requires_grad=False, pretrained=True): 7 | super(squeezenet, self).__init__() 8 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 9 | self.slice1 = torch.nn.Sequential() 10 | self.slice2 = torch.nn.Sequential() 11 | self.slice3 = torch.nn.Sequential() 12 | self.slice4 = torch.nn.Sequential() 13 | self.slice5 = torch.nn.Sequential() 14 | self.slice6 = torch.nn.Sequential() 15 | self.slice7 = torch.nn.Sequential() 16 | self.N_slices = 7 17 | for x in range(2): 18 | self.slice1.add_module(str(x), pretrained_features[x]) 19 | for x in range(2,5): 20 | self.slice2.add_module(str(x), pretrained_features[x]) 21 | for x in range(5, 8): 22 | self.slice3.add_module(str(x), pretrained_features[x]) 23 | for x in range(8, 10): 24 | self.slice4.add_module(str(x), pretrained_features[x]) 25 | for x in range(10, 11): 26 | self.slice5.add_module(str(x), pretrained_features[x]) 27 | for x in range(11, 12): 28 | self.slice6.add_module(str(x), pretrained_features[x]) 29 | for x in range(12, 13): 30 | self.slice7.add_module(str(x), pretrained_features[x]) 31 | if not requires_grad: 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward(self, X): 36 | h = self.slice1(X) 37 | h_relu1 = h 38 | h = self.slice2(h) 39 | h_relu2 = h 40 | h = self.slice3(h) 41 | h_relu3 = h 42 | h = self.slice4(h) 43 | h_relu4 = h 44 | h = self.slice5(h) 45 | h_relu5 = h 46 | h = self.slice6(h) 47 | h_relu6 = h 48 | h = self.slice7(h) 49 | h_relu7 = h 50 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 51 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 52 | 53 | return out 54 | 55 | 56 | class alexnet(torch.nn.Module): 57 | def __init__(self, requires_grad=False, pretrained=True): 58 | super(alexnet, self).__init__() 59 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 60 | self.slice1 = torch.nn.Sequential() 61 | self.slice2 = torch.nn.Sequential() 62 | self.slice3 = torch.nn.Sequential() 63 | self.slice4 = torch.nn.Sequential() 64 | self.slice5 = torch.nn.Sequential() 65 | self.N_slices = 5 66 | for x in range(2): 67 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 68 | for x in range(2, 5): 69 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 70 | for x in range(5, 8): 71 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 72 | for x in range(8, 10): 73 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 74 | for x in range(10, 12): 75 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 76 | if not requires_grad: 77 | for param in self.parameters(): 78 | param.requires_grad = False 79 | 80 | def forward(self, X): 81 | h = self.slice1(X) 82 | h_relu1 = h 83 | h = self.slice2(h) 84 | h_relu2 = h 85 | h = self.slice3(h) 86 | h_relu3 = h 87 | h = self.slice4(h) 88 | h_relu4 = h 89 | h = self.slice5(h) 90 | h_relu5 = h 91 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 92 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 93 | 94 | return out 95 | 96 | class vgg16(torch.nn.Module): 97 | def __init__(self, requires_grad=False, pretrained=True): 98 | super(vgg16, self).__init__() 99 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 100 | self.slice1 = torch.nn.Sequential() 101 | self.slice2 = torch.nn.Sequential() 102 | self.slice3 = torch.nn.Sequential() 103 | self.slice4 = torch.nn.Sequential() 104 | self.slice5 = torch.nn.Sequential() 105 | self.N_slices = 5 106 | for x in range(4): 107 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 108 | for x in range(4, 9): 109 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(9, 16): 111 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(16, 23): 113 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(23, 30): 115 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 116 | if not requires_grad: 117 | for param in self.parameters(): 118 | param.requires_grad = False 119 | 120 | def forward(self, X): 121 | h = self.slice1(X) 122 | h_relu1_2 = h 123 | h = self.slice2(h) 124 | h_relu2_2 = h 125 | h = self.slice3(h) 126 | h_relu3_3 = h 127 | h = self.slice4(h) 128 | h_relu4_3 = h 129 | h = self.slice5(h) 130 | h_relu5_3 = h 131 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 132 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 133 | 134 | return out 135 | 136 | 137 | 138 | class resnet(torch.nn.Module): 139 | def __init__(self, requires_grad=False, pretrained=True, num=18): 140 | super(resnet, self).__init__() 141 | if(num==18): 142 | self.net = tv.resnet18(pretrained=pretrained) 143 | elif(num==34): 144 | self.net = tv.resnet34(pretrained=pretrained) 145 | elif(num==50): 146 | self.net = tv.resnet50(pretrained=pretrained) 147 | elif(num==101): 148 | self.net = tv.resnet101(pretrained=pretrained) 149 | elif(num==152): 150 | self.net = tv.resnet152(pretrained=pretrained) 151 | self.N_slices = 5 152 | 153 | self.conv1 = self.net.conv1 154 | self.bn1 = self.net.bn1 155 | self.relu = self.net.relu 156 | self.maxpool = self.net.maxpool 157 | self.layer1 = self.net.layer1 158 | self.layer2 = self.net.layer2 159 | self.layer3 = self.net.layer3 160 | self.layer4 = self.net.layer4 161 | 162 | def forward(self, X): 163 | h = self.conv1(X) 164 | h = self.bn1(h) 165 | h = self.relu(h) 166 | h_relu1 = h 167 | h = self.maxpool(h) 168 | h = self.layer1(h) 169 | h_conv2 = h 170 | h = self.layer2(h) 171 | h_conv3 = h 172 | h = self.layer3(h) 173 | h_conv4 = h 174 | h = self.layer4(h) 175 | h_conv5 = h 176 | 177 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 178 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 179 | 180 | return out 181 | -------------------------------------------------------------------------------- /models/better/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.system("unset TORCH_CUDA_ARCH_LIST") 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | print(module_path) 12 | 13 | 14 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 15 | if input.device.type == "cpu": 16 | out = upfirdn2d_native( 17 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 18 | ) 19 | 20 | else: 21 | 22 | upfirdn2d_op = load( 23 | "upfirdn2d", 24 | sources=[ 25 | os.path.join(module_path, "upfirdn2d.cpp"), 26 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 27 | ], 28 | ) 29 | 30 | class UpFirDn2dBackward(Function): 31 | 32 | @staticmethod 33 | def forward( 34 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 35 | ): 36 | 37 | up_x, up_y = up 38 | down_x, down_y = down 39 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 40 | 41 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 42 | 43 | grad_input = upfirdn2d_op.upfirdn2d( 44 | grad_output, 45 | grad_kernel, 46 | down_x, 47 | down_y, 48 | up_x, 49 | up_y, 50 | g_pad_x0, 51 | g_pad_x1, 52 | g_pad_y0, 53 | g_pad_y1, 54 | ) 55 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 56 | 57 | ctx.save_for_backward(kernel) 58 | 59 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 60 | 61 | ctx.up_x = up_x 62 | ctx.up_y = up_y 63 | ctx.down_x = down_x 64 | ctx.down_y = down_y 65 | ctx.pad_x0 = pad_x0 66 | ctx.pad_x1 = pad_x1 67 | ctx.pad_y0 = pad_y0 68 | ctx.pad_y1 = pad_y1 69 | ctx.in_size = in_size 70 | ctx.out_size = out_size 71 | 72 | return grad_input 73 | 74 | @staticmethod 75 | def backward(ctx, gradgrad_input): 76 | kernel, = ctx.saved_tensors 77 | 78 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 79 | 80 | gradgrad_out = upfirdn2d_op.upfirdn2d( 81 | gradgrad_input, 82 | kernel, 83 | ctx.up_x, 84 | ctx.up_y, 85 | ctx.down_x, 86 | ctx.down_y, 87 | ctx.pad_x0, 88 | ctx.pad_x1, 89 | ctx.pad_y0, 90 | ctx.pad_y1, 91 | ) 92 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 93 | gradgrad_out = gradgrad_out.view( 94 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 95 | ) 96 | 97 | return gradgrad_out, None, None, None, None, None, None, None, None 98 | 99 | 100 | class UpFirDn2d(Function): 101 | 102 | @staticmethod 103 | def forward(ctx, input, kernel, up, down, pad): 104 | up_x, up_y = up 105 | down_x, down_y = down 106 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 107 | 108 | kernel_h, kernel_w = kernel.shape 109 | batch, channel, in_h, in_w = input.shape 110 | ctx.in_size = input.shape 111 | 112 | input = input.reshape(-1, in_h, in_w, 1) 113 | 114 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 115 | 116 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 117 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 118 | ctx.out_size = (out_h, out_w) 119 | 120 | ctx.up = (up_x, up_y) 121 | ctx.down = (down_x, down_y) 122 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 123 | 124 | g_pad_x0 = kernel_w - pad_x0 - 1 125 | g_pad_y0 = kernel_h - pad_y0 - 1 126 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 127 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 128 | 129 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 130 | 131 | out = upfirdn2d_op.upfirdn2d( 132 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 133 | ) 134 | # out = out.view(major, out_h, out_w, minor) 135 | out = out.view(-1, channel, out_h, out_w) 136 | 137 | return out 138 | 139 | @staticmethod 140 | def backward(ctx, grad_output): 141 | kernel, grad_kernel = ctx.saved_tensors 142 | 143 | grad_input = UpFirDn2dBackward.apply( 144 | grad_output, 145 | kernel, 146 | grad_kernel, 147 | ctx.up, 148 | ctx.down, 149 | ctx.pad, 150 | ctx.g_pad, 151 | ctx.in_size, 152 | ctx.out_size, 153 | ) 154 | 155 | return grad_input, None, None, None, None 156 | 157 | out = UpFirDn2d.apply( 158 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 159 | ) 160 | 161 | return out 162 | 163 | 164 | def upfirdn2d_native( 165 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 166 | ): 167 | _, channel, in_h, in_w = input.shape 168 | input = input.reshape(-1, in_h, in_w, 1) 169 | 170 | _, in_h, in_w, minor = input.shape 171 | kernel_h, kernel_w = kernel.shape 172 | 173 | out = input.view(-1, in_h, 1, in_w, 1, minor) 174 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 175 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 176 | 177 | out = F.pad( 178 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 179 | ) 180 | out = out[ 181 | :, 182 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 183 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 184 | :, 185 | ] 186 | 187 | out = out.permute(0, 3, 1, 2) 188 | out = out.reshape( 189 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 190 | ) 191 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 192 | out = F.conv2d(out, w) 193 | out = out.reshape( 194 | -1, 195 | minor, 196 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 197 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 198 | ) 199 | out = out.permute(0, 2, 3, 1) 200 | out = out[:, ::down_y, ::down_x, :] 201 | 202 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 203 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 204 | 205 | return out.view(-1, channel, out_h, out_w) 206 | -------------------------------------------------------------------------------- /models/networks_basic.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from skimage import color 11 | from . import pretrained_networks as pn 12 | 13 | from . import eval_models as util 14 | 15 | def spatial_average(in_tens, keepdim=True): 16 | return in_tens.mean([2,3],keepdim=keepdim) 17 | 18 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 19 | in_H = in_tens.shape[2] 20 | scale_factor = 1.*out_H/in_H 21 | 22 | return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) 23 | 24 | # Learned perceptual metric 25 | class PNetLin(nn.Module): 26 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): 27 | super(PNetLin, self).__init__() 28 | 29 | self.pnet_type = pnet_type 30 | self.pnet_tune = pnet_tune 31 | self.pnet_rand = pnet_rand 32 | self.spatial = spatial 33 | self.lpips = lpips 34 | self.version = version 35 | self.scaling_layer = ScalingLayer() 36 | 37 | if(self.pnet_type in ['vgg','vgg16']): 38 | net_type = pn.vgg16 39 | self.chns = [64,128,256,512,512] 40 | elif(self.pnet_type=='alex'): 41 | net_type = pn.alexnet 42 | self.chns = [64,192,384,256,256] 43 | elif(self.pnet_type=='squeeze'): 44 | net_type = pn.squeezenet 45 | self.chns = [64,128,256,384,384,512,512] 46 | self.L = len(self.chns) 47 | 48 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 49 | 50 | if(lpips): 51 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 52 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 53 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 54 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 55 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 56 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 57 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 58 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 59 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 60 | self.lins+=[self.lin5,self.lin6] 61 | 62 | def forward(self, in0, in1, retPerLayer=False): 63 | # v0.0 - original release had a bug, where input was not scaled 64 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 65 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 66 | feats0, feats1, diffs = {}, {}, {} 67 | 68 | for kk in range(self.L): 69 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) 70 | diffs[kk] = (feats0[kk]-feats1[kk])**2 71 | if(self.lpips): 72 | if(self.spatial): 73 | res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] 74 | else: 75 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 76 | else: 77 | if(self.spatial): 78 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] 79 | else: 80 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 81 | val = res[0] 82 | for l in range(1,self.L): 83 | val += res[l] 84 | 85 | if(retPerLayer): 86 | return (val, res) 87 | else: 88 | return val 89 | 90 | class ScalingLayer(nn.Module): 91 | def __init__(self): 92 | super(ScalingLayer, self).__init__() 93 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 94 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 95 | 96 | def forward(self, inp): 97 | return (inp - self.shift) / self.scale 98 | 99 | 100 | class NetLinLayer(nn.Module): 101 | ''' A single linear layer which does a 1x1 conv ''' 102 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 103 | super(NetLinLayer, self).__init__() 104 | 105 | layers = [nn.Dropout(),] if(use_dropout) else [] 106 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 107 | self.model = nn.Sequential(*layers) 108 | 109 | 110 | class Dist2LogitLayer(nn.Module): 111 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 112 | def __init__(self, chn_mid=32, use_sigmoid=True): 113 | super(Dist2LogitLayer, self).__init__() 114 | 115 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 116 | layers += [nn.LeakyReLU(0.2,True),] 117 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 118 | layers += [nn.LeakyReLU(0.2,True),] 119 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 120 | if(use_sigmoid): 121 | layers += [nn.Sigmoid(),] 122 | self.model = nn.Sequential(*layers) 123 | 124 | def forward(self,d0,d1,eps=0.1): 125 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 126 | 127 | class BCERankingLoss(nn.Module): 128 | def __init__(self, chn_mid=32): 129 | super(BCERankingLoss, self).__init__() 130 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 131 | # self.parameters = list(self.net.parameters()) 132 | self.loss = torch.nn.BCELoss() 133 | 134 | def forward(self, d0, d1, judge): 135 | per = (judge+1.)/2. 136 | self.logit = self.net.forward(d0,d1) 137 | return self.loss(self.logit, per) 138 | 139 | # L2, DSSIM metrics 140 | class FakeNet(nn.Module): 141 | def __init__(self, device='cpu', colorspace='Lab'): 142 | super(FakeNet, self).__init__() 143 | self.device = device 144 | self.colorspace=colorspace 145 | 146 | class L2(FakeNet): 147 | 148 | def forward(self, in0, in1, retPerLayer=None): 149 | assert(in0.size()[0]==1) # currently only supports batchSize 1 150 | 151 | if(self.colorspace=='RGB'): 152 | (N,C,X,Y) = in0.size() 153 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 154 | return value 155 | elif(self.colorspace=='Lab'): 156 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 157 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 158 | ret_var = Variable( torch.Tensor((value,) ) ).to(self.device) 159 | return ret_var 160 | 161 | class DSSIM(FakeNet): 162 | 163 | def forward(self, in0, in1, retPerLayer=None): 164 | assert(in0.size()[0]==1) # currently only supports batchSize 1 165 | 166 | if(self.colorspace=='RGB'): 167 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') 168 | elif(self.colorspace=='Lab'): 169 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 170 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 171 | ret_var = Variable( torch.Tensor((value,) ) ).to(self.device) 172 | return ret_var 173 | 174 | def print_network(net): 175 | num_params = 0 176 | for param in net.parameters(): 177 | num_params += param.numel() 178 | print('Network',net) 179 | print('Total number of parameters: %d' % num_params) 180 | -------------------------------------------------------------------------------- /CLIP-main/model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: CLIP 2 | 3 | Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we’re providing some accompanying information about the multimodal model. 4 | 5 | ## Model Details 6 | 7 | The CLIP model was developed by researchers at OpenAI to learn about what contributes to robustness in computer vision tasks. The model was also developed to test the ability of models to generalize to arbitrary image classification tasks in a zero-shot manner. It was not developed for general model deployment - to deploy models like CLIP, researchers will first need to carefully study their capabilities in relation to the specific context they’re being deployed within. 8 | 9 | ### Model Date 10 | 11 | January 2021 12 | 13 | ### Model Type 14 | 15 | The base model uses a ResNet50 with several modifications as an image encoder and uses a masked self-attention Transformer as a text encoder. These encoders are trained to maximize the similarity of (image, text) pairs via a contrastive loss. There is also a variant of the model where the ResNet image encoder is replaced with a Vision Transformer. 16 | 17 | ### Model Versions 18 | 19 | Initially, we’ve released one CLIP model based on the Vision Transformer architecture equivalent to ViT-B/32, along with the RN50 model, using the architecture equivalent to ResNet-50. 20 | 21 | As part of the staged release process, we have also released the RN101 model, as well as RN50x4, a RN50 scaled up 4x according to the [EfficientNet](https://arxiv.org/abs/1905.11946) scaling rule. In July 2021, we additionally released the RN50x16 and ViT-B/16 models, and in January 2022, the RN50x64 and ViT-L/14 models were released. Lastly, the ViT-L/14@336px model was released in April 2022. 22 | 23 | Please see the paper linked below for further details about their specification. 24 | 25 | ### Documents 26 | 27 | - [Blog Post](https://openai.com/blog/clip/) 28 | - [CLIP Paper](https://arxiv.org/abs/2103.00020) 29 | 30 | 31 | 32 | ## Model Use 33 | 34 | ### Intended Use 35 | 36 | The model is intended as a research output for research communities. We hope that this model will enable researchers to better understand and explore zero-shot, arbitrary image classification. We also hope it can be used for interdisciplinary studies of the potential impact of such models - the CLIP paper includes a discussion of potential downstream impacts to provide an example for this sort of analysis. 37 | 38 | #### Primary intended uses 39 | 40 | The primary intended users of these models are AI researchers. 41 | 42 | We primarily imagine the model will be used by researchers to better understand robustness, generalization, and other capabilities, biases, and constraints of computer vision models. 43 | 44 | ### Out-of-Scope Use Cases 45 | 46 | **Any** deployed use case of the model - whether commercial or not - is currently out of scope. Non-deployed use cases such as image search in a constrained environment, are also not recommended unless there is thorough in-domain testing of the model with a specific, fixed class taxonomy. This is because our safety assessment demonstrated a high need for task specific testing especially given the variability of CLIP’s performance with different class taxonomies. This makes untested and unconstrained deployment of the model in any use case currently potentially harmful. 47 | 48 | Certain use cases which would fall under the domain of surveillance and facial recognition are always out-of-scope regardless of performance of the model. This is because the use of artificial intelligence for tasks such as these can be premature currently given the lack of testing norms and checks to ensure its fair use. 49 | 50 | Since the model has not been purposefully trained in or evaluated on any languages other than English, its use should be limited to English language use cases. 51 | 52 | 53 | 54 | ## Data 55 | 56 | The model was trained on publicly available image-caption data. This was done through a combination of crawling a handful of websites and using commonly-used pre-existing image datasets such as [YFCC100M](http://projects.dfki.uni-kl.de/yfcc100m/). A large portion of the data comes from our crawling of the internet. This means that the data is more representative of people and societies most connected to the internet which tend to skew towards more developed nations, and younger, male users. 57 | 58 | ### Data Mission Statement 59 | 60 | Our goal with building this dataset was to test out robustness and generalizability in computer vision tasks. As a result, the focus was on gathering large quantities of data from different publicly-available internet data sources. The data was gathered in a mostly non-interventionist manner. However, we only crawled websites that had policies against excessively violent and adult images and allowed us to filter out such content. We do not intend for this dataset to be used as the basis for any commercial or deployed model and will not be releasing the dataset. 61 | 62 | 63 | 64 | ## Performance and Limitations 65 | 66 | ### Performance 67 | 68 | We have evaluated the performance of CLIP on a wide range of benchmarks across a variety of computer vision datasets such as OCR to texture recognition to fine-grained classification. The paper describes model performance on the following datasets: 69 | 70 | - Food101 71 | - CIFAR10 72 | - CIFAR100 73 | - Birdsnap 74 | - SUN397 75 | - Stanford Cars 76 | - FGVC Aircraft 77 | - VOC2007 78 | - DTD 79 | - Oxford-IIIT Pet dataset 80 | - Caltech101 81 | - Flowers102 82 | - MNIST 83 | - SVHN 84 | - IIIT5K 85 | - Hateful Memes 86 | - SST-2 87 | - UCF101 88 | - Kinetics700 89 | - Country211 90 | - CLEVR Counting 91 | - KITTI Distance 92 | - STL-10 93 | - RareAct 94 | - Flickr30 95 | - MSCOCO 96 | - ImageNet 97 | - ImageNet-A 98 | - ImageNet-R 99 | - ImageNet Sketch 100 | - ObjectNet (ImageNet Overlap) 101 | - Youtube-BB 102 | - ImageNet-Vid 103 | 104 | ## Limitations 105 | 106 | CLIP and our analysis of it have a number of limitations. CLIP currently struggles with respect to certain tasks such as fine grained classification and counting objects. CLIP also poses issues with regards to fairness and bias which we discuss in the paper and briefly in the next section. Additionally, our approach to testing CLIP also has an important limitation- in many cases we have used linear probes to evaluate the performance of CLIP and there is evidence suggesting that linear probes can underestimate model performance. 107 | 108 | ### Bias and Fairness 109 | 110 | We find that the performance of CLIP - and the specific biases it exhibits - can depend significantly on class design and the choices one makes for categories to include and exclude. We tested the risk of certain kinds of denigration with CLIP by classifying images of people from [Fairface](https://arxiv.org/abs/1908.04913) into crime-related and non-human animal categories. We found significant disparities with respect to race and gender. Additionally, we found that these disparities could shift based on how the classes were constructed. (Details captured in the Broader Impacts Section in the paper). 111 | 112 | We also tested the performance of CLIP on gender, race and age classification using the Fairface dataset (We default to using race categories as they are constructed in the Fairface dataset.) in order to assess quality of performance across different demographics. We found accuracy >96% across all races for gender classification with ‘Middle Eastern’ having the highest accuracy (98.4%) and ‘White’ having the lowest (96.5%). Additionally, CLIP averaged ~93% for racial classification and ~63% for age classification. Our use of evaluations to test for gender, race and age classification as well as denigration harms is simply to evaluate performance of the model across people and surface potential risks and not to demonstrate an endorsement/enthusiasm for such tasks. 113 | 114 | 115 | 116 | ## Feedback 117 | 118 | ### Where to send questions or comments about the model 119 | 120 | Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9) 121 | -------------------------------------------------------------------------------- /CLIP-main/README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | 3 | [[Blog]](https://openai.com/blog/clip/) [[Paper]](https://arxiv.org/abs/2103.00020) [[Model Card]](model-card.md) [[Colab]](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb) 4 | 5 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision. 6 | 7 | 8 | 9 | ## Approach 10 | 11 | ![CLIP](CLIP.png) 12 | 13 | 14 | 15 | ## Usage 16 | 17 | First, [install PyTorch 1.7.1](https://pytorch.org/get-started/locally/) (or later) and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick: 18 | 19 | ```bash 20 | $ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0 21 | $ pip install ftfy regex tqdm 22 | $ pip install git+https://github.com/openai/CLIP.git 23 | ``` 24 | 25 | Replace `cudatoolkit=11.0` above with the appropriate CUDA version on your machine or `cpuonly` when installing on a machine without a GPU. 26 | 27 | ```python 28 | import torch 29 | import clip 30 | from PIL import Image 31 | 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | model, preprocess = clip.load("ViT-B/32", device=device) 34 | 35 | image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device) 36 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 37 | 38 | with torch.no_grad(): 39 | image_features = model.encode_image(image) 40 | text_features = model.encode_text(text) 41 | 42 | logits_per_image, logits_per_text = model(image, text) 43 | probs = logits_per_image.softmax(dim=-1).cpu().numpy() 44 | 45 | print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]] 46 | ``` 47 | 48 | 49 | ## API 50 | 51 | The CLIP module `clip` provides the following methods: 52 | 53 | #### `clip.available_models()` 54 | 55 | Returns the names of the available CLIP models. 56 | 57 | #### `clip.load(name, device=..., jit=False)` 58 | 59 | Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint. 60 | 61 | The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded. 62 | 63 | #### `clip.tokenize(text: Union[str, List[str]], context_length=77)` 64 | 65 | Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model 66 | 67 | --- 68 | 69 | The model returned by `clip.load()` supports the following methods: 70 | 71 | #### `model.encode_image(image: Tensor)` 72 | 73 | Given a batch of images, returns the image features encoded by the vision portion of the CLIP model. 74 | 75 | #### `model.encode_text(text: Tensor)` 76 | 77 | Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model. 78 | 79 | #### `model(image: Tensor, text: Tensor)` 80 | 81 | Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100. 82 | 83 | 84 | 85 | ## More Examples 86 | 87 | ### Zero-Shot Prediction 88 | 89 | The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset. 90 | 91 | ```python 92 | import os 93 | import clip 94 | import torch 95 | from torchvision.datasets import CIFAR100 96 | 97 | # Load the model 98 | device = "cuda" if torch.cuda.is_available() else "cpu" 99 | model, preprocess = clip.load('ViT-B/32', device) 100 | 101 | # Download the dataset 102 | cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) 103 | 104 | # Prepare the inputs 105 | image, class_id = cifar100[3637] 106 | image_input = preprocess(image).unsqueeze(0).to(device) 107 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) 108 | 109 | # Calculate features 110 | with torch.no_grad(): 111 | image_features = model.encode_image(image_input) 112 | text_features = model.encode_text(text_inputs) 113 | 114 | # Pick the top 5 most similar labels for the image 115 | image_features /= image_features.norm(dim=-1, keepdim=True) 116 | text_features /= text_features.norm(dim=-1, keepdim=True) 117 | similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) 118 | values, indices = similarity[0].topk(5) 119 | 120 | # Print the result 121 | print("\nTop predictions:\n") 122 | for value, index in zip(values, indices): 123 | print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%") 124 | ``` 125 | 126 | The output will look like the following (the exact numbers may be slightly different depending on the compute device): 127 | 128 | ``` 129 | Top predictions: 130 | 131 | snake: 65.31% 132 | turtle: 12.29% 133 | sweet_pepper: 3.83% 134 | lizard: 1.88% 135 | crocodile: 1.75% 136 | ``` 137 | 138 | Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs. 139 | 140 | 141 | ### Linear-probe evaluation 142 | 143 | The example below uses [scikit-learn](https://scikit-learn.org/) to perform logistic regression on image features. 144 | 145 | ```python 146 | import os 147 | import clip 148 | import torch 149 | 150 | import numpy as np 151 | from sklearn.linear_model import LogisticRegression 152 | from torch.utils.data import DataLoader 153 | from torchvision.datasets import CIFAR100 154 | from tqdm import tqdm 155 | 156 | # Load the model 157 | device = "cuda" if torch.cuda.is_available() else "cpu" 158 | model, preprocess = clip.load('ViT-B/32', device) 159 | 160 | # Load the dataset 161 | root = os.path.expanduser("~/.cache") 162 | train = CIFAR100(root, download=True, train=True, transform=preprocess) 163 | test = CIFAR100(root, download=True, train=False, transform=preprocess) 164 | 165 | 166 | def get_features(dataset): 167 | all_features = [] 168 | all_labels = [] 169 | 170 | with torch.no_grad(): 171 | for images, labels in tqdm(DataLoader(dataset, batch_size=100)): 172 | features = model.encode_image(images.to(device)) 173 | 174 | all_features.append(features) 175 | all_labels.append(labels) 176 | 177 | return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy() 178 | 179 | # Calculate the image features 180 | train_features, train_labels = get_features(train) 181 | test_features, test_labels = get_features(test) 182 | 183 | # Perform logistic regression 184 | classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) 185 | classifier.fit(train_features, train_labels) 186 | 187 | # Evaluate using the logistic regression classifier 188 | predictions = classifier.predict(test_features) 189 | accuracy = np.mean((test_labels == predictions).astype(float)) * 100. 190 | print(f"Accuracy = {accuracy:.3f}") 191 | ``` 192 | 193 | Note that the `C` value should be determined via a hyperparameter sweep using a validation split. 194 | 195 | 196 | ## See Also 197 | 198 | * [OpenCLIP](https://github.com/mlfoundations/open_clip): includes larger and independently trained CLIP models up to ViT-G/14 199 | * [Hugging Face implementation of CLIP](https://huggingface.co/docs/transformers/model_doc/clip): for easier integration with the HF ecosystem 200 | -------------------------------------------------------------------------------- /models/better/normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Normalization layers.""" 17 | import torch.nn as nn 18 | import torch 19 | import functools 20 | 21 | 22 | def get_normalization(config, conditional=False): 23 | """Obtain normalization modules from the config file.""" 24 | norm = config.model.normalization 25 | if conditional: 26 | if norm == 'InstanceNorm++': 27 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 28 | else: 29 | raise NotImplementedError(f'{norm} not implemented yet.') 30 | else: 31 | if norm == 'InstanceNorm': 32 | return nn.InstanceNorm2d 33 | elif norm == 'InstanceNorm++': 34 | return InstanceNorm2dPlus 35 | elif norm == 'VarianceNorm': 36 | return VarianceNorm2d 37 | elif norm == 'GroupNorm': 38 | return nn.GroupNorm 39 | else: 40 | raise ValueError('Unknown normalization: %s' % norm) 41 | 42 | 43 | class ConditionalBatchNorm2d(nn.Module): 44 | def __init__(self, num_features, num_classes, bias=True): 45 | super().__init__() 46 | self.num_features = num_features 47 | self.bias = bias 48 | self.bn = nn.BatchNorm2d(num_features, affine=False) 49 | if self.bias: 50 | self.embed = nn.Embedding(num_classes, num_features * 2) 51 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 52 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 53 | else: 54 | self.embed = nn.Embedding(num_classes, num_features) 55 | self.embed.weight.data.uniform_() 56 | 57 | def forward(self, x, y): 58 | out = self.bn(x) 59 | if self.bias: 60 | gamma, beta = self.embed(y).chunk(2, dim=1) 61 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 62 | else: 63 | gamma = self.embed(y) 64 | out = gamma.view(-1, self.num_features, 1, 1) * out 65 | return out 66 | 67 | 68 | class ConditionalInstanceNorm2d(nn.Module): 69 | def __init__(self, num_features, num_classes, bias=True): 70 | super().__init__() 71 | self.num_features = num_features 72 | self.bias = bias 73 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 74 | if bias: 75 | self.embed = nn.Embedding(num_classes, num_features * 2) 76 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 77 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 78 | else: 79 | self.embed = nn.Embedding(num_classes, num_features) 80 | self.embed.weight.data.uniform_() 81 | 82 | def forward(self, x, y): 83 | h = self.instance_norm(x) 84 | if self.bias: 85 | gamma, beta = self.embed(y).chunk(2, dim=-1) 86 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 87 | else: 88 | gamma = self.embed(y) 89 | out = gamma.view(-1, self.num_features, 1, 1) * h 90 | return out 91 | 92 | 93 | class ConditionalVarianceNorm2d(nn.Module): 94 | def __init__(self, num_features, num_classes, bias=False): 95 | super().__init__() 96 | self.num_features = num_features 97 | self.bias = bias 98 | self.embed = nn.Embedding(num_classes, num_features) 99 | self.embed.weight.data.normal_(1, 0.02) 100 | 101 | def forward(self, x, y): 102 | vars = torch.var(x, dim=(2, 3), keepdim=True) 103 | h = x / torch.sqrt(vars + 1e-5) 104 | 105 | gamma = self.embed(y) 106 | out = gamma.view(-1, self.num_features, 1, 1) * h 107 | return out 108 | 109 | 110 | class VarianceNorm2d(nn.Module): 111 | def __init__(self, num_features, bias=False): 112 | super().__init__() 113 | self.num_features = num_features 114 | self.bias = bias 115 | self.alpha = nn.Parameter(torch.zeros(num_features)) 116 | self.alpha.data.normal_(1, 0.02) 117 | 118 | def forward(self, x): 119 | vars = torch.var(x, dim=(2, 3), keepdim=True) 120 | h = x / torch.sqrt(vars + 1e-5) 121 | 122 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 123 | return out 124 | 125 | 126 | class ConditionalNoneNorm2d(nn.Module): 127 | def __init__(self, num_features, num_classes, bias=True): 128 | super().__init__() 129 | self.num_features = num_features 130 | self.bias = bias 131 | if bias: 132 | self.embed = nn.Embedding(num_classes, num_features * 2) 133 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 134 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 135 | else: 136 | self.embed = nn.Embedding(num_classes, num_features) 137 | self.embed.weight.data.uniform_() 138 | 139 | def forward(self, x, y): 140 | if self.bias: 141 | gamma, beta = self.embed(y).chunk(2, dim=-1) 142 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 143 | else: 144 | gamma = self.embed(y) 145 | out = gamma.view(-1, self.num_features, 1, 1) * x 146 | return out 147 | 148 | 149 | class NoneNorm2d(nn.Module): 150 | def __init__(self, num_features, bias=True): 151 | super().__init__() 152 | 153 | def forward(self, x): 154 | return x 155 | 156 | 157 | class InstanceNorm2dPlus(nn.Module): 158 | def __init__(self, num_features, bias=True): 159 | super().__init__() 160 | self.num_features = num_features 161 | self.bias = bias 162 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 163 | self.alpha = nn.Parameter(torch.zeros(num_features)) 164 | self.gamma = nn.Parameter(torch.zeros(num_features)) 165 | self.alpha.data.normal_(1, 0.02) 166 | self.gamma.data.normal_(1, 0.02) 167 | if bias: 168 | self.beta = nn.Parameter(torch.zeros(num_features)) 169 | 170 | def forward(self, x): 171 | means = torch.mean(x, dim=(2, 3)) 172 | m = torch.mean(means, dim=-1, keepdim=True) 173 | v = torch.var(means, dim=-1, keepdim=True) 174 | means = (means - m) / (torch.sqrt(v + 1e-5)) 175 | h = self.instance_norm(x) 176 | 177 | if self.bias: 178 | h = h + means[..., None, None] * self.alpha[..., None, None] 179 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 180 | else: 181 | h = h + means[..., None, None] * self.alpha[..., None, None] 182 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 183 | return out 184 | 185 | 186 | class ConditionalInstanceNorm2dPlus(nn.Module): 187 | def __init__(self, num_features, num_classes, bias=True): 188 | super().__init__() 189 | self.num_features = num_features 190 | self.bias = bias 191 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 192 | if bias: 193 | self.embed = nn.Embedding(num_classes, num_features * 3) 194 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 195 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 196 | else: 197 | self.embed = nn.Embedding(num_classes, 2 * num_features) 198 | self.embed.weight.data.normal_(1, 0.02) 199 | 200 | def forward(self, x, y): 201 | means = torch.mean(x, dim=(2, 3)) 202 | m = torch.mean(means, dim=-1, keepdim=True) 203 | v = torch.var(means, dim=-1, keepdim=True) 204 | means = (means - m) / (torch.sqrt(v + 1e-5)) 205 | h = self.instance_norm(x) 206 | 207 | if self.bias: 208 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 209 | h = h + means[..., None, None] * alpha[..., None, None] 210 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 211 | else: 212 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 213 | h = h + means[..., None, None] * alpha[..., None, None] 214 | out = gamma.view(-1, self.num_features, 1, 1) * h 215 | return out 216 | -------------------------------------------------------------------------------- /models/better/up_or_down_sampling.py: -------------------------------------------------------------------------------- 1 | """Layers used for up-sampling or down-sampling images. 2 | 3 | Many functions are ported from https://github.com/NVlabs/stylegan2. 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | from .op.upfirdn2d import upfirdn2d 12 | 13 | 14 | # Function ported from StyleGAN2 15 | def get_weight(module, 16 | shape, 17 | weight_var='weight', 18 | kernel_init=None): 19 | """Get/create weight tensor for a convolution or fully-connected layer.""" 20 | 21 | return module.param(weight_var, kernel_init, shape) 22 | 23 | 24 | class Conv2d(nn.Module): 25 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" 26 | 27 | def __init__(self, in_ch, out_ch, kernel, up=False, down=False, 28 | resample_kernel=(1, 3, 3, 1), 29 | use_bias=True, 30 | kernel_init=None): 31 | super().__init__() 32 | assert not (up and down) 33 | assert kernel >= 1 and kernel % 2 == 1 34 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) 35 | if kernel_init is not None: 36 | self.weight.data = kernel_init(self.weight.data.shape) 37 | if use_bias: 38 | self.bias = nn.Parameter(torch.zeros(out_ch)) 39 | 40 | self.up = up 41 | self.down = down 42 | self.resample_kernel = resample_kernel 43 | self.kernel = kernel 44 | self.use_bias = use_bias 45 | 46 | def forward(self, x): 47 | if self.up: 48 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) 49 | elif self.down: 50 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) 51 | else: 52 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) 53 | 54 | if self.use_bias: 55 | x = x + self.bias.reshape(1, -1, 1, 1) 56 | 57 | return x 58 | 59 | 60 | def naive_upsample_2d(x, factor=2): 61 | _N, C, H, W = x.shape 62 | x = torch.reshape(x, (-1, C, H, 1, W, 1)) 63 | x = x.repeat(1, 1, 1, factor, 1, factor) 64 | return torch.reshape(x, (-1, C, H * factor, W * factor)) 65 | 66 | 67 | def naive_downsample_2d(x, factor=2): 68 | _N, C, H, W = x.shape 69 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) 70 | return torch.mean(x, dim=(3, 5)) 71 | 72 | 73 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1): 74 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 75 | 76 | Padding is performed only once at the beginning, not between the 77 | operations. 78 | The fused op is considerably more efficient than performing the same 79 | calculation 80 | using standard TensorFlow ops. It supports gradients of arbitrary order. 81 | Args: 82 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 83 | C]`. 84 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 85 | outChannels]`. Grouped convolution can be performed by `inChannels = 86 | x.shape[0] // numGroups`. 87 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 88 | (separable). The default is `[1] * factor`, which corresponds to 89 | nearest-neighbor upsampling. 90 | factor: Integer upsampling factor (default: 2). 91 | gain: Scaling factor for signal magnitude (default: 1.0). 92 | 93 | Returns: 94 | Tensor of the shape `[N, C, H * factor, W * factor]` or 95 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 96 | """ 97 | 98 | assert isinstance(factor, int) and factor >= 1 99 | 100 | # Check weight shape. 101 | assert len(w.shape) == 4 102 | convH = w.shape[2] 103 | convW = w.shape[3] 104 | inC = w.shape[1] 105 | outC = w.shape[0] 106 | 107 | assert convW == convH 108 | 109 | # Setup filter kernel. 110 | if k is None: 111 | k = [1] * factor 112 | k = _setup_kernel(k) * (gain * (factor ** 2)) 113 | p = (k.shape[0] - factor) - (convW - 1) 114 | 115 | stride = (factor, factor) 116 | 117 | # Determine data dimensions. 118 | stride = [1, 1, factor, factor] 119 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) 120 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, 121 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) 122 | assert output_padding[0] >= 0 and output_padding[1] >= 0 123 | num_groups = _shape(x, 1) // inC 124 | 125 | # Transpose weights. 126 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) 127 | w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) 128 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) 129 | 130 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) 131 | ## Original TF code. 132 | # x = tf.nn.conv2d_transpose( 133 | # x, 134 | # w, 135 | # output_shape=output_shape, 136 | # strides=stride, 137 | # padding='VALID', 138 | # data_format=data_format) 139 | ## JAX equivalent 140 | 141 | return upfirdn2d(x, torch.tensor(k, device=x.device), 142 | pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) 143 | 144 | 145 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1): 146 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 147 | 148 | Padding is performed only once at the beginning, not between the operations. 149 | The fused op is considerably more efficient than performing the same 150 | calculation 151 | using standard TensorFlow ops. It supports gradients of arbitrary order. 152 | Args: 153 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 154 | C]`. 155 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 156 | outChannels]`. Grouped convolution can be performed by `inChannels = 157 | x.shape[0] // numGroups`. 158 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 159 | (separable). The default is `[1] * factor`, which corresponds to 160 | average pooling. 161 | factor: Integer downsampling factor (default: 2). 162 | gain: Scaling factor for signal magnitude (default: 1.0). 163 | 164 | Returns: 165 | Tensor of the shape `[N, C, H // factor, W // factor]` or 166 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 167 | """ 168 | 169 | assert isinstance(factor, int) and factor >= 1 170 | _outC, _inC, convH, convW = w.shape 171 | assert convW == convH 172 | if k is None: 173 | k = [1] * factor 174 | k = _setup_kernel(k) * gain 175 | p = (k.shape[0] - factor) + (convW - 1) 176 | s = [factor, factor] 177 | x = upfirdn2d(x, torch.tensor(k, device=x.device), 178 | pad=((p + 1) // 2, p // 2)) 179 | return F.conv2d(x, w, stride=s, padding=0) 180 | 181 | 182 | def _setup_kernel(k): 183 | k = np.asarray(k, dtype=np.float32) 184 | if k.ndim == 1: 185 | k = np.outer(k, k) 186 | k /= np.sum(k) 187 | assert k.ndim == 2 188 | assert k.shape[0] == k.shape[1] 189 | return k 190 | 191 | 192 | def _shape(x, dim): 193 | return x.shape[dim] 194 | 195 | 196 | def upsample_2d(x, k=None, factor=2, gain=1): 197 | r"""Upsample a batch of 2D images with the given filter. 198 | 199 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 200 | and upsamples each image with the given filter. The filter is normalized so 201 | that 202 | if the input pixels are constant, they will be scaled by the specified 203 | `gain`. 204 | Pixels outside the image are assumed to be zero, and the filter is padded 205 | with 206 | zeros so that its shape is a multiple of the upsampling factor. 207 | Args: 208 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 209 | C]`. 210 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 211 | (separable). The default is `[1] * factor`, which corresponds to 212 | nearest-neighbor upsampling. 213 | factor: Integer upsampling factor (default: 2). 214 | gain: Scaling factor for signal magnitude (default: 1.0). 215 | 216 | Returns: 217 | Tensor of the shape `[N, C, H * factor, W * factor]` 218 | """ 219 | assert isinstance(factor, int) and factor >= 1 220 | if k is None: 221 | k = [1] * factor 222 | k = _setup_kernel(k) * (gain * (factor ** 2)) 223 | p = k.shape[0] - factor 224 | return upfirdn2d(x, torch.tensor(k, device=x.device), 225 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) 226 | 227 | 228 | def downsample_2d(x, k=None, factor=2, gain=1): 229 | r"""Downsample a batch of 2D images with the given filter. 230 | 231 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 232 | and downsamples each image with the given filter. The filter is normalized 233 | so that 234 | if the input pixels are constant, they will be scaled by the specified 235 | `gain`. 236 | Pixels outside the image are assumed to be zero, and the filter is padded 237 | with 238 | zeros so that its shape is a multiple of the downsampling factor. 239 | Args: 240 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 241 | C]`. 242 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 243 | (separable). The default is `[1] * factor`, which corresponds to 244 | average pooling. 245 | factor: Integer downsampling factor (default: 2). 246 | gain: Scaling factor for signal magnitude (default: 1.0). 247 | 248 | Returns: 249 | Tensor of the shape `[N, C, H // factor, W // factor]` 250 | """ 251 | 252 | assert isinstance(factor, int) and factor >= 1 253 | if k is None: 254 | k = [1] * factor 255 | k = _setup_kernel(k) * gain 256 | p = k.shape[0] - factor 257 | return upfirdn2d(x, torch.tensor(k, device=x.device), 258 | down=factor, pad=((p + 1) // 2, p // 2)) 259 | -------------------------------------------------------------------------------- /CLIP-main/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def patch_device(module): 149 | try: 150 | graphs = [module.graph] if hasattr(module, "graph") else [] 151 | except RuntimeError: 152 | graphs = [] 153 | 154 | if hasattr(module, "forward1"): 155 | graphs.append(module.forward1.graph) 156 | 157 | for graph in graphs: 158 | for node in graph.findAllNodes("prim::Constant"): 159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 160 | node.copyAttributes(device_node) 161 | 162 | model.apply(patch_device) 163 | patch_device(model.encode_image) 164 | patch_device(model.encode_text) 165 | 166 | # patch dtype to float32 on CPU 167 | if str(device) == "cpu": 168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 170 | float_node = float_input.node() 171 | 172 | def patch_float(module): 173 | try: 174 | graphs = [module.graph] if hasattr(module, "graph") else [] 175 | except RuntimeError: 176 | graphs = [] 177 | 178 | if hasattr(module, "forward1"): 179 | graphs.append(module.forward1.graph) 180 | 181 | for graph in graphs: 182 | for node in graph.findAllNodes("aten::to"): 183 | inputs = list(node.inputs()) 184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 185 | if inputs[i].node()["value"] == 5: 186 | inputs[i].node().copyAttributes(float_node) 187 | 188 | model.apply(patch_float) 189 | patch_float(model.encode_image) 190 | patch_float(model.encode_text) 191 | 192 | model.float() 193 | 194 | return model, _transform(model.input_resolution.item()) 195 | 196 | 197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 198 | """ 199 | Returns the tokenized representation of given input string(s) 200 | 201 | Parameters 202 | ---------- 203 | texts : Union[str, List[str]] 204 | An input string or a list of input strings to tokenize 205 | 206 | context_length : int 207 | The context length to use; all CLIP models use 77 as the context length 208 | 209 | truncate: bool 210 | Whether to truncate the text in case its encoding is longer than the context length 211 | 212 | Returns 213 | ------- 214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | else: 226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 227 | 228 | for i, tokens in enumerate(all_tokens): 229 | if len(tokens) > context_length: 230 | if truncate: 231 | tokens = tokens[:context_length] 232 | tokens[-1] = eot_token 233 | else: 234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 235 | result[i, :len(tokens)] = torch.tensor(tokens) 236 | 237 | return result 238 | -------------------------------------------------------------------------------- /evaluation/fid_score_OLD.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectively. 15 | 16 | See --help to see further details. 17 | 18 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 19 | of Tensorflow 20 | 21 | Copyright 2018 Institute of Bioinformatics, JKU Linz 22 | 23 | Licensed under the Apache License, Version 2.0 (the "License"); 24 | you may not use this file except in compliance with the License. 25 | You may obtain a copy of the License at 26 | 27 | http://www.apache.org/licenses/LICENSE-2.0 28 | 29 | Unless required by applicable law or agreed to in writing, software 30 | distributed under the License is distributed on an "AS IS" BASIS, 31 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | See the License for the specific language governing permissions and 33 | limitations under the License. 34 | """ 35 | import os 36 | import pathlib 37 | 38 | import numpy as np 39 | import torch 40 | from scipy import linalg 41 | from torch.nn.functional import adaptive_avg_pool2d 42 | import os 43 | 44 | from PIL import Image 45 | 46 | try: 47 | from tqdm import tqdm 48 | except ImportError: 49 | # If not tqdm is not available, provide a mock version of it 50 | def tqdm(x): return x 51 | 52 | from .inception import InceptionV3 53 | 54 | def imread(filename): 55 | """ 56 | Loads an image file into a (height, width, 3) uint8 ndarray. 57 | """ 58 | return np.asarray(Image.open(filename), dtype=np.uint8)[..., :3] 59 | 60 | 61 | def get_activations(files, model, batch_size=50, dims=2048, 62 | cuda=False, verbose=False): 63 | """Calculates the activations of the pool_3 layer for all images. 64 | 65 | Params: 66 | -- files : List of image files paths 67 | -- model : Instance of inception model 68 | -- batch_size : Batch size of images for the model to process at once. 69 | Make sure that the number of samples is a multiple of 70 | the batch size, otherwise some samples are ignored. This 71 | behavior is retained to match the original FID score 72 | implementation. 73 | -- dims : Dimensionality of features returned by Inception 74 | -- cuda : If set to True, use GPU 75 | -- verbose : If set to True and parameter out_step is given, the number 76 | of calculated batches is reported. 77 | Returns: 78 | -- A numpy array of dimension (num images, dims) that contains the 79 | activations of the given tensor when feeding inception with the 80 | query tensor. 81 | """ 82 | model.eval() 83 | 84 | if batch_size > len(files): 85 | print(('Warning: batch size is bigger than the data size. ' 86 | 'Setting batch size to data size')) 87 | batch_size = len(files) 88 | 89 | pred_arr = np.empty((len(files), dims)) 90 | 91 | for i in tqdm(range(0, len(files), batch_size)): 92 | if verbose: 93 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 94 | end='', flush=True) 95 | start = i 96 | end = i + batch_size 97 | 98 | images = np.array([imread(str(f)).astype(np.float32) 99 | for f in files[start:end]]) 100 | 101 | # Reshape to (n_images, 3, height, width) 102 | images = images.transpose((0, 3, 1, 2)) 103 | images /= 255 104 | 105 | batch = torch.from_numpy(images).type(torch.FloatTensor) 106 | if cuda: 107 | batch = batch.cuda() 108 | 109 | pred = model(batch)[0] 110 | 111 | # If model output is not scalar, apply global spatial average pooling. 112 | # This happens if you choose a dimensionality not equal 2048. 113 | if pred.size(2) != 1 or pred.size(3) != 1: 114 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 115 | 116 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(pred.size(0), -1) 117 | 118 | if verbose: 119 | print(' done') 120 | 121 | return pred_arr 122 | 123 | 124 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 125 | """Numpy implementation of the Frechet Distance. 126 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 127 | and X_2 ~ N(mu_2, C_2) is 128 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 129 | 130 | Stable version by Dougal J. Sutherland. 131 | 132 | Params: 133 | -- mu1 : Numpy array containing the activations of a layer of the 134 | inception net (like returned by the function 'get_predictions') 135 | for generated samples. 136 | -- mu2 : The sample mean over activations, precalculated on an 137 | representative data set. 138 | -- sigma1: The covariance matrix over activations for generated samples. 139 | -- sigma2: The covariance matrix over activations, precalculated on an 140 | representative data set. 141 | 142 | Returns: 143 | -- : The Frechet Distance. 144 | """ 145 | 146 | mu1 = np.atleast_1d(mu1) 147 | mu2 = np.atleast_1d(mu2) 148 | 149 | sigma1 = np.atleast_2d(sigma1) 150 | sigma2 = np.atleast_2d(sigma2) 151 | 152 | assert mu1.shape == mu2.shape, \ 153 | 'Training and test mean vectors have different lengths' 154 | assert sigma1.shape == sigma2.shape, \ 155 | 'Training and test covariances have different dimensions' 156 | 157 | diff = mu1 - mu2 158 | 159 | # Product might be almost singular 160 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 161 | if not np.isfinite(covmean).all(): 162 | msg = ('fid calculation produces singular product; ' 163 | 'adding %s to diagonal of cov estimates') % eps 164 | print(msg) 165 | offset = np.eye(sigma1.shape[0]) * eps 166 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 167 | 168 | # Numerical error might give slight imaginary component 169 | if np.iscomplexobj(covmean): 170 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 171 | m = np.max(np.abs(covmean.imag)) 172 | raise ValueError('Imaginary component {}'.format(m)) 173 | covmean = covmean.real 174 | 175 | tr_covmean = np.trace(covmean) 176 | 177 | return (diff.dot(diff) + np.trace(sigma1) + 178 | np.trace(sigma2) - 2 * tr_covmean) 179 | 180 | 181 | def calculate_activation_statistics(files, model, batch_size=50, 182 | dims=2048, cuda=False, verbose=False): 183 | """Calculation of the statistics used by the FID. 184 | Params: 185 | -- files : List of image files paths 186 | -- model : Instance of inception model 187 | -- batch_size : The images numpy array is split into batches with 188 | batch size batch_size. A reasonable batch size 189 | depends on the hardware. 190 | -- dims : Dimensionality of features returned by Inception 191 | -- cuda : If set to True, use GPU 192 | -- verbose : If set to True and parameter out_step is given, the 193 | number of calculated batches is reported. 194 | Returns: 195 | -- mu : The mean over samples of the activations of the pool_3 layer of 196 | the inception model. 197 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 198 | the inception model. 199 | """ 200 | act = get_activations(files, model, batch_size, dims, cuda, verbose) 201 | mu = np.mean(act, axis=0) 202 | sigma = np.cov(act, rowvar=False) 203 | return mu, sigma 204 | 205 | 206 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda): 207 | if path.endswith('.npz'): 208 | f = np.load(path) 209 | m, s = f['mu'][:], f['sigma'][:] 210 | f.close() 211 | else: 212 | path = pathlib.Path(path) 213 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 214 | m, s = calculate_activation_statistics(files, model, batch_size, 215 | dims, cuda) 216 | 217 | return m, s 218 | 219 | 220 | def calculate_fid_given_paths(paths, batch_size, cuda, dims): 221 | """Calculates the FID of two paths""" 222 | for p in paths: 223 | if not os.path.exists(p): 224 | raise RuntimeError('Invalid path: %s' % p) 225 | 226 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 227 | 228 | model = InceptionV3([block_idx]) 229 | if cuda: 230 | model.cuda() 231 | 232 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, 233 | dims, cuda) 234 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, 235 | dims, cuda) 236 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 237 | 238 | return fid_value 239 | 240 | 241 | def get_fid(path1, path2): 242 | fid_value = calculate_fid_given_paths([path1, path2], 243 | 50, 244 | True, 245 | 2048) 246 | return fid_value 247 | 248 | links = { 249 | 'CIFAR10': 'http://bioinf.jku.at/research/ttur/ttur_stats/fid_stats_cifar10_train.npz', 250 | 'LSUN': 'http://bioinf.jku.at/research/ttur/ttur_stats/fid_stats_lsun_train.npz' 251 | } 252 | 253 | def get_fid_stats_path(args, config, download=True): 254 | if config.data.dataset == 'CIFAR10': 255 | path = os.path.join(args.exp, 'datasets', 'cifar10_fid.npz') 256 | if not os.path.exists(path): 257 | if not download: 258 | raise FileNotFoundError("no statistics file founded") 259 | else: 260 | import urllib 261 | urllib.request.urlretrieve( 262 | links[config.data.dataset], path 263 | ) 264 | elif config.data.dataset == 'CELEBA': 265 | path = os.path.join(args.exp, 'datasets', 'celeba_test_fid_stats.npz') 266 | if not os.path.exists(path): 267 | raise FileNotFoundError('no statistics file founded') 268 | 269 | return path -------------------------------------------------------------------------------- /models/fvd/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import os.path as osp 5 | import math 6 | import torch.nn.functional as F 7 | 8 | # try: 9 | # from torchvision.models.utils import load_state_dict_from_url 10 | # except ImportError: 11 | # from torch.utils.model_zoo import load_url as load_state_dict_from_url 12 | 13 | # i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI" 14 | 15 | # def load_i3d_pretrained(device=torch.device('cpu')): 16 | # from .pytorch_i3d import InceptionI3d 17 | # i3d = InceptionI3d(400, in_channels=3).to(device) 18 | # try: # can't access internet from compute canada, so need a local version 19 | # filepath = 'models/i3d_pretrained_400.pt' 20 | # i3d.load_state_dict(torch.load(filepath, map_location=device)) 21 | # except: 22 | # state_dict = load_state_dict_from_url(i3D_WEIGHTS_URL, progress=True, map_location=device) 23 | # i3d.load_state_dict(state_dict) 24 | # i3d = torch.nn.DataParallel(i3d) 25 | # i3d.eval() 26 | # return i3d 27 | 28 | 29 | # https://github.com/universome/fvd-comparison 30 | i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt" 31 | 32 | def load_i3d_pretrained(device=torch.device('cpu')): 33 | 34 | filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt') 35 | print(filepath) 36 | if not os.path.exists(filepath): 37 | os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") 38 | # filepath = "i3d_torchscript.pt" 39 | i3d = torch.jit.load(filepath).eval().to(device) 40 | i3d = torch.nn.DataParallel(i3d) 41 | return i3d 42 | 43 | 44 | def get_feats(videos, detector, device, bs=10): 45 | # videos : torch.tensor BCTHW [0, 1] 46 | detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer. 47 | feats = np.empty((0, 400)) 48 | device = torch.device("cuda:0") if device is not torch.device("cpu") else device 49 | with torch.no_grad(): 50 | for i in range((len(videos)-1)//bs + 1): 51 | feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()]) 52 | return feats 53 | 54 | 55 | def get_fvd_feats(videos, i3d, device, bs=10): 56 | # videos in [0, 1] as torch tensor BCTHW 57 | # videos = [preprocess_single(video) for video in videos] 58 | embeddings = get_feats(videos, i3d, device, bs) 59 | return embeddings 60 | 61 | # """ 62 | # Copy-pasted from Copy-pasted from https://github.com/NVlabs/stylegan2-ada-pytorch 63 | # """ 64 | 65 | # import ctypes 66 | # import fnmatch 67 | # import importlib 68 | # import inspect 69 | # import numpy as np 70 | # import os 71 | # import shutil 72 | # import sys 73 | # import types 74 | # import io 75 | # import pickle 76 | # import re 77 | # import requests 78 | # import html 79 | # import hashlib 80 | # import glob 81 | # import tempfile 82 | # import urllib 83 | # import urllib.request 84 | # import uuid 85 | 86 | # from distutils.util import strtobool 87 | # from typing import Any, List, Tuple, Union, Dict 88 | 89 | # def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any: 90 | # """Download the given URL and return a binary-mode file object to access the data.""" 91 | # assert num_attempts >= 1 92 | 93 | # # Doesn't look like an URL scheme so interpret it as a local filename. 94 | # if not re.match('^[a-z]+://', url): 95 | # return url if return_filename else open(url, "rb") 96 | 97 | # # Handle file URLs. This code handles unusual file:// patterns that 98 | # # arise on Windows: 99 | # # 100 | # # file:///c:/foo.txt 101 | # # 102 | # # which would translate to a local '/c:/foo.txt' filename that's 103 | # # invalid. Drop the forward slash for such pathnames. 104 | # # 105 | # # If you touch this code path, you should test it on both Linux and 106 | # # Windows. 107 | # # 108 | # # Some internet resources suggest using urllib.request.url2pathname() but 109 | # # but that converts forward slashes to backslashes and this causes 110 | # # its own set of problems. 111 | # if url.startswith('file://'): 112 | # filename = urllib.parse.urlparse(url).path 113 | # if re.match(r'^/[a-zA-Z]:', filename): 114 | # filename = filename[1:] 115 | # return filename if return_filename else open(filename, "rb") 116 | 117 | # url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 118 | 119 | # # Download. 120 | # url_name = None 121 | # url_data = None 122 | # with requests.Session() as session: 123 | # if verbose: 124 | # print("Downloading %s ..." % url, end="", flush=True) 125 | # for attempts_left in reversed(range(num_attempts)): 126 | # try: 127 | # with session.get(url) as res: 128 | # res.raise_for_status() 129 | # if len(res.content) == 0: 130 | # raise IOError("No data received") 131 | 132 | # if len(res.content) < 8192: 133 | # content_str = res.content.decode("utf-8") 134 | # if "download_warning" in res.headers.get("Set-Cookie", ""): 135 | # links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 136 | # if len(links) == 1: 137 | # url = requests.compat.urljoin(url, links[0]) 138 | # raise IOError("Google Drive virus checker nag") 139 | # if "Google Drive - Quota exceeded" in content_str: 140 | # raise IOError("Google Drive download quota exceeded -- please try again later") 141 | 142 | # match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 143 | # url_name = match[1] if match else url 144 | # url_data = res.content 145 | # if verbose: 146 | # print(" done") 147 | # break 148 | # except KeyboardInterrupt: 149 | # raise 150 | # except: 151 | # if not attempts_left: 152 | # if verbose: 153 | # print(" failed") 154 | # raise 155 | # if verbose: 156 | # print(".", end="", flush=True) 157 | 158 | # # Return data as file object. 159 | # assert not return_filename 160 | # return io.BytesIO(url_data) 161 | 162 | 163 | def preprocess_single(video, resolution=224, sequence_length=None): 164 | # video: CTHW, [0, 1] 165 | c, t, h, w = video.shape 166 | 167 | # temporal crop 168 | if sequence_length is not None: 169 | assert sequence_length <= t 170 | video = video[:, :sequence_length] 171 | 172 | # scale shorter side to resolution 173 | scale = resolution / min(h, w) 174 | if h < w: 175 | target_size = (resolution, math.ceil(w * scale)) 176 | else: 177 | target_size = (math.ceil(h * scale), resolution) 178 | video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False) 179 | 180 | # center crop 181 | c, t, h, w = video.shape 182 | w_start = (w - resolution) // 2 183 | h_start = (h - resolution) // 2 184 | video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] 185 | 186 | # [0, 1] -> [-1, 1] 187 | video = (video - 0.5) * 2 188 | 189 | return video.contiguous() 190 | 191 | 192 | def get_logits(i3d, videos, device): 193 | #assert videos.shape[0] % 2 == 0 194 | logits = torch.empty(0, 400) 195 | with torch.no_grad(): 196 | for i in range(len(videos)): 197 | # logits.append(i3d(preprocess_single(videos[i]).unsqueeze(0).to(device)).detach().cpu()) 198 | logits = torch.vstack([logits, i3d(preprocess_single(videos[i]).unsqueeze(0).to(device)).detach().cpu()]) 199 | # logits = torch.cat(logits, dim=0) 200 | return logits 201 | 202 | 203 | def get_fvd_logits(videos, i3d, device): 204 | # videos in [0, 1] as torch tensor BCTHW 205 | # videos = [preprocess_single(video) for video in videos] 206 | embeddings = get_logits(i3d, videos, device) 207 | return embeddings 208 | 209 | 210 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 211 | def _symmetric_matrix_square_root(mat, eps=1e-10): 212 | u, s, v = torch.linalg.svd(mat) 213 | si = torch.where(s < eps, s, torch.sqrt(s)) 214 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) 215 | 216 | 217 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 218 | def trace_sqrt_product(sigma, sigma_v): 219 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 220 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) 221 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 222 | 223 | 224 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 225 | def cov(m, rowvar=False): 226 | '''Estimate a covariance matrix given data. 227 | 228 | Covariance indicates the level to which two variables vary together. 229 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 230 | then the covariance matrix element `C_{ij}` is the covariance of 231 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 232 | 233 | Args: 234 | m: A 1-D or 2-D array containing multiple variables and observations. 235 | Each row of `m` represents a variable, and each column a single 236 | observation of all those variables. 237 | rowvar: If `rowvar` is True, then each row represents a 238 | variable, with observations in the columns. Otherwise, the 239 | relationship is transposed: each column represents a variable, 240 | while the rows contain observations. 241 | 242 | Returns: 243 | The covariance matrix of the variables. 244 | ''' 245 | if m.dim() > 2: 246 | raise ValueError('m has more than 2 dimensions') 247 | if m.dim() < 2: 248 | m = m.view(1, -1) 249 | if not rowvar and m.size(0) != 1: 250 | m = m.t() 251 | 252 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate 253 | m -= torch.mean(m, dim=1, keepdim=True) 254 | mt = m.t() # if complex: mt = m.t().conj() 255 | return fact * m.matmul(mt).squeeze() 256 | 257 | 258 | # def frechet_distance(x1, x2): 259 | # x1 = x1.flatten(start_dim=1) 260 | # x2 = x2.flatten(start_dim=1) 261 | # m, m_w = x1.mean(dim=0), x2.mean(dim=0) 262 | # sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) 263 | # sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 264 | # trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 265 | # mean = torch.sum((m - m_w) ** 2) 266 | # fd = trace + mean 267 | # return fd 268 | 269 | 270 | """ 271 | Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py 272 | """ 273 | from typing import Tuple 274 | from scipy.linalg import sqrtm 275 | import numpy as np 276 | 277 | 278 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 279 | mu = feats.mean(axis=0) # [d] 280 | sigma = np.cov(feats, rowvar=False) # [d, d] 281 | return mu, sigma 282 | 283 | 284 | def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: 285 | mu_gen, sigma_gen = compute_stats(feats_fake) 286 | mu_real, sigma_real = compute_stats(feats_real) 287 | m = np.square(mu_gen - mu_real).sum() 288 | s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 289 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 290 | return float(fid) 291 | -------------------------------------------------------------------------------- /models/dist_model.py: -------------------------------------------------------------------------------- 1 | # Taken from https://github.com/psh01087/Vid-ODE/blob/main/eval_models/dist_model.py 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | from . import networks_basic as networks 20 | from . import eval_models as util 21 | 22 | class DistModel(BaseModel): 23 | def name(self): 24 | return self.model_name 25 | 26 | def initialize(self, device, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, printNet=False, spatial=False, 27 | is_train=False, lr=.0001, beta1=0.5, version='0.1'): 28 | ''' 29 | INPUTS 30 | model - ['net-lin'] for linearly calibrated network 31 | ['net'] for off-the-shelf network 32 | ['L2'] for L2 distance in Lab colorspace 33 | ['SSIM'] for ssim in RGB colorspace 34 | net - ['squeeze','alex','vgg'] 35 | model_path - if None, will look in weights/[NET_NAME].pth 36 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 37 | use_gpu - bool - whether or not to use a GPU 38 | printNet - bool - whether or not to print network architecture out 39 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 40 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 41 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 42 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 43 | is_train - bool - [True] for training mode 44 | lr - float - initial learning rate 45 | beta1 - float - initial momentum term for adam 46 | version - 0.1 for latest, 0.0 was original (with a bug) 47 | gpu_ids - int array - [0] by default, gpus to use 48 | ''' 49 | BaseModel.initialize(self, device) 50 | 51 | self.model = model 52 | self.net = net 53 | self.is_train = is_train 54 | self.spatial = spatial 55 | self.device = device 56 | self.model_name = '%s [%s]'%(model,net) 57 | use_gpu = True if torch.cuda.is_available() else False 58 | self.use_gpu = use_gpu 59 | 60 | if(self.model == 'net-lin'): # pretrained net + linear layer 61 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 62 | use_dropout=True, spatial=spatial, version=version, lpips=True) 63 | kw = {} 64 | if not use_gpu: 65 | kw['map_location'] = 'cpu' 66 | if(model_path is None): 67 | import inspect 68 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 69 | 70 | if(not is_train): 71 | print('Loading model from: %s'%model_path) 72 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 73 | 74 | elif(self.model=='net'): # pretrained network 75 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 76 | elif(self.model in ['L2','l2']): 77 | self.net = networks.L2(device=self.device,colorspace=colorspace) # not really a network, only for testing 78 | self.model_name = 'L2' 79 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 80 | self.net = networks.DSSIM(device=self.device,colorspace=colorspace) 81 | self.model_name = 'SSIM' 82 | else: 83 | raise ValueError("Model [%s] not recognized." % self.model) 84 | 85 | self.parameters = list(self.net.parameters()) 86 | 87 | if self.is_train: # training mode 88 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 89 | self.rankLoss = networks.BCERankingLoss() 90 | self.parameters += list(self.rankLoss.net.parameters()) 91 | self.lr = lr 92 | self.old_lr = lr 93 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 94 | else: # test mode 95 | self.net.eval() 96 | self.net.to(device=self.device) 97 | #self.net = torch.nn.DataParallel(self.net) 98 | if(self.is_train): 99 | self.rankLoss = self.rankLoss.to(device=device) # just put this on GPU0 100 | 101 | if(printNet): 102 | print('---------- Networks initialized -------------') 103 | networks.print_network(self.net) 104 | print('-----------------------------------------------') 105 | 106 | def forward(self, in0, in1, retPerLayer=False): 107 | ''' Function computes the distance between image patches in0 and in1 108 | INPUTS 109 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 110 | OUTPUT 111 | computed distances between in0 and in1 112 | ''' 113 | 114 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 115 | 116 | # ***** TRAINING FUNCTIONS ***** 117 | def optimize_parameters(self): 118 | self.forward_train() 119 | self.optimizer_net.zero_grad() 120 | self.backward_train() 121 | self.optimizer_net.step() 122 | self.clamp_weights() 123 | 124 | def clamp_weights(self): 125 | for module in self.net.modules(): 126 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 127 | module.weight.data = torch.clamp(module.weight.data,min=0) 128 | 129 | def set_input(self, data): 130 | self.input_ref = data['ref'] 131 | self.input_p0 = data['p0'] 132 | self.input_p1 = data['p1'] 133 | self.input_judge = data['judge'] 134 | 135 | if(self.use_gpu): 136 | self.input_ref = self.input_ref.to(device=self.device) 137 | self.input_p0 = self.input_p0.to(device=self.device) 138 | self.input_p1 = self.input_p1.to(device=self.device) 139 | self.input_judge = self.input_judge.to(device=self.device) 140 | 141 | self.var_ref = Variable(self.input_ref,requires_grad=True) 142 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 143 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 144 | 145 | def forward_train(self): # run forward pass 146 | # print(self.net.module.scaling_layer.shift) 147 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 148 | 149 | self.d0 = self.forward(self.var_ref, self.var_p0) 150 | self.d1 = self.forward(self.var_ref, self.var_p1) 151 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 152 | 153 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 154 | 155 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 156 | 157 | return self.loss_total 158 | 159 | def backward_train(self): 160 | torch.mean(self.loss_total).backward() 161 | 162 | def compute_accuracy(self,d0,d1,judge): 163 | ''' d0, d1 are Variables, judge is a Tensor ''' 164 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 207 | self.old_lr = lr 208 | 209 | def score_2afc_dataset(data_loader, func, name=''): 210 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 211 | distance function 'func' in dataset 'data_loader' 212 | INPUTS 213 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 214 | func - callable distance function - calling d=func(in0,in1) should take 2 215 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 216 | OUTPUTS 217 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 218 | [1] - dictionary with following elements 219 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 220 | gts - N array in [0,1], preferred patch selected by human evaluators 221 | (closer to "0" for left patch p0, "1" for right patch p1, 222 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 223 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 224 | CONSTS 225 | N - number of test triplets in data_loader 226 | ''' 227 | 228 | d0s = [] 229 | d1s = [] 230 | gts = [] 231 | 232 | for data in tqdm(data_loader.load_data(), desc=name): 233 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 234 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 235 | gts+=data['judge'].cpu().numpy().flatten().tolist() 236 | 237 | d0s = np.array(d0s) 238 | d1s = np.array(d1s) 239 | gts = np.array(gts) 240 | scores = (d0sbhwij', q.reshape(B * self.n_heads, C, H, W), k.reshape(B * self.n_heads, C, H, W)) * (int(C) ** (-0.5)) 59 | w = torch.reshape(w, (B * self.n_heads, H, W, H * W)) 60 | w = F.softmax(w, dim=-1) 61 | w = torch.reshape(w, (B * self.n_heads, H, W, H, W)) 62 | h = torch.einsum('bhwij,bcij->bchw', w, v.reshape(B * self.n_heads, C, H, W)) 63 | h = h.reshape(B, C * self.n_heads, H, W) 64 | h = self.NIN_3(h) 65 | if not self.skip_rescale: 66 | return x + h 67 | else: 68 | return (x + h) / np.sqrt(2.) 69 | 70 | class NIN1d(nn.Module): 71 | def __init__(self, in_dim, num_units, init_scale=0.1): 72 | super().__init__() 73 | self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) 74 | self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) 75 | 76 | def forward(self, x): 77 | x = x.permute(0, 2, 1) 78 | y = contract_inner(x, self.W) + self.b 79 | return y.permute(0, 2, 1) 80 | 81 | class AttnBlockpp1d(nn.Module): 82 | """Channel-wise self-attention block. Modified from DDPM. in 1D""" 83 | 84 | def __init__(self, channels, skip_rescale=False, init_scale=0., n_heads=1, n_head_channels=-1): 85 | super().__init__() 86 | num_groups = min(channels // 4, 32) 87 | while (channels % num_groups != 0): 88 | num_groups -= 1 89 | self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, 90 | eps=1e-6) 91 | self.NIN_0 = NIN1d(channels, channels) 92 | self.NIN_1 = NIN1d(channels, channels) 93 | self.NIN_2 = NIN1d(channels, channels) 94 | self.NIN_3 = NIN1d(channels, channels, init_scale=init_scale) 95 | self.skip_rescale = skip_rescale 96 | if n_head_channels == -1: 97 | self.n_heads = n_heads 98 | else: 99 | if channels < n_head_channels: 100 | self.n_heads = 1 101 | else: 102 | assert channels % n_head_channels == 0 103 | self.n_heads = channels // n_head_channels 104 | 105 | def forward(self, x): 106 | B, C, T = x.shape 107 | h = self.GroupNorm_0(x) 108 | q = self.NIN_0(h) 109 | k = self.NIN_1(h) 110 | v = self.NIN_2(h) 111 | 112 | C = C // self.n_heads 113 | w = torch.einsum('bct,bci->bti', q.reshape(B * self.n_heads, C, T), k.reshape(B * self.n_heads, C, T)) * (int(C) ** (-0.5)) 114 | w = torch.reshape(w, (B * self.n_heads, T, T)) 115 | w = F.softmax(w, dim=-1) 116 | w = torch.reshape(w, (B * self.n_heads, T, T)) 117 | h = torch.einsum('bti,bci->bct', w, v.reshape(B * self.n_heads, C, T)) 118 | h = h.reshape(B, C * self.n_heads, T) 119 | h = self.NIN_3(h) 120 | if not self.skip_rescale: 121 | return x + h 122 | else: 123 | return (x + h) / np.sqrt(2.) 124 | 125 | 126 | class NIN3d(nn.Module): 127 | def __init__(self, in_dim, num_units, init_scale=0.1): 128 | super().__init__() 129 | self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) 130 | self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) 131 | 132 | def forward(self, x): 133 | x = x.permute(0, 2, 3, 4, 1) # BxCxNxHxW to BxNxHxWxC 134 | y = contract_inner(x, self.W) + self.b 135 | return y.permute(0, 4, 1, 2, 3) 136 | 137 | 138 | class AttnBlockpp3d_old(nn.Module): # over time, height, width; crazy memory demands like 9GB for one att block!!! Not worth it 139 | """Channel-wise 3d self-attention block.""" 140 | 141 | def __init__(self, channels, skip_rescale=False, init_scale=0., n_heads=1, n_head_channels=-1, n_frames=1): 142 | super().__init__() 143 | self.N = n_frames 144 | self.channels = self.Cin = channels // n_frames 145 | num_groups = min(self.channels // 4, 32) 146 | while (self.channels % num_groups != 0): 147 | num_groups -= 1 148 | self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=self.channels, 149 | eps=1e-6) 150 | self.NIN_0 = NIN3d(self.channels, self.channels) 151 | self.NIN_1 = NIN3d(self.channels, self.channels) 152 | self.NIN_2 = NIN3d(self.channels, self.channels) 153 | self.NIN_3 = NIN3d(self.channels, self.channels, init_scale=init_scale) 154 | self.skip_rescale = skip_rescale 155 | if n_head_channels == -1: 156 | self.n_heads = n_heads 157 | else: 158 | if self.channels < n_head_channels: 159 | self.n_heads = 1 160 | else: 161 | assert self.channels % n_head_channels == 0 162 | self.n_heads = self.channels // n_head_channels 163 | 164 | def forward(self, x): 165 | # to 3d shape 166 | B, CN, H, W = x.shape 167 | C = self.Cin 168 | N = self.N 169 | x = x.reshape(B, C, N, H, W) 170 | 171 | h = self.GroupNorm_0(x) 172 | q = self.NIN_0(h) 173 | k = self.NIN_1(h) 174 | v = self.NIN_2(h) 175 | 176 | C = C // self.n_heads 177 | 178 | w = torch.einsum('bcnhw,bcnij->bnhwij', q.reshape(B * self.n_heads, C, N, H, W), k.reshape(B * self.n_heads, C, N, H, W)) * (int(C) ** (-0.5)) 179 | w = torch.reshape(w, (B * self.n_heads, N, H, W, N * H * W)) 180 | w = F.softmax(w, dim=-1) 181 | w = torch.reshape(w, (B * self.n_heads, N, H, W, N, H, W)) 182 | h = torch.einsum('bnhwijk,bcijk->bcnhw', w, v.reshape(B * self.n_heads, C, N, H, W)) 183 | h = h.reshape(B, C * self.n_heads, N, H, W) 184 | h = self.NIN_3(h) 185 | if not self.skip_rescale: 186 | x = x + h 187 | else: 188 | x = (x + h) / np.sqrt(2.) 189 | return x.reshape(B, C*N, H, W) 190 | 191 | class AttnBlockpp3d(nn.Module): 192 | """Channel-wise 3d self-attention block.""" 193 | # Because doing attn over space-time is very memory demanding, we do space, then time 194 | # 1) space-only attn block 195 | # 2) time-only attn block 196 | 197 | def __init__(self, channels, skip_rescale=False, init_scale=0., n_heads=1, n_head_channels=-1, n_frames=1, act=None): 198 | super().__init__() 199 | self.N = n_frames 200 | self.channels = self.Cin = channels // n_frames 201 | self.space_att = AttnBlockpp(channels=self.channels, skip_rescale=skip_rescale, init_scale=init_scale, n_heads=n_heads, n_head_channels=n_head_channels) 202 | self.time_att = AttnBlockpp1d(channels=self.channels, skip_rescale=skip_rescale, init_scale=init_scale, n_heads=n_heads, n_head_channels=n_head_channels) 203 | self.act = act 204 | 205 | def forward(self, x): 206 | B, CN, H, W = x.shape 207 | C = self.Cin 208 | N = self.N 209 | 210 | # Space attention 211 | x = x.reshape(B, C, N, H, W).permute(0, 2, 1, 3, 4).reshape(B*N, C, H, W) 212 | x = self.space_att(x) 213 | x = x.reshape(B, N, C, H, W).permute(0, 2, 1, 3, 4) # B, C, N, H, W 214 | 215 | if self.act is not None: 216 | x = self.act(x) 217 | 218 | # Time attention 219 | x = x.permute(0, 3, 4, 1, 2).reshape(B*H*W, C, N) 220 | x = self.time_att(x) 221 | x = x.reshape(B, H, W, C, N).permute(0, 3, 4, 1, 2).reshape(B, C*N, H, W) 222 | 223 | return x 224 | 225 | class MyConv3d(nn.Module): 226 | """3d convolution.""" 227 | 228 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, bias=True, init_scale=1., padding=0, dilation=1, n_frames=1): 229 | super().__init__() 230 | self.N = n_frames 231 | self.Cin = in_planes // n_frames 232 | self.Cout = out_planes // n_frames 233 | self.conv = nn.Conv3d(self.Cin, self.Cout, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, dilation=dilation) 234 | self.conv.weight.data = default_init(init_scale)(self.conv.weight.data.shape) 235 | nn.init.zeros_(self.conv.bias) 236 | 237 | def forward(self, x): 238 | # to 3d shape 239 | B, CN, H, W = x.shape 240 | x = x.reshape(B, self.Cin, self.N, H, W) 241 | x = self.conv(x) 242 | x = x.reshape(B, self.Cout*self.N, H, W) 243 | return x 244 | 245 | def ddpm_conv1x1_3d(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0, n_frames=1): 246 | """1x1 convolution with DDPM initialization.""" 247 | conv = MyConv3d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias, n_frames=n_frames) 248 | return conv 249 | 250 | def ddpm_conv3x3_3d(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1, n_frames=1): 251 | """3x3 convolution with DDPM initialization.""" 252 | conv = MyConv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, 253 | dilation=dilation, bias=bias, n_frames=n_frames) 254 | return conv 255 | 256 | 257 | class PseudoConv3d(nn.Module): 258 | """Pseudo3d convolution.""" 259 | # Because doing conv over space-time is very memory demanding, we do space, then time 260 | # 1) space-only conv2d 261 | # activation function 262 | # 2) time-only conv1d 263 | 264 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, bias=True, init_scale=1., padding=0, dilation=1, n_frames=1, act=None): 265 | super().__init__() 266 | self.N = n_frames 267 | self.Cin = in_planes // n_frames 268 | self.Cout = out_planes // n_frames 269 | 270 | self.space_conv = nn.Conv2d(self.Cin, self.Cout, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, dilation=dilation) 271 | self.space_conv.weight.data = default_init(init_scale)(self.space_conv.weight.data.shape) 272 | nn.init.zeros_(self.space_conv.bias) 273 | 274 | self.time_conv = nn.Conv1d(self.Cout, self.Cout, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, dilation=dilation) 275 | self.time_conv.weight.data = default_init(init_scale)(self.time_conv.weight.data.shape) 276 | nn.init.zeros_(self.time_conv.bias) 277 | 278 | self.act=act 279 | 280 | def forward(self, x): 281 | B, CN, H, W = x.shape 282 | C = self.Cin 283 | N = self.N 284 | 285 | # Space conv2d B*N, C, H, W 286 | x = x.reshape(B, C, N, H, W).permute(0, 2, 1, 3, 4).reshape(B*N, C, H, W) 287 | x = self.space_conv(x) 288 | C = self.Cout 289 | x = x.reshape(B, N, C, H, W).permute(0, 2, 1, 3, 4) # B, C, N, H, W 290 | 291 | if self.act is not None: 292 | x = self.act(x) 293 | 294 | # Time conv1d B*H*W, C, N 295 | x = x.permute(0, 3, 4, 1, 2).reshape(B*H*W, C, N) 296 | x = self.time_conv(x) 297 | x = x.reshape(B, H, W, C, N).permute(0, 3, 4, 1, 2).reshape(B, C*N, H, W) 298 | 299 | return x 300 | 301 | def ddpm_conv1x1_pseudo3d(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0, n_frames=1, act=None): 302 | """1x1 Pseudo convolution with DDPM initialization.""" 303 | conv = PseudoConv3d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias, n_frames=n_frames, act=act) 304 | return conv 305 | 306 | def ddpm_conv3x3_pseudo3d(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1, n_frames=1, act=None): 307 | """3x3 Pseudo convolution with DDPM initialization.""" 308 | conv = PseudoConv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, 309 | dilation=dilation, bias=bias, n_frames=n_frames, act=act) 310 | return conv 311 | --------------------------------------------------------------------------------