├── LICENSE ├── README.md ├── config.py ├── configs └── basic_conf.yaml ├── depth_inference.py ├── dream.py ├── examples ├── dream_example.gif └── example_img │ └── test_img1.jpg ├── requirements.txt ├── spyNet.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jonas Massa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Headtrip 2 | 3 | - Work in Progress 4 | 5 | ![](examples/dream_example.gif) 6 | 7 | [![](https://img.youtube.com/vi/Cd5LNeT5wHI/0.jpg)](https://youtu.be/Cd5LNeT5wHI) 8 | 9 | Single Deep Dreaming and Sequence Dreaming with Optical Flow and Depth Estimation in Pytorch. 10 | 11 | Check out my [article](https://towardsdatascience.com/sequence-dreaming-with-depth-estimation-in-pytorch-d754cba14d30) for the setup! 12 | 13 | [@aertist](https://github.com/aertist) made a colab notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1hZFeBtLTY1nUBxC0G_qvbz4ZilUh67rr?usp=sharing) 14 | 15 | Features: 16 | 17 | - Sequence Dreaming with Optical Flow (Farneback or SpyNet) 18 | - Depth Estimation with [MiDas](https://pytorch.org/hub/intelisl_midas_v2/) 19 | - Supports multiple Pytorch Architectures 20 | - Dream Single Class of ImageNet, check my post [here](https://towardsdatascience.com/deep-lucid-dreaming-94fecd3cd46d) 21 | 22 | # Install 23 | 24 | ## Requirements 25 | 26 | 1. Python 3.7 27 | 2. Pytorch 1.7 28 | 3. OpenCV 29 | 4. Matplotlib 30 | 31 | ``` 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | Depending on your cuda version, you might get an error installing pytorch in your env. 36 | 37 | # Usage 38 | 39 | For inference you need a config file like the basic*conf in the \_configs* folder. 40 | 41 | ``` 42 | python dream.py --config configs/basic_conf.yaml 43 | ``` 44 | 45 | Currently, its only possible to dream on a sequence of frames from a video that are 46 | extracted beforehand with e.g. _ffmpeg_ 47 | 48 | - The SpyNet Code is adapted from this [github repository](https://github.com/sniklaus/pytorch-spynet) 49 | 50 | [**Buy me a coffee! :coffee:**](https://www.buymeacoffee.com/beinabih) 51 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import yaml 6 | 7 | 8 | def load_config(): 9 | parser = argparse.ArgumentParser(description="fusion_training") 10 | parser.add_argument( 11 | "--config", type=str, help="Path to the YAML config file", required=True 12 | ) 13 | args = parser.parse_args() 14 | config = _load_config_yaml(args.config) 15 | return config 16 | 17 | 18 | def _load_config_yaml(config_file): 19 | return yaml.safe_load(open(config_file, "r")) 20 | -------------------------------------------------------------------------------- /configs/basic_conf.yaml: -------------------------------------------------------------------------------- 1 | ######### Hyperparameters 2 | num_iterations: 10 3 | num_octaves: 20 4 | octave_scale: 1.2 5 | lr: 0.008 6 | # Random Hyperparameter 7 | random: False 8 | ######### 9 | # Guided Dreaming on/off 10 | guided: True 11 | # takes the max classd 12 | max_output: False 13 | # takes the max class on each octave 14 | pyramid_max: True 15 | # takes all the classes from the list 16 | channel_list: [863] 17 | # if you want to dream on other layers 18 | no_class: False 19 | at_layer: 26 20 | ########## 21 | # optical flow spynet or opencv farneback 22 | use_spynet: False 23 | ########### 24 | use_depth: False 25 | depth_str: 1.2 26 | # use Threshold on mask to set specific areas to 0 27 | use_threshold: True 28 | th_val: 0.2 29 | invert_depth: False #foreground/background 30 | ######### 31 | pretrained: True 32 | ######## 33 | # input and output path 34 | input: "examples/example_img/" 35 | outpath: "examples/output" 36 | # sequence startposition when sorted 37 | start_position: 0 38 | ###### 39 | #c hange model here, available are 40 | # resnet, vgg, alexnet, inception, densenet, mobile, resnetx, squeeze, masnet, googlenet 41 | model: vgg19 42 | 43 | # set true for video sequence 44 | seq: True 45 | 46 | # float precision 16 47 | fp16: True 48 | -------------------------------------------------------------------------------- /depth_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import torch 7 | 8 | 9 | class MiDaS: 10 | def __init__(self, use_large_model): 11 | 12 | os.makedirs("depth_model", exist_ok=True) 13 | torch.hub.set_dir("depth_model") 14 | 15 | if use_large_model: 16 | self.midas = torch.hub.load( 17 | "intel-isl/MiDaS", "MiDaS", _use_new_zipfile_serialization=False 18 | ) 19 | else: 20 | self.midas = torch.hub.load( 21 | "intel-isl/MiDaS", "MiDaS_small", _use_new_zipfile_serialization=False 22 | ) 23 | 24 | self.device = ( 25 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 26 | ) 27 | self.midas.to(self.device) 28 | self.midas.eval() 29 | 30 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") 31 | 32 | if use_large_model: 33 | self.transform = midas_transforms.default_transform 34 | else: 35 | self.transform = midas_transforms.small_transform 36 | 37 | def inference(self, img): 38 | 39 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 40 | 41 | input_batch = self.transform(img).to(self.device) 42 | 43 | with torch.no_grad(): 44 | prediction = self.midas(input_batch) 45 | 46 | prediction = torch.nn.functional.interpolate( 47 | prediction.unsqueeze(1), 48 | size=img.shape[:2], 49 | mode="bicubic", 50 | align_corners=False, 51 | ).squeeze() 52 | 53 | return prediction.detach().cpu().numpy() 54 | -------------------------------------------------------------------------------- /dream.py: -------------------------------------------------------------------------------- 1 | """Summary 2 | """ 3 | import argparse 4 | import glob 5 | import os 6 | import random 7 | 8 | 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import scipy.ndimage as nd 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torchvision.transforms as transforms 17 | import tqdm 18 | from PIL import Image 19 | from torch.autograd import Variable 20 | from torchvision import models 21 | 22 | from config import load_config 23 | from depth_inference import MiDaS 24 | from spyNet import calc_opflow 25 | from utils import clip, convert, deprocess, get_octaves, preprocess, warp 26 | 27 | 28 | class Dreamer: 29 | def __init__(self, img_p, outpath, config): 30 | """Summary 31 | 32 | Args: 33 | model (TYPE): Description 34 | batchsize (TYPE): Description 35 | img_p (TYPE): Description 36 | outpath (TYPE): Description 37 | config (TYPE): Description 38 | """ 39 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 40 | 41 | self.config = config 42 | self.img_p = img_p 43 | self.outpath = outpath 44 | self.init_model() 45 | self.norm_str = 5 46 | 47 | self.octave_list = [1.1, 1.2, 1.3, 1.4, 1.5] 48 | self.num_octaves_para = config["num_octaves"] 49 | self.octave_scale = config["octave_scale"] 50 | self.at_layer_para = config["at_layer"] 51 | self.lr = config["lr"] 52 | self.random = config["random"] 53 | self.no_class = config["no_class"] 54 | self.ch_list = config["channel_list"] 55 | self.img_list = sorted(glob.glob(img_p)) 56 | self.depth = config["use_depth"] 57 | self.depth_w = config["depth_str"] 58 | 59 | self.depth_model = MiDaS(False) 60 | 61 | self.loss = nn.BCEWithLogitsLoss() 62 | 63 | self.transform = transforms.Compose( 64 | [ 65 | transforms.ToTensor(), 66 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 67 | ] 68 | ) 69 | 70 | if self.random: 71 | self.random_para() 72 | 73 | if self.no_class: 74 | self.layers = list(self.model.features.children()) 75 | self.model = nn.Sequential(*self.layers[: (self.at_layer_para + 1)]) 76 | self.norm_str = 1 77 | 78 | print( 79 | self.config["num_iterations"], 80 | self.octave_scale, 81 | self.config["num_octaves"], 82 | self.lr, 83 | ) 84 | 85 | def init_model(self): 86 | """initializes the model with the config file""" 87 | if self.config["model"] == "resnet": 88 | network = models.resnext50_32x4d(pretrained=True) 89 | elif self.config["model"] == "vgg19": 90 | network = models.vgg19(pretrained=True) 91 | elif self.config["model"] == "densenet": 92 | network = models.densenet121(pretrained=True) 93 | elif self.config["model"] == "inception": 94 | network = models.inception_v3(pretrained=True) 95 | elif self.config["model"] == "mobile": 96 | network = models.mobilenet_v2(pretrained=True) 97 | elif self.config["model"] == "shuffle": 98 | network = models.shufflenet_v2_x0_5(pretrained=True) 99 | elif self.config["model"] == "squeeze": 100 | network = models.squeezenet1_1(pretrained=True) 101 | elif self.config["model"] == "resnetx": 102 | network = models.resnext101_32x8d(pretrained=True) 103 | elif self.config["model"] == "masnet": 104 | network = models.mnasnet1_0(pretrained=True) 105 | elif self.config["model"] == "googlenet": 106 | network = models.googlenet(pretrained=True) 107 | elif self.config["model"] == "alexnet": 108 | network = models.alexnet(pretrained=True) 109 | else: 110 | print("Invalid Model") 111 | 112 | network.eval() 113 | self.model = network.to(self.device) 114 | 115 | if self.config["fp16"]: 116 | scaler = torch.cuda.amp.GradScaler() 117 | self.model = self.model.half() 118 | # amp.register_float_function(torch, "batch_norm") 119 | # self.model = amp.initialize(self.model, opt_level="O2") 120 | 121 | def forward(self, model, image, z, d_img=None, mask=None): 122 | """Summary 123 | 124 | Args: 125 | model (TYPE): Description 126 | image (TYPE): Description 127 | z (TYPE): Description 128 | d_img (None, optional): Description 129 | mask (None, optional): Description 130 | 131 | Returns: 132 | TYPE: Description 133 | """ 134 | model.zero_grad() 135 | 136 | if self.config["fp16"]: 137 | with torch.cuda.amp.autocast(): 138 | out = model(image) 139 | else: 140 | out = model(image) 141 | 142 | if self.config["guided"]: 143 | target = self.get_target(self.config, z, out) 144 | loss = -self.loss(out, target) 145 | else: 146 | loss = out.norm() 147 | 148 | loss.backward() 149 | 150 | avg_grad = np.abs(image.grad.data.cpu().numpy()).mean() 151 | norm_lr = self.lr / avg_grad 152 | grad = image.grad.data 153 | 154 | dream_grad = grad * (norm_lr * self.norm_str) 155 | 156 | if self.depth: 157 | d_img = torch.from_numpy(d_img) 158 | d_img = d_img[0, 0].to(self.device) 159 | 160 | dream_grad *= d_img * self.depth_w 161 | 162 | if mask is not None: 163 | mask = torch.from_numpy(mask) 164 | dream_grad *= mask.to(self.device) 165 | 166 | image.data += dream_grad 167 | image.data = clip(image.data) 168 | image.grad.data.zero_() 169 | 170 | return image 171 | 172 | def get_target(self, config, z, out): 173 | """Summary 174 | 175 | Args: 176 | config (Dictionary): Config File 177 | z (Integer): Iteration Value 178 | out (tensor): Model Output 179 | 180 | Returns: 181 | Tensor: Target Tensor for guided dreaming 182 | """ 183 | target = torch.zeros((1, 1000)).to(self.device) 184 | 185 | if config["max_output"]: 186 | out = out.float() 187 | 188 | if config["pyramid_max"]: 189 | if z == 0: 190 | self.channel = out.argmax() 191 | target[0, self.channel] = 100 192 | else: 193 | out[0, self.channel] = 0 194 | self.channel = out.argmax() 195 | target = torch.zeros((1, 1000)).to(self.device) 196 | target[0, self.channel] = 100 197 | else: 198 | self.channel = out.argmax() 199 | target[0, self.channel] = 100 200 | 201 | else: 202 | for ch in config["channel_list"]: 203 | target[0, ch] = 100 204 | 205 | return target 206 | 207 | def dream(self, image, model, d_img=None, mask=None): 208 | """Updates the image to maximize outputs for n iterations 209 | 210 | Args: 211 | image (TYPE): Description 212 | model (TYPE): Description 213 | d_img (None, optional): Description 214 | mask (None, optional): Description 215 | 216 | Returns: 217 | TYPE: Description 218 | """ 219 | Tensor = ( 220 | torch.cuda.FloatTensor if torch.cuda.is_available else torch.FloatTensor 221 | ) 222 | image = Variable(Tensor(image), requires_grad=True) 223 | 224 | for n in range(self.config["num_iterations"]): 225 | image = self.forward(model, image, n, d_img, mask) 226 | 227 | return image.cpu().data.numpy() 228 | 229 | def deep_dream(self, image, model, i, seq, mask=None): 230 | """Main deep dream method 231 | 232 | Args: 233 | image (TYPE): Description 234 | model (TYPE): Description 235 | i (TYPE): Description 236 | seq (TYPE): Description 237 | mask (None, optional): Description 238 | 239 | Returns: 240 | TYPE: Description 241 | """ 242 | 243 | image_p = image.unsqueeze(0).cpu().detach().numpy() 244 | args = [] 245 | 246 | octaves = get_octaves(image_p, self.config["num_octaves"], self.octave_scale) 247 | 248 | if self.depth: 249 | d_img = self.depth_model.inference(convert(image_p)) 250 | d_img = d_img / np.max(d_img) 251 | 252 | if self.config["invert_depth"]: 253 | d_img = 1 - d_img 254 | 255 | if self.config["use_threshold"]: 256 | d_img[d_img < self.config["th_val"]] = 0 257 | 258 | d_img = np.expand_dims(d_img, 0) 259 | d_img = np.expand_dims(d_img, 0) 260 | d_img_octaves = get_octaves( 261 | d_img, self.config["num_octaves"], self.octave_scale 262 | ) 263 | d_img_octaves = d_img_octaves[::-1] 264 | args.append(d_img_octaves) 265 | 266 | if mask is not None: 267 | mask = np.transpose(mask, (2, 0, 1)) 268 | mask = np.expand_dims(mask, 0) 269 | octaves_mask = get_octaves( 270 | mask, self.config["num_octaves"], self.octave_scale 271 | ) 272 | octaves_mask = octaves_mask[::-1] 273 | args.append(octaves_mask) 274 | 275 | kernel = np.ones((5, 5), np.uint8) 276 | self.detail = np.zeros_like(octaves[-1]) 277 | for octave, octave_base in enumerate(tqdm.tqdm(octaves[::-1], desc="Dreaming")): 278 | if octave > 0: 279 | # Upsample detail to new octave dimension 280 | self.detail = nd.zoom( 281 | self.detail, 282 | np.array(octave_base.shape) / np.array(self.detail.shape), 283 | order=1, 284 | ) 285 | 286 | input_image = octave_base + self.detail 287 | 288 | dreamed_image = self.dream( 289 | input_image, model, *map(lambda x: x[octave], args) 290 | ) 291 | 292 | self.detail = dreamed_image - octave_base 293 | 294 | return input_image 295 | 296 | def save_img(self, img, suffix, iter_): 297 | """Summary 298 | 299 | Args: 300 | img (numpy array): Output Image 301 | suffix (string): filename suffix 302 | iter_ (integer): the iteration value 303 | """ 304 | img = deprocess(img) 305 | img = np.clip(img, 0, 1) 306 | file_name = self.img_list[self.config["start_position"] + iter_] 307 | file_name = file_name.split("/")[-1] 308 | plt.imsave(self.outpath + "/{}{}".format(suffix, file_name), img) 309 | 310 | def get_opflow_image(self, img1, dream_img, img2): 311 | """Calculates the optical flow with opencv and the spynet 312 | 313 | Args: 314 | img1 (TYPE): Description 315 | dream_img (TYPE): Description 316 | img2 (TYPE): Description 317 | 318 | Returns: 319 | TYPE: Description 320 | """ 321 | img1 = np.float32(img1) 322 | dream_img = np.float32(dream_img) 323 | img2 = np.float32(img2) 324 | 325 | h, w, c = img1.shape 326 | if self.config["use_spynet"]: 327 | flow = calc_opflow(np.uint8(img1), np.uint8(img2)) 328 | flow = np.transpose(np.float32(flow), (1, 2, 0)) 329 | else: 330 | grayImg1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY) 331 | grayImg2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY) 332 | 333 | flow = cv2.calcOpticalFlowFarneback( 334 | grayImg1, 335 | grayImg2, 336 | pyr_scale=0.5, 337 | levels=3, 338 | winsize=15, 339 | iterations=3, 340 | poly_n=3, 341 | poly_sigma=1.2, 342 | flags=0, 343 | flow=1, 344 | ) 345 | 346 | inv_flow = flow 347 | flow = -flow 348 | 349 | flow[:, :, 0] += np.arange(w) 350 | flow[:, :, 1] += np.arange(h)[:, np.newaxis] 351 | 352 | halludiff = cv2.addWeighted(img2, 0.1, dream_img, 0.9, 0) - img1 353 | halludiff = cv2.remap(halludiff, flow, None, cv2.INTER_LINEAR) 354 | hallu_flow = img2 + halludiff 355 | 356 | magnitude, angle = cv2.cartToPolar(inv_flow[..., 0], inv_flow[..., 1]) 357 | norm_mag = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX) 358 | ret, mask = cv2.threshold(norm_mag, 6, 255, cv2.THRESH_BINARY) 359 | flow_mask = mask.astype(np.uint8).reshape((h, w, 1)) 360 | 361 | blendstatic = 0.1 362 | background_blendimg = cv2.addWeighted( 363 | img2, (1 - blendstatic), dream_img, blendstatic, 0 364 | ) 365 | background_masked = cv2.bitwise_and( 366 | background_blendimg, background_blendimg, mask=cv2.bitwise_not(flow_mask) 367 | ) 368 | 369 | return hallu_flow, background_masked 370 | 371 | def random_para(self): 372 | """chooses random parameters""" 373 | self.config["num_iterations"] = random.randint(2, 14) 374 | if not self.config["guided"]: 375 | self.at_layer_para = random.randint(10, 38) 376 | self.config["num_octaves"] = random.randint(30, 40) 377 | self.lr = random.choice([0.01, 0.009, 0.008, 0.02, 0.03, 0.007]) 378 | self.octave_scale = random.choice(self.octave_list) 379 | 380 | def dream_single(self): 381 | """Dreams independent frames""" 382 | for i, path in enumerate(self.img_list): 383 | img1 = Image.open(path) 384 | d_img = self.deep_dream(self.transform(img1), self.model, i, seq="first") 385 | 386 | self.save_img(d_img, "", i) 387 | 388 | def dream_seq(self): 389 | """Dreams a sequence with optical flow""" 390 | 391 | for i, path in enumerate(self.img_list[self.config["start_position"] :]): 392 | 393 | if i == 0: 394 | img1 = Image.open(path) 395 | d_img = self.deep_dream( 396 | self.transform(img1), self.model, i, seq="first" 397 | ) 398 | 399 | self.save_img(d_img, "", i) 400 | d_img = convert(d_img) 401 | flow_iter = 0 402 | 403 | # the iterations needs to be reduced 404 | self.config["num_iterations"] -= 5 405 | 406 | if i > 0: 407 | img2 = Image.open(path) 408 | feature_img, background_masked = self.get_opflow_image( 409 | img1, d_img, img2 410 | ) 411 | 412 | feature_img = np.clip(feature_img, 0, 255) 413 | 414 | background_masked[background_masked > 0] = 1 - (flow_iter * 0.1) # 0.5 415 | background_masked[background_masked == 0] = flow_iter * 0.1 416 | 417 | d_img = self.deep_dream( 418 | self.transform(np.uint8(feature_img)), 419 | self.model, 420 | i, 421 | seq="first", 422 | mask=background_masked, 423 | ) 424 | 425 | # change position 426 | img1 = img2 427 | self.save_img(d_img, "", i) 428 | d_img = convert(d_img) 429 | flow_iter += 1 430 | flow_iter = 0 if flow_iter > 5 else flow_iter 431 | 432 | 433 | def start_dreamer(config): 434 | """ 435 | 436 | Args: 437 | config (Dictionary): The config file 438 | """ 439 | pretrained = config["pretrained"] 440 | 441 | # Load image 442 | if os.path.isdir(config["input"]): 443 | img_p = config["input"] + "/*" 444 | elif os.path.isfile(config["input"]): 445 | img_p = config["input"] 446 | else: 447 | raise Exception("Wrong Input") 448 | 449 | outpath = config["outpath"] 450 | os.makedirs(outpath, exist_ok=True) 451 | 452 | dreamer = Dreamer(img_p, outpath, config) 453 | if config["seq"]: 454 | dreamer.dream_seq() 455 | else: 456 | dreamer.dream_single() 457 | 458 | 459 | if __name__ == "__main__": 460 | 461 | parser = argparse.ArgumentParser() 462 | parser.add_argument("--config", default="", type=str) 463 | opt = parser.parse_args() 464 | 465 | config = load_config() 466 | 467 | start_dreamer(config) 468 | -------------------------------------------------------------------------------- /examples/dream_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beinabih/Pytorch-HeadTrip/8c7f1955219cb465ebc5aefef94b4dc8103c3205/examples/dream_example.gif -------------------------------------------------------------------------------- /examples/example_img/test_img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beinabih/Pytorch-HeadTrip/8c7f1955219cb465ebc5aefef94b4dc8103c3205/examples/example_img/test_img1.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | awscli==1.16.230 2 | certifi==2020.12.5 3 | colorama==0.3.9 4 | cycler==0.10.0 5 | docutils==0.15.2 6 | jmespath==0.9.4 7 | kiwisolver==1.3.1 8 | matplotlib==3.3.4 9 | numpy==1.20.1 10 | opencv-python==4.5.1.48 11 | Pillow==8.1.1 12 | pyasn1==0.4.7 13 | pyparsing==2.4.7 14 | python-dateutil==2.8.1 15 | PyYAML==5.4.1 16 | rsa==3.4.2 17 | scipy==1.6.1 18 | six==1.15.0 19 | torch==1.7.1+cu101 20 | torchaudio==0.7.2 21 | torchvision==0.8.2+cu101 22 | tqdm==4.58.0 23 | typing-extensions==3.7.4.3 24 | -------------------------------------------------------------------------------- /spyNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # This code was adapted from https://github.com/sniklaus/pytorch-spynet.git 4 | 5 | import getopt 6 | import math 7 | import os 8 | import sys 9 | 10 | import numpy 11 | import PIL 12 | import PIL.Image 13 | import torch 14 | 15 | ########################################################## 16 | 17 | assert ( 18 | int(str("").join(torch.__version__.split(".")[0:2])) >= 13 19 | ) # requires at least pytorch version 1.3.0 20 | 21 | # torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance 22 | 23 | torch.backends.cudnn.enabled = ( 24 | True # make sure to use cudnn for computational performance 25 | ) 26 | 27 | # ########################################################## 28 | 29 | arguments_strModel = "sintel-final" 30 | arguments_strFirst = "./images/first.png" 31 | arguments_strSecond = "./images/second.png" 32 | arguments_strOut = "./out.flo" 33 | 34 | # for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]: 35 | # if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use, see below 36 | # if strOption == '--first' and strArgument != '': arguments_strFirst = strArgument # path to the first frame 37 | # if strOption == '--second' and strArgument != '': arguments_strSecond = strArgument # path to the second frame 38 | # if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored 39 | # # end 40 | 41 | # ########################################################## 42 | 43 | backwarp_tenGrid = {} 44 | 45 | 46 | def backwarp(tenInput, tenFlow): 47 | if str(tenFlow.shape) not in backwarp_tenGrid: 48 | tenHor = ( 49 | torch.linspace( 50 | -1.0 + (1.0 / tenFlow.shape[3]), 51 | 1.0 - (1.0 / tenFlow.shape[3]), 52 | tenFlow.shape[3], 53 | ) 54 | .view(1, 1, 1, -1) 55 | .expand(-1, -1, tenFlow.shape[2], -1) 56 | ) 57 | tenVer = ( 58 | torch.linspace( 59 | -1.0 + (1.0 / tenFlow.shape[2]), 60 | 1.0 - (1.0 / tenFlow.shape[2]), 61 | tenFlow.shape[2], 62 | ) 63 | .view(1, 1, -1, 1) 64 | .expand(-1, -1, -1, tenFlow.shape[3]) 65 | ) 66 | 67 | backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda() 68 | # end 69 | 70 | tenFlow = torch.cat( 71 | [ 72 | tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 73 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), 74 | ], 75 | 1, 76 | ) 77 | 78 | return torch.nn.functional.grid_sample( 79 | input=tenInput, 80 | grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), 81 | mode="bilinear", 82 | padding_mode="border", 83 | align_corners=False, 84 | ) 85 | 86 | 87 | # end 88 | 89 | ########################################################## 90 | 91 | 92 | class Network(torch.nn.Module): 93 | def __init__(self): 94 | super(Network, self).__init__() 95 | 96 | class Preprocess(torch.nn.Module): 97 | def __init__(self): 98 | super(Preprocess, self).__init__() 99 | 100 | # end 101 | 102 | def forward(self, tenInput): 103 | tenBlue = (tenInput[:, 0:1, :, :] - 0.406) / 0.225 104 | tenGreen = (tenInput[:, 1:2, :, :] - 0.456) / 0.224 105 | tenRed = (tenInput[:, 2:3, :, :] - 0.485) / 0.229 106 | 107 | return torch.cat([tenRed, tenGreen, tenBlue], 1) 108 | 109 | # end 110 | 111 | # end 112 | 113 | class Basic(torch.nn.Module): 114 | def __init__(self, intLevel): 115 | super(Basic, self).__init__() 116 | 117 | self.netBasic = torch.nn.Sequential( 118 | torch.nn.Conv2d( 119 | in_channels=8, 120 | out_channels=32, 121 | kernel_size=7, 122 | stride=1, 123 | padding=3, 124 | ), 125 | torch.nn.ReLU(inplace=False), 126 | torch.nn.Conv2d( 127 | in_channels=32, 128 | out_channels=64, 129 | kernel_size=7, 130 | stride=1, 131 | padding=3, 132 | ), 133 | torch.nn.ReLU(inplace=False), 134 | torch.nn.Conv2d( 135 | in_channels=64, 136 | out_channels=32, 137 | kernel_size=7, 138 | stride=1, 139 | padding=3, 140 | ), 141 | torch.nn.ReLU(inplace=False), 142 | torch.nn.Conv2d( 143 | in_channels=32, 144 | out_channels=16, 145 | kernel_size=7, 146 | stride=1, 147 | padding=3, 148 | ), 149 | torch.nn.ReLU(inplace=False), 150 | torch.nn.Conv2d( 151 | in_channels=16, 152 | out_channels=2, 153 | kernel_size=7, 154 | stride=1, 155 | padding=3, 156 | ), 157 | ) 158 | 159 | # end 160 | 161 | def forward(self, tenInput): 162 | return self.netBasic(tenInput) 163 | 164 | # end 165 | 166 | # end 167 | 168 | self.netPreprocess = Preprocess() 169 | 170 | self.netBasic = torch.nn.ModuleList([Basic(intLevel) for intLevel in range(6)]) 171 | 172 | self.load_state_dict( 173 | { 174 | strKey.replace("module", "net"): tenWeight 175 | for strKey, tenWeight in torch.hub.load_state_dict_from_url( 176 | url="http://content.sniklaus.com/github/pytorch-spynet/network-" 177 | + arguments_strModel 178 | + ".pytorch", 179 | file_name="spynet-" + arguments_strModel, 180 | ).items() 181 | } 182 | ) 183 | 184 | # end 185 | 186 | def forward(self, tenFirst, tenSecond): 187 | tenFlow = [] 188 | 189 | tenFirst = [self.netPreprocess(tenFirst)] 190 | tenSecond = [self.netPreprocess(tenSecond)] 191 | 192 | for intLevel in range(5): 193 | if tenFirst[0].shape[2] > 32 or tenFirst[0].shape[3] > 32: 194 | tenFirst.insert( 195 | 0, 196 | torch.nn.functional.avg_pool2d( 197 | input=tenFirst[0], 198 | kernel_size=2, 199 | stride=2, 200 | count_include_pad=False, 201 | ), 202 | ) 203 | tenSecond.insert( 204 | 0, 205 | torch.nn.functional.avg_pool2d( 206 | input=tenSecond[0], 207 | kernel_size=2, 208 | stride=2, 209 | count_include_pad=False, 210 | ), 211 | ) 212 | # end 213 | # end 214 | 215 | tenFlow = tenFirst[0].new_zeros( 216 | [ 217 | tenFirst[0].shape[0], 218 | 2, 219 | int(math.floor(tenFirst[0].shape[2] / 2.0)), 220 | int(math.floor(tenFirst[0].shape[3] / 2.0)), 221 | ] 222 | ) 223 | 224 | for intLevel in range(len(tenFirst)): 225 | tenUpsampled = ( 226 | torch.nn.functional.interpolate( 227 | input=tenFlow, scale_factor=2, mode="bilinear", align_corners=True 228 | ) 229 | * 2.0 230 | ) 231 | 232 | if tenUpsampled.shape[2] != tenFirst[intLevel].shape[2]: 233 | tenUpsampled = torch.nn.functional.pad( 234 | input=tenUpsampled, pad=[0, 0, 0, 1], mode="replicate" 235 | ) 236 | if tenUpsampled.shape[3] != tenFirst[intLevel].shape[3]: 237 | tenUpsampled = torch.nn.functional.pad( 238 | input=tenUpsampled, pad=[0, 1, 0, 0], mode="replicate" 239 | ) 240 | 241 | tenFlow = ( 242 | self.netBasic[intLevel]( 243 | torch.cat( 244 | [ 245 | tenFirst[intLevel], 246 | backwarp( 247 | tenInput=tenSecond[intLevel], tenFlow=tenUpsampled 248 | ), 249 | tenUpsampled, 250 | ], 251 | 1, 252 | ) 253 | ) 254 | + tenUpsampled 255 | ) 256 | # end 257 | 258 | return tenFlow 259 | 260 | # end 261 | 262 | 263 | # end 264 | 265 | netNetwork = None 266 | 267 | ########################################################## 268 | 269 | 270 | def estimate(tenFirst, tenSecond): 271 | global netNetwork 272 | 273 | if netNetwork is None: 274 | netNetwork = Network().cuda().eval() 275 | # end 276 | 277 | assert tenFirst.shape[1] == tenSecond.shape[1] 278 | assert tenFirst.shape[2] == tenSecond.shape[2] 279 | 280 | intWidth = tenFirst.shape[2] 281 | intHeight = tenFirst.shape[1] 282 | 283 | # assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 284 | # assert(intHeight == 416) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 285 | 286 | tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) 287 | tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) 288 | 289 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 32.0) * 32.0)) 290 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 32.0) * 32.0)) 291 | 292 | tenPreprocessedFirst = torch.nn.functional.interpolate( 293 | input=tenPreprocessedFirst, 294 | size=(intPreprocessedHeight, intPreprocessedWidth), 295 | mode="bilinear", 296 | align_corners=False, 297 | ) 298 | tenPreprocessedSecond = torch.nn.functional.interpolate( 299 | input=tenPreprocessedSecond, 300 | size=(intPreprocessedHeight, intPreprocessedWidth), 301 | mode="bilinear", 302 | align_corners=False, 303 | ) 304 | 305 | tenFlow = torch.nn.functional.interpolate( 306 | input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), 307 | size=(intHeight, intWidth), 308 | mode="bilinear", 309 | align_corners=False, 310 | ) 311 | 312 | tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) 313 | tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) 314 | 315 | return tenFlow[0, :, :, :].cpu() 316 | 317 | 318 | # end 319 | 320 | ########################################################## 321 | 322 | 323 | def calc_opflow(img1, img2): 324 | 325 | torch.set_grad_enabled(False) 326 | 327 | img1 = PIL.Image.fromarray(img1) 328 | img2 = PIL.Image.fromarray(img2) 329 | 330 | tenFirst = torch.FloatTensor( 331 | numpy.ascontiguousarray( 332 | numpy.array(img1)[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) 333 | * (1.0 / 255.0) 334 | ) 335 | ) 336 | tenSecond = torch.FloatTensor( 337 | numpy.ascontiguousarray( 338 | numpy.array(img2)[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) 339 | * (1.0 / 255.0) 340 | ) 341 | ) 342 | 343 | tenOutput = estimate(tenFirst, tenSecond) 344 | 345 | torch.set_grad_enabled(True) 346 | 347 | return tenOutput 348 | 349 | 350 | if __name__ == "__main__": 351 | tenFirst = torch.FloatTensor( 352 | numpy.ascontiguousarray( 353 | numpy.array(PIL.Image.open(arguments_strFirst))[:, :, ::-1] 354 | .transpose(2, 0, 1) 355 | .astype(numpy.float32) 356 | * (1.0 / 255.0) 357 | ) 358 | ) 359 | tenSecond = torch.FloatTensor( 360 | numpy.ascontiguousarray( 361 | numpy.array(PIL.Image.open(arguments_strSecond))[:, :, ::-1] 362 | .transpose(2, 0, 1) 363 | .astype(numpy.float32) 364 | * (1.0 / 255.0) 365 | ) 366 | ) 367 | 368 | tenOutput = estimate(tenFirst, tenSecond) 369 | 370 | objOutput = open(arguments_strOut, "wb") 371 | 372 | # numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput) 373 | # numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput) 374 | numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput) 375 | 376 | objOutput.close() 377 | # end 378 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import scipy.ndimage as nd 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from torchvision import transforms 13 | 14 | mean = np.array([0.485, 0.456, 0.406]) 15 | std = np.array([0.229, 0.224, 0.225]) 16 | 17 | preprocess = transforms.Compose( 18 | [transforms.ToTensor(), transforms.Normalize(mean, std)] 19 | ) 20 | 21 | 22 | def show_img(img): 23 | img = convert(img.detach().cpu().numpy()) 24 | plt.imshow(img) 25 | plt.show() 26 | 27 | 28 | def warp(x, flo): 29 | """ 30 | warp an image/tensor (im2) back to im1, according to the optical flow 31 | x: [B, C, H, W] (im2) 32 | flo: [B, 2, H, W] flow 33 | """ 34 | B, C, H, W = x.size() 35 | # mesh grid 36 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1) 37 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W) 38 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 39 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 40 | grid = torch.cat((xx, yy), 1).float() 41 | 42 | if x.is_cuda: 43 | grid = grid.cuda() 44 | flo = flo.cuda() 45 | vgrid = Variable(grid) + flo 46 | 47 | # scale grid to [-1,1] 48 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0 49 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0 50 | 51 | vgrid = vgrid.permute(0, 2, 3, 1) 52 | output = nn.functional.grid_sample(x, vgrid) 53 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda() 54 | mask = nn.functional.grid_sample(mask, vgrid) 55 | 56 | # if W==128: 57 | # np.save('mask.npy', mask.cpu().data.numpy()) 58 | # np.save('warp.npy', output.cpu().data.numpy()) 59 | 60 | mask[mask < 0.9999] = 0 61 | mask[mask > 0] = 1 62 | 63 | return output * mask, mask 64 | 65 | 66 | def get_octaves(img, num_octaves, octave_scale): 67 | octaves = [img] 68 | for _ in range(num_octaves - 1): 69 | new_octave = nd.zoom( 70 | octaves[-1], (1, 1, 1 / octave_scale, 1 / octave_scale), order=1 71 | ) 72 | 73 | if new_octave.shape[2] > 32 and new_octave.shape[3] > 32: 74 | octaves.append(new_octave) 75 | 76 | return octaves 77 | 78 | 79 | def deprocess(image_np): 80 | image_np = image_np.squeeze().transpose(1, 2, 0) 81 | image_np = image_np * std.reshape((1, 1, 3)) + mean.reshape((1, 1, 3)) 82 | image_np = np.clip(image_np, 0.0, 255.0) 83 | return image_np 84 | 85 | 86 | def clip(image_tensor): 87 | for c in range(3): 88 | m, s = mean[c], std[c] 89 | image_tensor[0, c] = torch.clamp(image_tensor[0, c], -m / s, (1 - m) / s) 90 | return image_tensor 91 | 92 | 93 | def convert(img): 94 | image_np = img.squeeze().transpose(1, 2, 0) 95 | image_np = image_np * std.reshape((1, 1, 3)) + mean.reshape((1, 1, 3)) 96 | image_np = image_np * 255 97 | image_np = np.clip(image_np, 0.0, 255.0) 98 | return image_np.astype(np.uint8) 99 | 100 | 101 | def find_contours(img): 102 | img = convert(img) 103 | imgray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 104 | ret, thresh = cv2.threshold(imgray, 127, 255, 0) 105 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 106 | 107 | # for contour in contours: 108 | # hull = cv2.convexHull(contour) 109 | # cv2.drawContours(imgrey,[hull],-1,0,-1) 110 | if contours: 111 | area = [] 112 | for cnt in contours: 113 | area.append(cv2.contourArea(cnt)) 114 | 115 | hull = [] 116 | arr = np.array(area) 117 | sorted_arr = np.flip((np.argsort(arr))) 118 | 119 | mask = np.zeros(img.shape) 120 | 121 | if len(contours) == 1: 122 | hull.append(cv2.convexHull(contours[sorted_arr[0]])) 123 | # img = cv2.drawContours(mask, [hull[0]], 0, (255,255,255), 3) 124 | img = cv2.fillPoly(img, pts=[hull], color=(255, 255, 255)) 125 | elif len(contours) == 2: 126 | for i in range(2): 127 | hull.append(cv2.convexHull(contours[sorted_arr[i]])) 128 | img = cv2.drawContours(mask, [hull[0]], 0, (255, 255, 255), 3) 129 | img = cv2.drawContours(mask, [hull[1]], 0, (255, 255, 255), 3) 130 | img = cv2.fillPoly(img, pts=[hull], color=(255, 255, 255)) 131 | elif len(contours) == 3: 132 | for i in range(3): 133 | hull.append(cv2.convexHull(contours[sorted_arr[i]])) 134 | img = cv2.drawContours(mask, [hull[0]], 0, (255, 255, 255), 3) 135 | img = cv2.drawContours(mask, [hull[1]], 0, (255, 255, 255), 3) 136 | img = cv2.drawContours(mask, [hull[2]], 0, (255, 255, 255), 3) 137 | img = cv2.fillPoly(img, pts=[hull], color=(255, 255, 255)) 138 | elif len(contours) >= 4: 139 | for i in range(4): 140 | hull.append(cv2.convexHull(contours[sorted_arr[i]])) 141 | 142 | for i in range(4): 143 | img = cv2.drawContours(mask, [hull[i]], 0, (255, 255, 255), 3) 144 | img = cv2.fillPoly(img, pts=[hull[i]], color=(255, 255, 255)) 145 | # img = cv2.drawContours(mask, [contours[sorted_arr[1]]], 0, (255,255,255), 3) 146 | # img = cv2.drawContours(mask, [contours[sorted_arr[2]]], 0, (255,255,255), 3) 147 | # img = cv2.drawContours(mask, [contours[sorted_arr[3]]], 0, (255,255,255), 3) 148 | 149 | # img = cv2.drawContours(img, [hull[0]], -1, (0,255,0), -1) 150 | # cv2.imshow('imgcont', img) 151 | # cv2.waitKey(0) 152 | 153 | # img = img.transpose(2, 0, 1) 154 | # img = np.expand_dims(img, 0) 155 | return img[:, :, 0] 156 | 157 | 158 | def make_video(path, name): 159 | 160 | image = cv2.imread(path[0]) 161 | height, width, layers = image.shape 162 | size = (width, height) 163 | 164 | out = cv2.VideoWriter( 165 | name, 166 | cv2.VideoWriter_fourcc(*"MP4V"), 167 | framerate, 168 | size, 169 | ) 170 | 171 | for i in range(65): 172 | img_array = [] 173 | for filename in path: 174 | image = cv2.imread(filename) 175 | img_array.append(image) 176 | 177 | for i in range(len(img_array)): 178 | out.write(img_array[i]) 179 | out.release() 180 | 181 | 182 | def blend(self, img1, img2, blend): 183 | return img1 * (1.0 - blend) + img2 * blend 184 | 185 | 186 | def blend_intermediate(self): 187 | for i in range(1, self.batchsize - 1): 188 | n = i * 1 / (self.batchsize - 1) 189 | blend_grad = self.blend(self.detail_first, self.detail_last, n) 190 | 191 | if i > 1: 192 | self.blend_gradients = np.concatenate((self.blend_gradients, blend_grad), 0) 193 | else: 194 | self.blend_gradients = blend_grad 195 | 196 | # if self.config['seq'] and i > 0: 197 | # ###optical flow 198 | # flow = calc_opflow(self.prev_img, unp_image) 199 | # flow = -flow 200 | # warped_out = warp(self.prev_out[octave], flow) 201 | # loss = -(self.loss(out, target) + self.l1(out, warped_out)) 202 | 203 | 204 | def smooth_grad(grad, octave): 205 | if octave < 5: 206 | smoothing = GaussianSmoothing(3, 7, 10) 207 | inp = F.pad(grad, (3, 3, 3, 3), mode="reflect") 208 | return smoothing(inp) 209 | else: 210 | return grad 211 | 212 | 213 | class GaussianSmoothing(nn.Module): 214 | """ 215 | Apply gaussian smoothing on a 216 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 217 | in the input using a depthwise convolution. 218 | Arguments: 219 | channels (int, sequence): Number of channels of the input tensors. Output will 220 | have this number of channels as well. 221 | kernel_size (int, sequence): Size of the gaussian kernel. 222 | sigma (float, sequence): Standard deviation of the gaussian kernel. 223 | dim (int, optional): The number of dimensions of the data. 224 | Default value is 2 (spatial). 225 | """ 226 | 227 | def __init__(self, channels, kernel_size, sigma, dim=2): 228 | super(GaussianSmoothing, self).__init__() 229 | if isinstance(kernel_size, numbers.Number): 230 | kernel_size = [kernel_size] * dim 231 | if isinstance(sigma, numbers.Number): 232 | sigma = [sigma] * dim 233 | 234 | # The gaussian kernel is the product of the 235 | # gaussian function of each dimension. 236 | kernel = 1 237 | meshgrids = torch.meshgrid( 238 | [torch.arange(size, dtype=torch.float32) for size in kernel_size] 239 | ) 240 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 241 | mean = (size - 1) / 2 242 | kernel *= ( 243 | 1 244 | / (std * math.sqrt(2 * math.pi)) 245 | * torch.exp(-(((mgrid - mean) / std) ** 2) / 2) 246 | ) 247 | 248 | # Make sure sum of values in gaussian kernel equals 1. 249 | kernel = kernel / torch.sum(kernel) 250 | 251 | # Reshape to depthwise convolutional weight 252 | kernel = kernel.view(1, 1, *kernel.size()) 253 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 254 | 255 | self.register_buffer("weight", kernel) 256 | self.groups = channels 257 | 258 | if dim == 1: 259 | self.conv = F.conv1d 260 | elif dim == 2: 261 | self.conv = F.conv2d 262 | elif dim == 3: 263 | self.conv = F.conv3d 264 | else: 265 | raise RuntimeError( 266 | "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim) 267 | ) 268 | 269 | def forward(self, input): 270 | """ 271 | Apply gaussian filter to input. 272 | Arguments: 273 | input (torch.Tensor): Input to apply gaussian filter on. 274 | Returns: 275 | filtered (torch.Tensor): Filtered output. 276 | """ 277 | return self.conv(input, weight=self.weight.cuda(), groups=self.groups) 278 | --------------------------------------------------------------------------------