├── Flows_dataset_raft.ipynb ├── README.md ├── create_txt.ipynb ├── datagen.py ├── image_utils.py ├── inference_future_frames.ipynb ├── inference_online.ipynb ├── matched_features_dataset.ipynb ├── metrics.py ├── requirements.txt ├── stabilize_future_frames.py ├── stabilize_online.py ├── train_vgg19_16x16_future_frames.ipynb ├── train_vgg19_16x16_online.ipynb └── trainlist.txt /Flows_dataset_raft.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "import warnings\n", 13 | "import numpy as np\n", 14 | "import os\n", 15 | "import cv2\n", 16 | "import tqdm\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "device = 'cuda'\n", 19 | "shape = (H,W,C) = (256,256,3)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 7, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "from torchvision import models\n", 29 | "raft = models.optical_flow.raft_small(weights = 'Raft_Small_Weights.C_T_V2').eval().to(device)\n", 30 | "\n", 31 | "def get_flow(img1, img2):\n", 32 | " img1_t = torch.from_numpy(img1/255.0).permute(-1,0,1).unsqueeze(0).float().to(device)\n", 33 | " img2_t = torch.from_numpy(img2/255.0).permute(-1,0,1).unsqueeze(0).float().to(device)\n", 34 | " flow = raft(img1_t,img2_t)[-1].detach().cpu().permute(0,2,3,1).squeeze(0).numpy()\n", 35 | " return flow\n", 36 | "\n", 37 | "def show_flow(flow):\n", 38 | " hsv_mask = np.zeros(shape= flow.shape[:-1] +(3,),dtype = np.uint8)\n", 39 | " hsv_mask[...,1] = 255\n", 40 | " mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1],angleInDegrees=True)\n", 41 | " hsv_mask[:,:,0] = ang /2 \n", 42 | " hsv_mask[:,:,2] = cv2.normalize(mag,None,0,255,cv2.NORM_MINMAX)\n", 43 | " rgb = cv2.cvtColor(hsv_mask,cv2.COLOR_HSV2RGB)\n", 44 | " return(rgb)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 8, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stderr", 54 | "output_type": "stream", 55 | "text": [ 56 | " 9%|▊ | 2/23 [00:25<04:22, 12.52s/it]\n" 57 | ] 58 | }, 59 | { 60 | "ename": "KeyboardInterrupt", 61 | "evalue": "", 62 | "output_type": "error", 63 | "traceback": [ 64 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 65 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 66 | "Cell \u001b[1;32mIn[8], line 11\u001b[0m\n\u001b[0;32m 9\u001b[0m flows \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m 10\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m---> 11\u001b[0m ret,curr \u001b[38;5;241m=\u001b[39m \u001b[43mcap\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m ret: \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[0;32m 13\u001b[0m curr \u001b[38;5;241m=\u001b[39m cv2\u001b[38;5;241m.\u001b[39mresize(curr,(W,H))\n", 67 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: " 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "videos_path = 'E:/Datasets/DeepStab_Dataset/stable/'\n", 73 | "flows_path = 'E:/Datasets/Flows/'\n", 74 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n", 75 | "videos = os.listdir(videos_path)\n", 76 | "for video in tqdm.tqdm(videos):\n", 77 | " cap = cv2.VideoCapture(os.path.join(videos_path, video))\n", 78 | " ret,prev = cap.read()\n", 79 | " prev = cv2.resize(prev,(W,H))\n", 80 | " flows = []\n", 81 | " while True:\n", 82 | " ret,curr = cap.read()\n", 83 | " if not ret: break\n", 84 | " curr = cv2.resize(curr,(W,H))\n", 85 | " flow = get_flow(prev,curr)\n", 86 | " flows.append(flow)\n", 87 | " prev = curr\n", 88 | " cv2.imshow('window',show_flow(flow))\n", 89 | " if cv2.waitKey(1) & 0xFF == ord('9'):\n", 90 | " break\n", 91 | " flows = np.array(flows).astype(np.float32)\n", 92 | " output_path = os.path.join(flows_path,video.split('.')[0] + '.npy')\n", 93 | " np.save(output_path,flows)\n", 94 | "cv2.destroyAllWindows()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "DUTCode", 108 | "language": "python", 109 | "name": "python3" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.9.18" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 2 126 | } 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Online Video Stabilization With Multi-Grid Warping Transformation Learning 2 | 3 | This is a PyTorch implementation of the [paper](https://cg.cs.tsinghua.edu.cn/papers/TIP-2019-VideoStab.pdf). 4 | 5 | ![Video Stabilization Example](https://github.com/btxviny/Deep-Motion-Blind-Video-Stabilization/blob/main/result.gif). 6 | 7 | I provide the original online algorithm described in the paper and a second implementation using a buffer of future frames. The latter can no longer be categorized as an online algorithm but it achieves better stabilization results 8 | 9 | ## Inference Instructions 10 | 11 | Follow these instructions to perform video stabilization using the pretrained model: 12 | 13 | 1. **Download the pretrained models:** 14 | - Download the pretrained models [weights](https://drive.google.com/drive/folders/1K8HfenNEr_0Joi6RdX4SfKVnCg-GjhvW?usp=drive_link). 15 | - Place the downloaded weights folder in the main folder of your project. 16 | 17 | 2. **Run the Stabilization Script:** 18 | - For the original model run: 19 | ```bash 20 | python stabilize_online.py --in_path unstable_video_path --out_path result_path 21 | ``` 22 | - Replace `unstable_video_path` with the path to your input unstable video. 23 | - Replace `result_path` with the desired path for the stabilized output video. 24 | - For the second model with future frames: 25 | ```bash 26 | python stabilize_future_frames.py --in_path unstable_video_path --out_path result_path 27 | ``` 28 | 29 | Make sure you have the necessary dependencies installed, and that your environment is set up correctly before running the stabilization scripts. 30 | ```bash 31 | pip install numpy opencv-python torch==2.1.2 matplotlib 32 | ``` 33 | 34 | 35 | 36 | ## Training Instructions 37 | 38 | Follow these instructions to train the model: 39 | 40 | 1. **Download Datasets:** 41 | - Download the training dataset: [DeepStab](https://cg.cs.tsinghua.edu.cn/people/~miao/stabnet/data.zip). 42 | - Extract the contents of the downloaded dataset to a location on your machine. 43 | 44 | 2. **Create Datasets for Loss Functions:** 45 | - Create the optical flows and matched feature datasets to be used in the loss functions descrined in the paper: 46 | - [Flows_dataset_raft.ipynb](https://github.com/btxviny/StabNet/blob/main/Flows_dataset_raft.ipynb) for optical flow dataset. 47 | - [matched_features_dataset.ipynb](https://github.com/btxviny/StabNet/blob/main/matched_features_dataset.ipynb) for matched feature dataset. 48 | - create a train_list.txt containing the file paths for each sample input, using [create.txt](https://github.com/btxviny/StabNet/blob/main/create_txt.ipynb)(adjust paths as needed). 49 | 50 | 3. **Training Notebooks:** 51 | - Online version: [train_vgg19_16x16_online.ipynb](https://github.com/btxviny/StabNet/blob/main/train_vgg19_16x16_online.ipynb) 52 | - Future frame version: [train_vgg19_16x16_future_frames.ipynb](https://github.com/btxviny/StabNet/edit/main/train_vgg19_16x16_future_frames.ipynb) 53 | - Make sure to change `ckpt_dir` to the destination you want the model checkpoints to be saved at. 54 | 55 | 4. **Metrics Calculation:** 56 | - Use `metrics.py` to compute cropping, distortion, and stability scores for the generated results. 57 | -------------------------------------------------------------------------------- /create_txt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import cv2\n", 11 | "import os \n", 12 | "import random\n", 13 | "import torch\n", 14 | "from torchvision import transforms" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 9, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "stable_prefix = 'E:/Datasets/DeepStab Modded/Stable_60/'\n", 24 | "unstable_prefix = 'E:/Datasets/DeepStab Modded/Unstable/'\n", 25 | "flow_prefix = 'E:/Datasets/DeepStab Modded/Flows_256x256'\n", 26 | "feature_prefix = 'E:/Datasets/DeepStab Modded/matched_features'\n", 27 | "videos = os.listdir(feature_prefix)\n", 28 | "with open('./trainlist.txt','w') as f:\n", 29 | " for vid in videos:\n", 30 | " video_name = vid.split('.')[0] + '.avi'\n", 31 | " s_path = os.path.join(stable_prefix, video_name)\n", 32 | " u_path = os.path.join(unstable_prefix, video_name)\n", 33 | " flow_path = os.path.join(flow_prefix,vid)\n", 34 | " feature_path = os.path.join(feature_prefix,vid)\n", 35 | " s_cap = cv2.VideoCapture(s_path)\n", 36 | " u_cap = cv2.VideoCapture(u_path)\n", 37 | " num_frames = min(int(s_cap.get(cv2.CAP_PROP_FRAME_COUNT)),int(u_cap.get(cv2.CAP_PROP_FRAME_COUNT)))\n", 38 | " for idx in range(33,num_frames):\n", 39 | " line = f'{s_path},{u_path},{flow_path},{feature_path},{idx}\\n'\n", 40 | " f.write(line)\n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [] 49 | } 50 | ], 51 | "metadata": { 52 | "kernelspec": { 53 | "display_name": "DUTCode", 54 | "language": "python", 55 | "name": "python3" 56 | }, 57 | "language_info": { 58 | "codemirror_mode": { 59 | "name": "ipython", 60 | "version": 3 61 | }, 62 | "file_extension": ".py", 63 | "mimetype": "text/x-python", 64 | "name": "python", 65 | "nbconvert_exporter": "python", 66 | "pygments_lexer": "ipython3", 67 | "version": "3.9.16" 68 | }, 69 | "orig_nbformat": 4 70 | }, 71 | "nbformat": 4, 72 | "nbformat_minor": 2 73 | } 74 | -------------------------------------------------------------------------------- /datagen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import random 5 | import torch 6 | 7 | class Datagen: 8 | def __init__(self,shape = (256,256),txt_path = './trainlist.txt'): 9 | self.shape = shape 10 | with open(txt_path,'r') as f: 11 | self.trainlist = f.read().splitlines() 12 | def preprocess(self,img,gray = True): 13 | h,w = self.shape 14 | if gray: 15 | img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) 16 | img = cv2.resize(img,(w,h)) 17 | img = img / 255.0 18 | return img 19 | def __call__(self): 20 | self.trainlist = random.sample(self.trainlist, len(self.trainlist)) 21 | for sample in self.trainlist: 22 | s_path = sample.split(',')[0] 23 | u_path = sample.split(',')[1] 24 | flow_path = sample.split(',')[2] 25 | feature_path = sample.split(',')[3] 26 | idx = int(sample.split(',')[4]) 27 | s_cap = cv2.VideoCapture(s_path) 28 | u_cap = cv2.VideoCapture(u_path) 29 | seq1 = [] 30 | seq2 = [] 31 | s_cap.set(cv2.CAP_PROP_POS_FRAMES, idx - 33) 32 | for i in range(5,-1,-1): 33 | pos = 2 ** i + 1 34 | s_cap.set(cv2.CAP_PROP_POS_FRAMES, idx - pos) 35 | _,temp1 = s_cap.read() 36 | temp1 = self.preprocess(temp1) # -33 37 | temp1 = random_translation(temp1) 38 | seq2.append(temp1) 39 | _,temp2 = s_cap.read() 40 | temp2 = self.preprocess(temp2) # -32 41 | temp2 = random_translation(temp2) 42 | seq1.append(temp2) 43 | seq1 = np.array(seq1,dtype=np.uint8) 44 | seq2 = np.array(seq2,dtype=np.uint8) 45 | _,Igt = s_cap.read() 46 | Igt = self.preprocess(Igt, gray= False) 47 | u_cap.set(cv2.CAP_PROP_POS_FRAMES, idx - 1) 48 | _,It_prev = u_cap.read() 49 | It_prev = self.preprocess(It_prev, gray= False) 50 | _,It_curr = u_cap.read() 51 | It_curr = self.preprocess(It_curr, gray= False) 52 | flow = np.load(flow_path,mmap_mode='r') 53 | flo = torch.from_numpy(flow[idx - 1,...].copy()).permute(-1,0,1).float() 54 | features = np.load(feature_path,mmap_mode='r') 55 | features = torch.from_numpy(features[idx,...].copy()).float() 56 | seq1 = np.flip(seq1,axis = 0) 57 | seq1 = torch.from_numpy(seq1.copy()).float() 58 | seq2 = np.flip(seq2,axis = 0) 59 | seq2 = torch.from_numpy(seq2.copy()).float() 60 | Igt = torch.from_numpy(Igt).permute(-1,0,1).float() 61 | It_prev = torch.from_numpy(It_prev).permute(-1,0,1).float() 62 | It_curr = torch.from_numpy(It_curr).permute(-1,0,1).float() 63 | 64 | yield seq1, seq2, It_curr, Igt, It_prev, flo, features 65 | 66 | def random_translation(img): 67 | img = np.array(img) 68 | (h,w) = img.shape 69 | dx = np.random.randint(-w//8,w//8) 70 | dy = np.random.randint(-h//8,h//8) 71 | mat = np.array([[1,0,dx],[0,1,dy]],dtype=np.float32) 72 | img = cv2.warpAffine(img, mat, (w,h)) 73 | return img -------------------------------------------------------------------------------- /image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | def dense_warp(image, flow): 7 | """ 8 | Densely warps an image using optical flow. 9 | 10 | Args: 11 | image (torch.Tensor): Input image tensor of shape (batch_size, channels, height, width). 12 | flow (torch.Tensor): Optical flow tensor of shape (batch_size, 2, height, width). 13 | 14 | Returns: 15 | torch.Tensor: Warped image tensor of shape (batch_size, channels, height, width). 16 | """ 17 | batch_size, channels, height, width = image.size() 18 | 19 | # Generate a grid of pixel coordinates based on the optical flow 20 | grid_y, grid_x = torch.meshgrid(torch.arange(height), torch.arange(width),indexing='ij') 21 | grid = torch.stack((grid_x, grid_y), dim=-1).to(image.device) 22 | grid = grid.unsqueeze(0).expand(batch_size, -1, -1, -1) 23 | new_grid = grid + flow.permute(0, 2, 3, 1) 24 | 25 | # Normalize the grid coordinates between -1 and 1 26 | new_grid /= torch.tensor([width - 1, height - 1], dtype=torch.float32, device=image.device) 27 | new_grid = new_grid * 2 - 1 28 | # Perform the dense warp using grid_sample 29 | warped_image = F.grid_sample(image, new_grid, align_corners=False) 30 | 31 | return warped_image 32 | 33 | 34 | def find_homography_numpy(src_points, dst_points): 35 | A = [] 36 | B = [] 37 | for src, dst in zip(src_points, dst_points): 38 | x, y = src 39 | x_prime, y_prime = dst 40 | A.append([-x, -y, -1, 0, 0, 0, x * x_prime, y * x_prime, x_prime]) 41 | A.append([0, 0, 0, -x, -y, -1, x * y_prime, y * y_prime, y_prime]) 42 | B.extend([-x_prime, -y_prime]) 43 | A = np.array(A) 44 | B = np.array(B) 45 | ATA = np.dot(A.T, A) 46 | eigenvalues, eigenvectors = np.linalg.eigh(ATA) 47 | min_eigenvalue_index = np.argmin(eigenvalues) 48 | homography_vector = eigenvectors[:, min_eigenvalue_index] 49 | homography_vector /= homography_vector[-1] 50 | homography_matrix = np.reshape(homography_vector,(3, 3)) 51 | 52 | return homography_matrix 53 | 54 | import torch 55 | 56 | def warp(img, mat): 57 | device = img.device 58 | mat = torch.cat([mat,torch.ones((mat.size(0),1),device = device)], axis = -1).view(-1,3,3) 59 | batch_size, channels, height, width = img.shape 60 | cy, cx = height // 2, width // 2 61 | 62 | # Compute the translation matrix to shift the center to the origin 63 | translation_matrix1 = torch.tensor([[1, 0, -cx], 64 | [0, 1, -cy], 65 | [0, 0, 1]], dtype=torch.float32, device=device) 66 | translation_matrix1 = translation_matrix1.repeat(batch_size, 1, 1) 67 | 68 | # Compute the translation matrix to shift the origin back to the center 69 | translation_matrix2 = torch.tensor([[1, 0, cx], 70 | [0, 1, cy], 71 | [0, 0, 1]], dtype=torch.float32, device=device) 72 | translation_matrix2 = translation_matrix2.repeat(batch_size, 1, 1) 73 | transformation_matrix = torch.matmul(translation_matrix2, torch.matmul(mat, translation_matrix1)) 74 | 75 | # Compute the grid coordinates 76 | y_coords, x_coords = torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device)) 77 | coords = torch.stack([x_coords, y_coords, torch.ones_like(x_coords)], dim=-1).float() 78 | coords = coords.view(1, -1, 3).repeat(batch_size, 1, 1) 79 | 80 | # Apply the transformation matrix to the grid coordinates 81 | transformed_coords = torch.matmul(coords, transformation_matrix.transpose(1, 2)) 82 | 83 | # Normalize the transformed coordinates 84 | x_transformed = transformed_coords[:, :, 0] / transformed_coords[:, :, 2] 85 | y_transformed = transformed_coords[:, :, 1] / transformed_coords[:, :, 2] 86 | 87 | # Reshape the transformed coordinates to match the image size 88 | x_transformed = x_transformed.view(batch_size, height, width) 89 | y_transformed = y_transformed.view(batch_size, height, width) 90 | 91 | # Normalize the grid coordinates to the range [-1, 1] 92 | x_normalized = (x_transformed / (width - 1)) * 2 - 1 93 | y_normalized = (y_transformed / (height - 1)) * 2 - 1 94 | 95 | # Perform bilinear interpolation using grid_sample 96 | grid = torch.stack([x_normalized, y_normalized], dim=-1) 97 | warped_image = torch.nn.functional.grid_sample(img, grid, mode='bilinear', align_corners=False ,padding_mode='zeros') 98 | 99 | return warped_image 100 | 101 | 102 | def find_homography(src_points, dst_points): 103 | device = src_points.device 104 | A = [] 105 | B = [] 106 | # Convert input lists to PyTorch tensors 107 | src_points = torch.tensor(src_points, dtype=torch.float32) 108 | dst_points = torch.tensor(dst_points, dtype=torch.float32) 109 | for src, dst in zip(src_points, dst_points): 110 | x, y = src 111 | x_prime, y_prime = dst 112 | A.append([-x, -y, -1, 0, 0, 0, x * x_prime, y * x_prime, x_prime]) 113 | A.append([0, 0, 0, -x, -y, -1, x * y_prime, y * y_prime, y_prime]) 114 | B.extend([-x_prime, -y_prime]) 115 | A = torch.tensor(A, dtype=torch.float32) 116 | B = torch.tensor(B, dtype=torch.float32) 117 | # Calculate ATA matrix 118 | ATA = torch.matmul(A.T, A) 119 | # Eigenvalue decomposition 120 | eigenvalues, eigenvectors = torch.linalg.eigh(ATA) 121 | # Find the index of the smallest eigenvalue 122 | min_eigenvalue_index = torch.argmin(eigenvalues) 123 | # Extract the corresponding eigenvector 124 | homography_vector = eigenvectors[:, min_eigenvalue_index] 125 | # Normalize homography vector 126 | homography_vector = homography_vector / homography_vector[-1] 127 | # Reshape to obtain the homography matrix 128 | homography_matrix = homography_vector.view(3, 3) 129 | return homography_matrix.to(device) 130 | 131 | 132 | def findHomography(grids, new_grids_loc): 133 | """ 134 | @param: grids the location of origin grid vertices [2, H, W] 135 | @param: new_grids_loc the location of desired grid vertices [2, H, W] 136 | 137 | @return: homo_t homograph projection matrix for each grid [3, 3, H-1, W-1] 138 | """ 139 | 140 | _, H, W = grids.shape 141 | 142 | new_grids = new_grids_loc.unsqueeze(0) 143 | 144 | Homo = torch.zeros(1, 3, 3, H-1, W-1).to(grids.device) 145 | 146 | grids = grids.unsqueeze(0) 147 | 148 | try: 149 | # for common cases if all the homograph can be calculated 150 | one = torch.ones_like(grids[:, 0:1, :-1, :-1], device=grids.device) 151 | zero = torch.zeros_like(grids[:, 1:2, :-1, :-1], device=grids.device) 152 | 153 | A = torch.cat([ 154 | torch.stack([grids[:, 0:1, :-1, :-1], grids[:, 1:2, :-1, :-1], one, zero, zero, zero, 155 | -1 * grids[:, 0:1, :-1, :-1] * new_grids[:, 0:1, :-1, :-1], -1 * grids[:, 1:2, :-1, :-1] * new_grids[:, 0:1, :-1, :-1]], 2), # 1, 1, 8, h-1, w-1 156 | torch.stack([grids[:, 0:1, 1:, :-1], grids[:, 1:2, 1:, :-1], one, zero, zero, zero, 157 | -1 * grids[:, 0:1, 1:, :-1] * new_grids[:, 0:1, 1:, :-1], -1 * grids[:, 1:2, 1:, :-1] * new_grids[:, 0:1, 1:, :-1]], 2), 158 | torch.stack([grids[:, 0:1, :-1, 1:], grids[:, 1:2, :-1, 1:], one, zero, zero, zero, 159 | -1 * grids[:, 0:1, :-1, 1:] * new_grids[:, 0:1, :-1, 1:], -1 * grids[:, 1:2, :-1, 1:] * new_grids[:, 0:1, :-1, 1:]], 2), 160 | torch.stack([grids[:, 0:1, 1:, 1:], grids[:, 1:2, 1:, 1:], one, zero, zero, zero, 161 | -1 * grids[:, 0:1, 1:, 1:] * new_grids[:, 0:1, 1:, 1:], -1 * grids[:, 1:2, 1:, 1:] * new_grids[:, 0:1, 1:, 1:]], 2), 162 | torch.stack([zero, zero, zero, grids[:, 0:1, :-1, :-1], grids[:, 1:2, :-1, :-1], one, 163 | -1 * grids[:, 0:1, :-1, :-1] * new_grids[:, 1:2, :-1, :-1], -1 * grids[:, 1:2, :-1, :-1] * new_grids[:, 1:2, :-1, :-1]], 2), 164 | torch.stack([zero, zero, zero, grids[:, 0:1, 1:, :-1], grids[:, 1:2, 1:, :-1], one, 165 | -1 * grids[:, 0:1, 1:, :-1] * new_grids[:, 1:2, 1:, :-1], -1 * grids[:, 1:2, 1:, :-1] * new_grids[:, 1:2, 1:, :-1]], 2), 166 | torch.stack([zero, zero, zero, grids[:, 0:1, :-1, 1:], grids[:, 1:2, :-1, 1:], one, 167 | -1 * grids[:, 0:1, :-1, 1:] * new_grids[:, 1:2, :-1, 1:], -1 * grids[:, 1:2, :-1, 1:] * new_grids[:, 1:2, :-1, 1:]], 2), 168 | torch.stack([zero, zero, zero, grids[:, 0:1, 1:, 1:], grids[:, 1:2, 1:, 1:], one, 169 | -1 * grids[:, 0:1, 1:, 1:] * new_grids[:, 1:2, 1:, 1:], -1 * grids[:, 1:2, 1:, 1:] * new_grids[:, 1:2, 1:, 1:]], 2), 170 | ], 1).view(8, 8, -1).permute(2, 0, 1) # 1, 8, 8, h-1, w-1 171 | B_ = torch.stack([ 172 | new_grids[:, 0, :-1, :-1], 173 | new_grids[:, 0, 1:, :-1], 174 | new_grids[:, 0, :-1, 1:], 175 | new_grids[:, 0, 1:, 1:], 176 | new_grids[:, 1, :-1, :-1], 177 | new_grids[:, 1, 1:, :-1], 178 | new_grids[:, 1, :-1, 1:], 179 | new_grids[:, 1, 1:, 1:], 180 | ], 1).view(8, -1).permute(1, 0) # B, 8, h-1, w-1 ==> A @ H = B ==> H = A^-1 @ B 181 | A_inverse = torch.inverse(A) 182 | # B, 8, 8 @ B, 8, 1 --> B, 8, 1 183 | H_recovered = torch.bmm(A_inverse, B_.unsqueeze(2)) 184 | 185 | H_ = torch.cat([H_recovered, torch.ones_like( 186 | H_recovered[:, 0:1, :], device=H_recovered.device)], 1).view(H_recovered.shape[0], 3, 3) 187 | 188 | H_ = H_.permute(1, 2, 0) 189 | H_ = H_.view(Homo.shape) 190 | Homo = H_ 191 | except: 192 | # if some of the homography can not be calculated 193 | one = torch.ones_like(grids[:, 0:1, 0, 0], device=grids.device) 194 | zero = torch.zeros_like(grids[:, 1:2, 0, 0], device=grids.device) 195 | H_ = torch.eye(3, device=grids.device) 196 | for i in range(H - 1): 197 | for j in range(W - 1): 198 | A = torch.cat([ 199 | torch.stack([grids[:, 0:1, i, j], grids[:, 1:2, i, j], one, zero, zero, zero, 200 | -1 * grids[:, 0:1, i, j] * new_grids[:, 0:1, i, j], -1 * grids[:, 1:2, i, j] * new_grids[:, 0:1, i, j]], 2), 201 | torch.stack([grids[:, 0:1, i+1, j], grids[:, 1:2, i+1, j], one, zero, zero, zero, 202 | -1 * grids[:, 0:1, i+1, j] * new_grids[:, 0:1, i+1, j], -1 * grids[:, 1:2, i+1, j] * new_grids[:, 0:1, i+1, j]], 2), 203 | torch.stack([grids[:, 0:1, i, j+1], grids[:, 1:2, i, j+1], one, zero, zero, zero, 204 | -1 * grids[:, 0:1, i, j+1] * new_grids[:, 0:1, i, j+1], -1 * grids[:, 1:2, i, j+1] * new_grids[:, 0:1, i, j+1]], 2), 205 | torch.stack([grids[:, 0:1, i+1, j+1], grids[:, 1:2, i+1, j+1], one, zero, zero, zero, 206 | -1 * grids[:, 0:1, i+1, j+1] * new_grids[:, 0:1, i+1, j+1], -1 * grids[:, 1:2, i+1, j+1] * new_grids[:, 0:1, i+1, j+1]], 2), 207 | torch.stack([zero, zero, zero, grids[:, 0:1, i, j], grids[:, 1:2, i, j], one, 208 | -1 * grids[:, 0:1, i, j] * new_grids[:, 1:2, i, j], -1 * grids[:, 1:2, i, j] * new_grids[:, 1:2, i, j]], 2), 209 | torch.stack([zero, zero, zero, grids[:, 0:1, i+1, j], grids[:, 1:2, i+1, j], one, 210 | -1 * grids[:, 0:1, i+1, j] * new_grids[:, 1:2, i+1, j], -1 * grids[:, 1:2, i+1, j] * new_grids[:, 1:2, i+1, j]], 2), 211 | torch.stack([zero, zero, zero, grids[:, 0:1, i, j+1], grids[:, 1:2, i, j+1], one, 212 | -1 * grids[:, 0:1, i, j+1] * new_grids[:, 1:2, i, j+1], -1 * grids[:, 1:2, i, j+1] * new_grids[:, 1:2, i, j+1]], 2), 213 | torch.stack([zero, zero, zero, grids[:, 0:1, i+1, j+1], grids[:, 1:2, i+1, j+1], one, 214 | -1 * grids[:, 0:1, i+1, j+1] * new_grids[:, 1:2, i+1, j+1], -1 * grids[:, 1:2, i+1, j+1] * new_grids[:, 1:2, i+1, j+1]], 2), 215 | ], 1) # B, 8, 8 216 | B_ = torch.stack([ 217 | new_grids[:, 0, i, j], 218 | new_grids[:, 0, i+1, j], 219 | new_grids[:, 0, i, j+1], 220 | new_grids[:, 0, i+1, j+1], 221 | new_grids[:, 1, i, j], 222 | new_grids[:, 1, i+1, j], 223 | new_grids[:, 1, i, j+1], 224 | new_grids[:, 1, i+1, j+1], 225 | ], 1) # B, 8 ==> A @ H = B ==> H = A^-1 @ B 226 | try: 227 | A_inverse = torch.inverse(A) 228 | 229 | # B, 8, 8 @ B, 8, 1 --> B, 8, 1 230 | H_recovered = torch.bmm(A_inverse, B_.unsqueeze(2)) 231 | 232 | H_ = torch.cat([H_recovered, torch.ones_like(H_recovered[:, 0:1, :]).to( 233 | H_recovered.device)], 1).view(H_recovered.shape[0], 3, 3) 234 | except: 235 | pass 236 | Homo[:, :, :, i, j] = H_ 237 | 238 | homo_t = Homo.view(3, 3, H-1, W-1) 239 | 240 | return homo_t -------------------------------------------------------------------------------- /inference_future_frames.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import cv2\n", 11 | "import torch\n", 12 | "from time import time\n", 13 | "import os\n", 14 | "import datetime\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "import torch.nn as nn\n", 17 | "import torchvision\n", 18 | "import torch.nn.functional as F\n", 19 | "device = 'cuda'\n", 20 | "batch_size = 1\n", 21 | "grid_h,grid_w = 15,15\n", 22 | "H,W = height,width = 360,640" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "def get_warp(net_out,img):\n", 32 | " '''\n", 33 | " Inputs:\n", 34 | " net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n", 35 | " img: image to warp\n", 36 | " '''\n", 37 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n", 38 | " torch.linspace(-1,1, grid_h + 1),\n", 39 | " indexing='ij')\n", 40 | " src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n", 41 | " new_grid = src_grid + 1 * net_out\n", 42 | " grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n", 43 | " warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=True)\n", 44 | " return warped" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "class StabNet(nn.Module):\n", 54 | " def __init__(self,trainable_layers = 10):\n", 55 | " super(StabNet, self).__init__()\n", 56 | " # Load the pre-trained ResNet model\n", 57 | " vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n", 58 | " # Extract conv1 pretrained weights for RGB input\n", 59 | " rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n", 60 | " # Calculate the average across the RGB channels\n", 61 | " tiled_rgb_weights = rgb_weights.repeat(1,5,1,1) \n", 62 | " # Change size of the first layer from 3 to 9 channels\n", 63 | " vgg19.features[0] = nn.Conv2d(15,64, kernel_size=3, stride=1, padding=1, bias=False)\n", 64 | " # set new weights\n", 65 | " vgg19.features[0].weight = nn.Parameter(tiled_rgb_weights)\n", 66 | " # Determine the total number of layers in the model\n", 67 | " total_layers = sum(1 for _ in vgg19.parameters())\n", 68 | " # Freeze the layers except the last 10\n", 69 | " for idx, param in enumerate(vgg19.parameters()):\n", 70 | " if idx > total_layers - trainable_layers:\n", 71 | " param.requires_grad = True\n", 72 | " else:\n", 73 | " param.requires_grad = False\n", 74 | " # Remove the last layer of ResNet\n", 75 | " self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n", 76 | " self.regressor = nn.Sequential(nn.Linear(512,2048),\n", 77 | " nn.ReLU(),\n", 78 | " nn.Linear(2048,1024),\n", 79 | " nn.ReLU(),\n", 80 | " nn.Linear(1024,512),\n", 81 | " nn.ReLU(),\n", 82 | " nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n", 83 | " #self.regressor[-1].bias.data.fill_(0)\n", 84 | " total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n", 85 | " total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n", 86 | " print(\"Total Trainable mobilenet Parameters: \", total_resnet_params)\n", 87 | " print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n", 88 | " print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n", 89 | " \n", 90 | " def forward(self, x_tensor):\n", 91 | " x_batch_size = x_tensor.size()[0]\n", 92 | " x = x_tensor[:, :3, :, :]\n", 93 | "\n", 94 | " # summary 1, dismiss now\n", 95 | " x_tensor = self.encoder(x_tensor)\n", 96 | " x_tensor = torch.mean(x_tensor, dim=[2, 3])\n", 97 | " x = self.regressor(x_tensor)\n", 98 | " x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n", 99 | " return x" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 4, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "Total Trainable mobilenet Parameters: 2360320\n", 112 | "Total Trainable regressor Parameters: 3936256\n", 113 | "Total Trainable parameters: 6296576\n", 114 | "loaded weights ./ckpts/with_future_frames/stabnet_2023-11-02_23-03-39.pth\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "ckpt_dir = './ckpts/with_future_frames/'\n", 120 | "stabnet = StabNet().to(device).eval()\n", 121 | "ckpts = os.listdir(ckpt_dir)\n", 122 | "if ckpts:\n", 123 | " ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n", 124 | " \n", 125 | " # Get the filename of the latest checkpoint\n", 126 | " latest = os.path.join(ckpt_dir, ckpts[0])\n", 127 | "\n", 128 | " state = torch.load(latest)\n", 129 | " stabnet.load_state_dict(state['model'])\n", 130 | " print('loaded weights',latest)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 5, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "path = 'E:/Datasets/DeepStab_Dataset/unstable/2.avi'\n", 140 | "cap = cv2.VideoCapture(path)\n", 141 | "mean = np.array([0.485, 0.456, 0.406],dtype = np.float32) \n", 142 | "std = np.array([0.229, 0.224, 0.225],dtype = np.float32)\n", 143 | "frames = []\n", 144 | "while True:\n", 145 | " ret, img = cap.read()\n", 146 | " if not ret: break\n", 147 | " img = cv2.resize(img, (W,H))\n", 148 | " img = (img / 255.0).astype(np.float32)\n", 149 | " img = (img - mean)/std\n", 150 | " frames.append(img)\n", 151 | "frames = np.array(frames,dtype = np.float32)\n", 152 | "frame_count,_,_,_ = frames.shape" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 6, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "name": "stdout", 162 | "output_type": "stream", 163 | "text": [ 164 | "speed: 0.007669420583669505 seconds per frame\n" 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "frames_tensor = torch.from_numpy(frames).permute(0,3,1,2).float().to('cpu')\n", 170 | "stable_frames_tensor = frames_tensor.clone()\n", 171 | "\n", 172 | "SKIP = 32\n", 173 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n", 174 | "def get_batch(idx):\n", 175 | " batch = torch.zeros((5,3,H,W)).float()\n", 176 | " for i,j in enumerate(range(idx - SKIP, idx + SKIP + 1, SKIP//2)):\n", 177 | " batch[i,...] = frames_tensor[j,...]\n", 178 | " batch = batch.view(1,-1,H,W)\n", 179 | " return batch.to(device)\n", 180 | "start = time()\n", 181 | "for frame_idx in range(SKIP,frame_count - SKIP):\n", 182 | " batch = get_batch(frame_idx)\n", 183 | " with torch.no_grad():\n", 184 | " transform = stabnet(batch)\n", 185 | " warped = get_warp(transform, frames_tensor[frame_idx: frame_idx + 1,...].cuda())\n", 186 | " stable_frames_tensor[frame_idx] = warped\n", 187 | " img = warped.permute(0,2,3,1)[0,...].cpu().detach().numpy()\n", 188 | " img *= std\n", 189 | " img += mean\n", 190 | " img = np.clip(img * 255.0,0,255).astype(np.uint8)\n", 191 | " cv2.imshow('window', img)\n", 192 | " if cv2.waitKey(1) & 0xFF == ord('q'):\n", 193 | " break\n", 194 | "cv2.destroyAllWindows()\n", 195 | "total = time() - start\n", 196 | "speed = total / frame_count\n", 197 | "print(f'speed: {speed} seconds per frame')" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 18, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "stable_frames = np.clip(((stable_frames_tensor.permute(0,2,3,1).numpy() * std) + mean) * 255,0,255).astype(np.uint8)\n", 207 | "frames = np.clip(((frames_tensor.permute(0,2,3,1).numpy() * std) + mean) * 255,0,255).astype(np.uint8)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 19, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "from time import sleep\n", 217 | "fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n", 218 | "out = cv2.VideoWriter('2.avi', fourcc, 30.0, (W,H))\n", 219 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n", 220 | "for idx in range(frame_count):\n", 221 | " img = stable_frames[idx,...]\n", 222 | " out.write(img)\n", 223 | " cv2.imshow('window',img)\n", 224 | " #sleep(1/30)\n", 225 | " if cv2.waitKey(1) & 0xFF == ord(' '):\n", 226 | " break\n", 227 | "out.release()\n", 228 | "cv2.destroyAllWindows()" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 20, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "Frame: 446/447\n", 241 | "cropping score:0.996\tdistortion score:0.982\tstability:0.666\tpixel:0.997\n" 242 | ] 243 | }, 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "(0.9962942926265075, 0.9820633, 0.66624937150459, 0.9968942715786397)" 248 | ] 249 | }, 250 | "execution_count": 20, 251 | "metadata": {}, 252 | "output_type": "execute_result" 253 | } 254 | ], 255 | "source": [ 256 | "from metrics import metric\n", 257 | "metric('E:/Datasets/DeepStab_Dataset/unstable/2.avi','2.avi')" 258 | ] 259 | } 260 | ], 261 | "metadata": { 262 | "kernelspec": { 263 | "display_name": "DUTCode", 264 | "language": "python", 265 | "name": "python3" 266 | }, 267 | "language_info": { 268 | "codemirror_mode": { 269 | "name": "ipython", 270 | "version": 3 271 | }, 272 | "file_extension": ".py", 273 | "mimetype": "text/x-python", 274 | "name": "python", 275 | "nbconvert_exporter": "python", 276 | "pygments_lexer": "ipython3", 277 | "version": "3.9.18" 278 | }, 279 | "orig_nbformat": 4 280 | }, 281 | "nbformat": 4, 282 | "nbformat_minor": 2 283 | } 284 | -------------------------------------------------------------------------------- /inference_online.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import cv2\n", 11 | "import torch\n", 12 | "from time import time\n", 13 | "import os\n", 14 | "import datetime\n", 15 | "import pandas as pd\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import torch.nn as nn\n", 18 | "import torchvision\n", 19 | "import torch.nn.functional as F\n", 20 | "from torch.utils import data\n", 21 | "from image_utils import dense_warp, warp\n", 22 | "device = 'cuda'\n", 23 | "height,width = 360,640\n", 24 | "batch_size = 1\n", 25 | "grid_h,grid_w = 15,15" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 5, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "def get_warp(net_out,img):\n", 35 | " '''\n", 36 | " Inputs:\n", 37 | " net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n", 38 | " img: image to warp\n", 39 | " '''\n", 40 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n", 41 | " torch.linspace(-1,1, grid_h + 1),\n", 42 | " indexing='ij')\n", 43 | " src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n", 44 | " new_grid = src_grid + net_out\n", 45 | " grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n", 46 | " warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=False,padding_mode='zeros')\n", 47 | " return warped" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "class StabNet(nn.Module):\n", 57 | " def __init__(self,trainable_layers = 10):\n", 58 | " super(StabNet, self).__init__()\n", 59 | " # Load the pre-trained ResNet model\n", 60 | " vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n", 61 | " # Extract conv1 pretrained weights for RGB input\n", 62 | " rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n", 63 | " # Calculate the average across the RGB channels\n", 64 | " average_rgb_weights = torch.mean(rgb_weights, dim=1, keepdim=True).repeat(1,6,1,1) #torch.Size([64, 5, 7, 7])\n", 65 | " # Change size of the first layer from 3 to 9 channels\n", 66 | " vgg19.features[0] = nn.Conv2d(9,64, kernel_size=3, stride=1, padding=1, bias=False)\n", 67 | " # set new weights\n", 68 | " new_weights = torch.cat((rgb_weights, average_rgb_weights), dim=1)\n", 69 | " vgg19.features[0].weight = nn.Parameter(new_weights)\n", 70 | " # Determine the total number of layers in the model\n", 71 | " total_layers = sum(1 for _ in vgg19.parameters())\n", 72 | " # Freeze the layers except the last 10\n", 73 | " for idx, param in enumerate(vgg19.parameters()):\n", 74 | " if idx > total_layers - trainable_layers:\n", 75 | " param.requires_grad = True\n", 76 | " else:\n", 77 | " param.requires_grad = False\n", 78 | " # Remove the last layer of ResNet\n", 79 | " self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n", 80 | " self.regressor = nn.Sequential(nn.Linear(512,2048),\n", 81 | " nn.ReLU(),\n", 82 | " nn.Linear(2048,1024),\n", 83 | " nn.ReLU(),\n", 84 | " nn.Linear(1024,512),\n", 85 | " nn.ReLU(),\n", 86 | " nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n", 87 | " total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n", 88 | " total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n", 89 | " print(\"Total Trainable encoder Parameters: \", total_resnet_params)\n", 90 | " print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n", 91 | " print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n", 92 | " \n", 93 | " def forward(self, x_tensor):\n", 94 | " x_batch_size = x_tensor.size()[0]\n", 95 | " x = x_tensor[:, :3, :, :]\n", 96 | "\n", 97 | " # summary 1, dismiss now\n", 98 | " x_tensor = self.encoder(x_tensor)\n", 99 | " x_tensor = torch.mean(x_tensor, dim=[2, 3])\n", 100 | " x = self.regressor(x_tensor)\n", 101 | " x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n", 102 | " return x" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 10, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "Total Trainable encoder Parameters: 2360320\n", 115 | "Total Trainable regressor Parameters: 3936256\n", 116 | "Total Trainable parameters: 6296576\n", 117 | "loaded weights ./ckpts/original/stabnet_2023-10-26_13-42-14.pth\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "ckpt_dir = './ckpts/original/'\n", 123 | "stabnet = StabNet().to(device).eval()\n", 124 | "ckpts = os.listdir(ckpt_dir)\n", 125 | "if ckpts:\n", 126 | " ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n", 127 | " \n", 128 | " # Get the filename of the latest checkpoint\n", 129 | " latest = os.path.join(ckpt_dir, ckpts[0])\n", 130 | "\n", 131 | " state = torch.load(latest)\n", 132 | " stabnet.load_state_dict(state['model'])\n", 133 | " print('loaded weights',latest)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 7, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "path = 'E:/Datasets/DeepStab_Dataset/unstable/2.avi'\n", 143 | "cap = cv2.VideoCapture(path)\n", 144 | "frames = []\n", 145 | "while True:\n", 146 | " ret,frame = cap.read()\n", 147 | " if not ret : break\n", 148 | " frame = cv2.resize(frame,(width,height))\n", 149 | " frames.append(frame)\n", 150 | "frames = np.array(frames)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 8, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "data": { 160 | "text/plain": [ 161 | "torch.Size([447, 3, 360, 640])" 162 | ] 163 | }, 164 | "execution_count": 8, 165 | "metadata": {}, 166 | "output_type": "execute_result" 167 | } 168 | ], 169 | "source": [ 170 | "frames_t = torch.from_numpy(frames/255.0).permute(0,3,1,2).float()\n", 171 | "frames_t.shape" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 9, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "ename": "KeyboardInterrupt", 181 | "evalue": "", 182 | "output_type": "error", 183 | "traceback": [ 184 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 185 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 186 | "Cell \u001b[1;32mIn[9], line 14\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m 13\u001b[0m trasnform \u001b[38;5;241m=\u001b[39m stabnet(net_in)\n\u001b[1;32m---> 14\u001b[0m warped \u001b[38;5;241m=\u001b[39m get_warp(trasnform \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m1\u001b[39m ,\u001b[43mcurr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m 15\u001b[0m warped_frames[idx:idx\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m] \u001b[38;5;241m=\u001b[39m warped\u001b[38;5;241m.\u001b[39mcpu()\n\u001b[0;32m 16\u001b[0m warped_gray \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmean(warped,dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m,keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", 187 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: " 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "num_frames,_,h,w = frames_t.shape\n", 193 | "warped_frames = frames_t.clone()\n", 194 | "buffer = torch.zeros((6,1,h,w)).float()\n", 195 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n", 196 | "start = time()\n", 197 | "for iter in range(1):\n", 198 | " for idx in range(33,num_frames):\n", 199 | " for i in range(6):\n", 200 | " buffer[i,...] = torch.mean(warped_frames[idx - 2**i,...],dim = 0,keepdim = True)\n", 201 | " curr = warped_frames[idx:idx+1,...] \n", 202 | " net_in = torch.cat([curr,buffer.permute(1,0,2,3)], dim = 1).to(device)\n", 203 | " with torch.no_grad():\n", 204 | " trasnform = stabnet(net_in)\n", 205 | " warped = get_warp(trasnform * 1 ,curr.to(device))\n", 206 | " warped_frames[idx:idx+1,...] = warped.cpu()\n", 207 | " warped_gray = torch.mean(warped,dim = 1,keepdim=True)\n", 208 | " buffer = torch.roll(buffer, shifts= 1, dims=1)\n", 209 | " buffer[:,:1,:,:] = warped_gray\n", 210 | " img = warped_frames[idx,...].permute(1,2,0).numpy()\n", 211 | " img = (img * 255).astype(np.uint8)\n", 212 | " cv2.imshow('window',img)\n", 213 | " if cv2.waitKey(1) & 0xFF == ord(' '):\n", 214 | " break\n", 215 | "cv2.destroyAllWindows()\n", 216 | "total = time() - start\n", 217 | "speed = total / num_frames\n", 218 | "print(f'speed: {speed} seconds per frame') " 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 17, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "from time import sleep\n", 228 | "fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n", 229 | "out = cv2.VideoWriter('./results/2.avi', fourcc, 30.0, (256,256))\n", 230 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n", 231 | "for idx in range(num_frames):\n", 232 | " img = warped_frames[idx,...].permute(1,2,0).numpy()\n", 233 | " img = (img * 255).astype(np.uint8)\n", 234 | " diff = cv2.absdiff(img,frames[idx,...])\n", 235 | " out.write(img)\n", 236 | " cv2.imshow('window',img)\n", 237 | " sleep(1/30)\n", 238 | " if cv2.waitKey(1) & 0xFF == ord(' '):\n", 239 | " break\n", 240 | "cv2.destroyAllWindows()\n", 241 | "out.release()" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n", 251 | "for idx in range(num_frames):\n", 252 | " img = warped_frames[idx,...].permute(1,2,0).numpy()\n", 253 | " img = (img * 255).astype(np.uint8)\n", 254 | " diff = cv2.absdiff(img,frames[idx,...])\n", 255 | " cv2.imshow('window',diff)\n", 256 | " sleep(1/30)\n", 257 | " if cv2.waitKey(1) & 0xFF == ord(' '):\n", 258 | " break\n", 259 | "cv2.destroyAllWindows()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 18, 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "name": "stdout", 269 | "output_type": "stream", 270 | "text": [ 271 | "Frame: 446/447\n", 272 | "cropping score:1.000\tdistortion score:0.989\tstability:0.639\tpixel:0.997\n" 273 | ] 274 | }, 275 | { 276 | "data": { 277 | "text/plain": [ 278 | "(1.0, 0.98863894, 0.6389072784994596, 0.99749401723966)" 279 | ] 280 | }, 281 | "execution_count": 18, 282 | "metadata": {}, 283 | "output_type": "execute_result" 284 | } 285 | ], 286 | "source": [ 287 | "from metrics import metric\n", 288 | "metric('E:/Datasets/DeepStab_Dataset/unstable/2.avi','./results/Regular_2.avi')" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "DUTCode", 302 | "language": "python", 303 | "name": "python3" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.9.18" 316 | }, 317 | "orig_nbformat": 4 318 | }, 319 | "nbformat": 4, 320 | "nbformat_minor": 2 321 | } 322 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import sys 4 | 5 | def metric(original_video_path, pred_video_path,shape=(128,128)): 6 | ''' Inputs: 7 | path1: path to ground truth stable video 8 | path2: path to generated stable video 9 | Outputs: 10 | cropping_score 11 | distortion_score 12 | pixel_score 13 | stability_score 14 | ''' 15 | # Create brute-force matcher object 16 | sys.stdout.flush() 17 | bf = cv2.BFMatcher() 18 | sift = cv2.SIFT_create() 19 | 20 | # Apply the homography transformation if we have enough good matches 21 | MIN_MATCH_COUNT = 10 22 | ratio = 0.7 23 | thresh = 5.0 24 | 25 | CR_seq = [] 26 | DV_seq = [] 27 | Pt = np.eye(3) 28 | P_seq = [] 29 | pixel_loss = [] 30 | 31 | # Video loading 32 | H,W = shape 33 | #load both videos 34 | cap1 = cv2.VideoCapture(original_video_path) 35 | cap2 = cv2.VideoCapture(pred_video_path) 36 | frame_count = min(int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),int(cap2.get(cv2.CAP_PROP_FRAME_COUNT))) 37 | original_frames = np.zeros((frame_count,H,W,3),np.uint8) 38 | pred_frames = np.zeros_like(original_frames) 39 | for i in range(frame_count): 40 | ret1,img1 = cap1.read() 41 | ret2,img2 = cap2.read() 42 | if not ret1 or not ret2: 43 | break 44 | img1 = cv2.resize(img1,(W,H)) 45 | img2 = cv2.resize(img2,(W,H)) 46 | original_frames[i,...] = img1 47 | pred_frames[i,...] = img2 48 | 49 | for i in range(frame_count): 50 | img1 = original_frames[i,...] 51 | img1o = pred_frames[i,...] 52 | 53 | # Convert frames to grayscale 54 | a = (img1 / 255.0).astype(np.float32) 55 | b = (img1o/ 255.0).astype(np.float32) 56 | img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) 57 | img1o = cv2.cvtColor(img1o, cv2.COLOR_BGR2GRAY) 58 | pixel_loss.append(np.mean((a-b)**2)) 59 | # Detect the SIFT key points and compute the descriptors for the two images 60 | keyPoints1, descriptors1 = sift.detectAndCompute(img1, None) 61 | keyPoints1o, descriptors1o = sift.detectAndCompute(img1o, None) 62 | 63 | # Match the descriptors 64 | matches = bf.knnMatch(descriptors1, descriptors1o, k=2) 65 | 66 | # Select the good matches using the ratio test 67 | goodMatches = [] 68 | 69 | for m, n in matches: 70 | if m.distance < ratio * n.distance: 71 | goodMatches.append(m) 72 | 73 | if len(goodMatches) > MIN_MATCH_COUNT: 74 | # Get the good key points positions 75 | sourcePoints = np.float32([keyPoints1[m.queryIdx].pt for m in goodMatches]).reshape(-1, 1, 2) 76 | destinationPoints = np.float32([keyPoints1o[m.trainIdx].pt for m in goodMatches]).reshape(-1, 1, 2) 77 | 78 | # Obtain the homography matrix 79 | M, _ = cv2.findHomography(sourcePoints, destinationPoints, method=cv2.RANSAC, ransacReprojThreshold=thresh) 80 | 81 | # Obtain Scale, Translation, Rotation, Distortion value 82 | scaleRecovered = np.sqrt(M[0, 1] ** 2 + M[0, 0] ** 2) 83 | 84 | eigenvalues = np.abs(np.linalg.eigvals(M[0:2, 0:2])) 85 | eigenvalues = sorted(eigenvalues,reverse= True) 86 | DV = (eigenvalues[1] / eigenvalues[0]).astype(np.float32) 87 | 88 | CR_seq.append(1 / scaleRecovered) 89 | DV_seq.append(DV) 90 | 91 | # For Stability score calculation 92 | if i + 1 < frame_count: 93 | img2o = pred_frames[i+1,...] 94 | # Convert frame to grayscale 95 | img2o = cv2.cvtColor(img2o, cv2.COLOR_BGR2GRAY) 96 | 97 | keyPoints2o, descriptors2o = sift.detectAndCompute(img2o, None) 98 | matches = bf.knnMatch(descriptors1o, descriptors2o, k=2) 99 | goodMatches = [] 100 | 101 | for m, n in matches: 102 | if m.distance < ratio * n.distance: 103 | goodMatches.append(m) 104 | 105 | if len(goodMatches) > MIN_MATCH_COUNT: 106 | # Get the good key points positions 107 | sourcePoints = np.float32([keyPoints1o[m.queryIdx].pt for m in goodMatches]).reshape(-1, 1, 2) 108 | destinationPoints = np.float32([keyPoints2o[m.trainIdx].pt for m in goodMatches]).reshape(-1, 1, 2) 109 | 110 | # Obtain the homography matrix 111 | M, _ = cv2.findHomography(sourcePoints, destinationPoints, method=cv2.RANSAC, ransacReprojThreshold=thresh) 112 | 113 | P_seq.append(np.matmul(Pt, M)) 114 | Pt = np.matmul(Pt, M) 115 | 116 | sys.stdout.write('\rFrame: ' + str(i) + '/' + str(frame_count)) 117 | 118 | cap1.release() 119 | cap2.release() 120 | 121 | # Make 1D temporal signals 122 | P_seq_t = [] 123 | P_seq_r = [] 124 | 125 | for Mp in P_seq: 126 | transRecovered = np.sqrt(Mp[0, 2] ** 2 + Mp[1, 2] ** 2) 127 | thetaRecovered = np.arctan2(Mp[1, 0], Mp[0, 0]) * 180 / np.pi 128 | P_seq_t.append(transRecovered) 129 | P_seq_r.append(thetaRecovered) 130 | 131 | # FFT 132 | fft_t = np.fft.fft(P_seq_t) 133 | fft_r = np.fft.fft(P_seq_r) 134 | fft_t = np.abs(fft_t) ** 2 135 | fft_r = np.abs(fft_r) ** 2 136 | 137 | fft_t = np.delete(fft_t, 0) 138 | fft_r = np.delete(fft_r, 0) 139 | fft_t = fft_t[:len(fft_t) // 2] 140 | fft_r = fft_r[:len(fft_r) // 2] 141 | 142 | SS_t = np.sum(fft_t[:5]) / np.sum(fft_t) 143 | SS_r = np.sum(fft_r[:5]) / np.sum(fft_r) 144 | 145 | cropping_score = np.min([np.mean(CR_seq), 1]) 146 | distortion_score = np.min(DV_seq) 147 | stability_score = (SS_t+SS_r)/2 148 | pixel_score = 1 - np.mean(pixel_loss) 149 | out = f'\ncropping score:{cropping_score:.3f}\tdistortion score:{distortion_score:.3f}\tstability:{stability_score:.3f}\tpixel:{pixel_score:.3f}\n' 150 | sys.stdout.write(out) 151 | return cropping_score, distortion_score, stability_score, pixel_score 152 | 153 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.14.1 3 | opencv-python==4.10.0.84 4 | matplotlib==3.9.2 5 | pandas==2.2.3 6 | numpy==1.26.4 7 | scikit-learn==1.5.2 8 | scipy==1.13.1 -------------------------------------------------------------------------------- /stabilize_future_frames.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import argparse 4 | from time import time 5 | import os 6 | import datetime 7 | import matplotlib.pyplot as plt 8 | import torch 9 | import torch.nn as nn 10 | import torchvision 11 | import torch.nn.functional as F 12 | 13 | device = 'cuda' 14 | batch_size = 1 15 | grid_h,grid_w = 15,15 16 | H,W = height,width = 360,640 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='Video Stabilization using StabNet') 21 | parser.add_argument('--in_path', type=str, help='Input video file path') 22 | parser.add_argument('--out_path', type=str, help='Output stabilized video file path') 23 | return parser.parse_args() 24 | 25 | def get_warp(net_out,img): 26 | ''' 27 | Inputs: 28 | net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2]) 29 | img: image to warp 30 | ''' 31 | grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1), 32 | torch.linspace(-1,1, grid_h + 1), 33 | indexing='ij') 34 | src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device) 35 | new_grid = src_grid + 1 * net_out 36 | grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True) 37 | warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=False,padding_mode='zeros') 38 | return warped 39 | 40 | def save_video(frames, path): 41 | frame_count,h,w,_ = frames.shape 42 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 43 | out = cv2.VideoWriter(path, fourcc, 30.0, (w,h)) 44 | for idx in range(frame_count): 45 | out.write(frames[idx,...]) 46 | out.release() 47 | 48 | class StabNet(nn.Module): 49 | def __init__(self,trainable_layers = 10): 50 | super(StabNet, self).__init__() 51 | # Load the pre-trained ResNet model 52 | vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1') 53 | # Extract conv1 pretrained weights for RGB input 54 | rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3]) 55 | # Calculate the average across the RGB channels 56 | tiled_rgb_weights = rgb_weights.repeat(1,5,1,1) 57 | # Change size of the first layer from 3 to 9 channels 58 | vgg19.features[0] = nn.Conv2d(15,64, kernel_size=3, stride=1, padding=1, bias=False) 59 | # set new weights 60 | vgg19.features[0].weight = nn.Parameter(tiled_rgb_weights) 61 | # Determine the total number of layers in the model 62 | total_layers = sum(1 for _ in vgg19.parameters()) 63 | # Freeze the layers except the last 10 64 | for idx, param in enumerate(vgg19.parameters()): 65 | if idx > total_layers - trainable_layers: 66 | param.requires_grad = True 67 | else: 68 | param.requires_grad = False 69 | # Remove the last layer of ResNet 70 | self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1]) 71 | self.regressor = nn.Sequential(nn.Linear(512,2048), 72 | nn.ReLU(), 73 | nn.Linear(2048,1024), 74 | nn.ReLU(), 75 | nn.Linear(1024,512), 76 | nn.ReLU(), 77 | nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2))) 78 | #self.regressor[-1].bias.data.fill_(0) 79 | total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad) 80 | total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad) 81 | print("Total Trainable encoder Parameters: ", total_resnet_params) 82 | print("Total Trainable regressor Parameters: ", total_regressor_params) 83 | print("Total Trainable parameters:",total_regressor_params + total_resnet_params) 84 | 85 | def forward(self, x_tensor): 86 | x_batch_size = x_tensor.size()[0] 87 | x = x_tensor[:, :3, :, :] 88 | 89 | # summary 1, dismiss now 90 | x_tensor = self.encoder(x_tensor) 91 | x_tensor = torch.mean(x_tensor, dim=[2, 3]) 92 | x = self.regressor(x_tensor) 93 | x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2) 94 | return x 95 | 96 | def stabilize(in_path,out_path): 97 | 98 | if not os.path.exists(in_path): 99 | print(f"The input file '{in_path}' does not exist.") 100 | exit() 101 | _,ext = os.path.splitext(in_path) 102 | if ext not in ['.mp4','.avi']: 103 | print(f"The input file '{in_path}' is not a supported video file (only .mp4 and .avi are supported).") 104 | exit() 105 | 106 | #Load frames and stardardize 107 | cap = cv2.VideoCapture(in_path) 108 | mean = np.array([0.485, 0.456, 0.406],dtype = np.float32) 109 | std = np.array([0.229, 0.224, 0.225],dtype = np.float32) 110 | frames = [] 111 | while True: 112 | ret, img = cap.read() 113 | if not ret: break 114 | img = cv2.resize(img, (W,H)) 115 | img = (img / 255.0).astype(np.float32) 116 | img = (img - mean)/std 117 | frames.append(img) 118 | frames = np.array(frames,dtype = np.float32) 119 | frame_count,_,_,_ = frames.shape 120 | 121 | # stabilize video 122 | frames_tensor = torch.from_numpy(frames).permute(0,3,1,2).float().to('cpu') 123 | stable_frames_tensor = frames_tensor.clone() 124 | SKIP = 16 125 | cv2.namedWindow('window',cv2.WINDOW_NORMAL) 126 | def get_batch(idx): 127 | batch = torch.zeros((5,3,H,W)).float() 128 | for i,j in enumerate(range(idx - SKIP, idx + SKIP + 1, SKIP//2)): 129 | batch[i,...] = frames_tensor[j,...] 130 | batch = batch.view(1,-1,H,W) 131 | return batch.to(device) 132 | 133 | for frame_idx in range(SKIP,frame_count - SKIP): 134 | batch = get_batch(frame_idx) 135 | with torch.no_grad(): 136 | transform = stabnet(batch) 137 | warped = get_warp(transform, frames_tensor[frame_idx: frame_idx + 1,...].cuda()) 138 | stable_frames_tensor[frame_idx] = warped 139 | img = warped.permute(0,2,3,1)[0,...].cpu().detach().numpy() 140 | img *= std 141 | img += mean 142 | img = np.clip(img * 255.0,0,255).astype(np.uint8) 143 | cv2.imshow('window', img) 144 | if cv2.waitKey(1) & 0xFF == ord('q'): 145 | break 146 | cv2.destroyAllWindows() 147 | 148 | #undo standardization 149 | stable_frames = np.clip(((stable_frames_tensor.permute(0,2,3,1).numpy() * std) + mean) * 255,0,255).astype(np.uint8) 150 | frames = np.clip(((frames_tensor.permute(0,2,3,1).numpy() * std) + mean) * 255,0,255).astype(np.uint8) 151 | save_video(stable_frames,out_path) 152 | 153 | 154 | if __name__ == '__main__': 155 | args = parse_args() 156 | ckpt_dir = './ckpts/with_future_frames/' 157 | stabnet = StabNet().to(device).eval() 158 | ckpts = os.listdir(ckpt_dir) 159 | if ckpts: 160 | ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], "%H-%M-%S"), reverse=True) 161 | latest = os.path.join(ckpt_dir, ckpts[0]) 162 | state = torch.load(latest) 163 | stabnet.load_state_dict(state['model']) 164 | print('Loaded StabNet',latest) 165 | stabilize(args.in_path, args.out_path) -------------------------------------------------------------------------------- /stabilize_online.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import argparse 4 | from time import time 5 | import os 6 | import datetime 7 | import matplotlib.pyplot as plt 8 | import torch 9 | import torch.nn as nn 10 | import torchvision 11 | import torch.nn.functional as F 12 | 13 | device = 'cuda' 14 | batch_size = 1 15 | grid_h,grid_w = 15,15 16 | H,W = height,width = 360,640 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='Video Stabilization using StabNet') 21 | parser.add_argument('--in_path', type=str, help='Input video file path') 22 | parser.add_argument('--out_path', type=str, help='Output stabilized video file path') 23 | return parser.parse_args() 24 | 25 | def get_warp(net_out,img): 26 | ''' 27 | Inputs: 28 | net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2]) 29 | img: image to warp 30 | ''' 31 | grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1), 32 | torch.linspace(-1,1, grid_h + 1), 33 | indexing='ij') 34 | src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device) 35 | new_grid = src_grid + 1 * net_out 36 | grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True) 37 | warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=False,padding_mode='zeros') 38 | return warped 39 | 40 | def save_video(frames, path): 41 | frame_count,h,w,_ = frames.shape 42 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 43 | out = cv2.VideoWriter(path, fourcc, 30.0, (w,h)) 44 | for idx in range(frame_count): 45 | out.write(frames[idx,...]) 46 | out.release() 47 | 48 | class StabNet(nn.Module): 49 | def __init__(self,trainable_layers = 10): 50 | super(StabNet, self).__init__() 51 | # Load the pre-trained ResNet model 52 | vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1') 53 | # Extract conv1 pretrained weights for RGB input 54 | rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3]) 55 | # Calculate the average across the RGB channels 56 | average_rgb_weights = torch.mean(rgb_weights, dim=1, keepdim=True).repeat(1,6,1,1) #torch.Size([64, 5, 7, 7]) 57 | # Change size of the first layer from 3 to 9 channels 58 | vgg19.features[0] = nn.Conv2d(9,64, kernel_size=3, stride=1, padding=1, bias=False) 59 | # set new weights 60 | new_weights = torch.cat((rgb_weights, average_rgb_weights), dim=1) 61 | vgg19.features[0].weight = nn.Parameter(new_weights) 62 | # Determine the total number of layers in the model 63 | total_layers = sum(1 for _ in vgg19.parameters()) 64 | # Freeze the layers except the last 10 65 | for idx, param in enumerate(vgg19.parameters()): 66 | if idx > total_layers - trainable_layers: 67 | param.requires_grad = True 68 | else: 69 | param.requires_grad = False 70 | # Remove the last layer of ResNet 71 | self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1]) 72 | self.regressor = nn.Sequential(nn.Linear(512,2048), 73 | nn.ReLU(), 74 | nn.Linear(2048,1024), 75 | nn.ReLU(), 76 | nn.Linear(1024,512), 77 | nn.ReLU(), 78 | nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2))) 79 | total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad) 80 | total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad) 81 | print("Total Trainable encoder Parameters: ", total_resnet_params) 82 | print("Total Trainable regressor Parameters: ", total_regressor_params) 83 | print("Total Trainable parameters:",total_regressor_params + total_resnet_params) 84 | 85 | def forward(self, x_tensor): 86 | x_batch_size = x_tensor.size()[0] 87 | x = x_tensor[:, :3, :, :] 88 | 89 | # summary 1, dismiss now 90 | x_tensor = self.encoder(x_tensor) 91 | x_tensor = torch.mean(x_tensor, dim=[2, 3]) 92 | x = self.regressor(x_tensor) 93 | x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2) 94 | return x 95 | 96 | def stabilize(in_path,out_path): 97 | 98 | if not os.path.exists(in_path): 99 | print(f"The input file '{in_path}' does not exist.") 100 | exit() 101 | _,ext = os.path.splitext(in_path) 102 | if ext not in ['.mp4','.avi']: 103 | print(f"The input file '{in_path}' is not a supported video file (only .mp4 and .avi are supported).") 104 | exit() 105 | 106 | #Load frames and stardardize 107 | cap = cv2.VideoCapture(in_path) 108 | frames = [] 109 | while True: 110 | ret,frame = cap.read() 111 | if not ret : break 112 | frame = cv2.resize(frame,(256,256)) 113 | frames.append(frame) 114 | frames = np.array(frames) 115 | frame_count,_,_,_ = frames.shape 116 | 117 | # stabilize video 118 | frames_tensor = torch.from_numpy(frames/255.0).permute(0,3,1,2).float() 119 | stable_frames = frames_tensor.clone() 120 | buffer = torch.zeros((6,1,H,W)).float() 121 | cv2.namedWindow('window',cv2.WINDOW_NORMAL) 122 | for idx in range(33,frame_count): 123 | for i in range(6): 124 | buffer[i,...] = torch.mean(stable_frames[idx - 2**i,...],dim = 0,keepdim = True) 125 | curr = stable_frames[idx:idx+1,...] 126 | net_in = torch.cat([curr,buffer.permute(1,0,2,3)], dim = 1).to(device) 127 | with torch.no_grad(): 128 | trasnform = stabnet(net_in) 129 | warped = get_warp(trasnform * 0.5 ,curr.to(device)) 130 | stable_frames[idx:idx+1,...] = warped.cpu() 131 | warped_gray = torch.mean(warped,dim = 1,keepdim=True) 132 | buffer = torch.roll(buffer, shifts= 1, dims=1) 133 | buffer[:,:1,:,:] = warped_gray 134 | img = stable_frames[idx,...].permute(1,2,0).numpy() 135 | img = (img * 255).astype(np.uint8) 136 | cv2.imshow('window',img) 137 | if cv2.waitKey(1) & 0xFF == ord(' '): 138 | break 139 | cv2.destroyAllWindows() 140 | 141 | #undo standardization 142 | stable_frames = np.clip(stable_frames.permute(0,2,3,1).numpy() * 255,0,255).astype(np.uint8) 143 | save_video(stable_frames,out_path) 144 | 145 | 146 | if __name__ == '__main__': 147 | args = parse_args() 148 | ckpt_dir = './ckpts/original/' 149 | stabnet = StabNet().to(device).eval() 150 | ckpts = os.listdir(ckpt_dir) 151 | if ckpts: 152 | ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], "%H-%M-%S"), reverse=True) 153 | latest = os.path.join(ckpt_dir, ckpts[0]) 154 | state = torch.load(latest) 155 | stabnet.load_state_dict(state['model']) 156 | print('Loaded StabNet',latest) 157 | stabilize(args.in_path, args.out_path) -------------------------------------------------------------------------------- /train_vgg19_16x16_future_frames.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import cv2\n", 11 | "import torch\n", 12 | "from time import time\n", 13 | "import os\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import torch.nn as nn\n", 16 | "import torchvision\n", 17 | "import torch.nn.functional as F\n", 18 | "from torch.utils import data\n", 19 | "from datagen import Datagen\n", 20 | "from image_utils import *\n", 21 | "from v2_93 import *\n", 22 | "import math\n", 23 | "import datetime\n", 24 | "\n", 25 | "device = 'cuda'\n", 26 | "ckpt_dir = 'E:/ModelCkpts/StabNet-multigrid-future_frames'\n", 27 | "starting_epoch = 0\n", 28 | "H,W,C = shape = (256,256,3)\n", 29 | "batch_size = 4\n", 30 | "EPOCHS = 10\n", 31 | "grid_h, grid_w = 15,15" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "from models import transformer\n", 41 | "class StabNet(nn.Module):\n", 42 | " def __init__(self,trainable_layers = 10):\n", 43 | " super(StabNet, self).__init__()\n", 44 | " # Load the pre-trained ResNet model\n", 45 | " vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n", 46 | " # Extract conv1 pretrained weights for RGB input\n", 47 | " rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n", 48 | " # Calculate the average across the RGB channels\n", 49 | " tiled_rgb_weights = rgb_weights.repeat(1,5,1,1) \n", 50 | " # Change size of the first layer from 3 to 9 channels\n", 51 | " vgg19.features[0] = nn.Conv2d(15,64, kernel_size=3, stride=1, padding=1, bias=False)\n", 52 | " # set new weights\n", 53 | " vgg19.features[0].weight = nn.Parameter(tiled_rgb_weights)\n", 54 | " # Determine the total number of layers in the model\n", 55 | " total_layers = sum(1 for _ in vgg19.parameters())\n", 56 | " # Freeze the layers except the last 10\n", 57 | " for idx, param in enumerate(vgg19.parameters()):\n", 58 | " if idx > total_layers - trainable_layers:\n", 59 | " param.requires_grad = True\n", 60 | " else:\n", 61 | " param.requires_grad = False\n", 62 | " # Remove the last layer of ResNet\n", 63 | " self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n", 64 | " self.regressor = nn.Sequential(nn.Linear(512,2048),\n", 65 | " nn.ReLU(),\n", 66 | " nn.Linear(2048,1024),\n", 67 | " nn.ReLU(),\n", 68 | " nn.Linear(1024,512),\n", 69 | " nn.ReLU(),\n", 70 | " nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n", 71 | " #self.regressor[-1].bias.data.fill_(0)\n", 72 | " total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n", 73 | " total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n", 74 | " print(\"Total Trainable mobilenet Parameters: \", total_resnet_params)\n", 75 | " print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n", 76 | " print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n", 77 | " \n", 78 | " def forward(self, x_tensor):\n", 79 | " x_batch_size = x_tensor.size()[0]\n", 80 | " x = x_tensor[:, :3, :, :]\n", 81 | "\n", 82 | " # summary 1, dismiss now\n", 83 | " x_tensor = self.encoder(x_tensor)\n", 84 | " x_tensor = torch.mean(x_tensor, dim=[2, 3])\n", 85 | " x = self.regressor(x_tensor)\n", 86 | " x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n", 87 | " return x" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 3, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "Total Trainable mobilenet Parameters: 2360320\n", 100 | "Total Trainable regressor Parameters: 3936256\n", 101 | "Total Trainable parameters: 6296576\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "stabnet = StabNet(trainable_layers=10).to(device).train()\n", 107 | "optimizer = torch.optim.Adam(stabnet.parameters(),lr = 2e-5)\n" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "text/plain": [ 118 | "torch.Size([1, 16, 16, 2])" 119 | ] 120 | }, 121 | "execution_count": 4, 122 | "metadata": {}, 123 | "output_type": "execute_result" 124 | } 125 | ], 126 | "source": [ 127 | "out = stabnet(torch.randn(1,15,256,256).float().to(device))\n", 128 | "out.shape" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 5, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "Loaded weights from E:/ModelCkpts/StabNet-multigrid-future_frames\\stabnet_2023-11-01_23-17-37.pth\n", 141 | "Reduced learning rate to 5.000000000000001e-07\n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "ckpts = os.listdir(ckpt_dir)\n", 147 | "if ckpts:\n", 148 | " ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n", 149 | " # Get the filename of the latest checkpoint\n", 150 | " latest = os.path.join(ckpt_dir, ckpts[0])\n", 151 | " # Load the latest checkpoint\n", 152 | " state = torch.load(latest)\n", 153 | " stabnet.load_state_dict(state['model'])\n", 154 | " optimizer.load_state_dict(state['optimizer'])\n", 155 | " starting_epoch = state['epoch'] + 1\n", 156 | " optimizer.param_groups[0]['lr'] *= 0.1\n", 157 | " print('Loaded weights from', latest)\n", 158 | " print('Reduced learning rate to', optimizer.param_groups[0]['lr'])" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 6, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "def get_warp(net_out,img):\n", 168 | " '''\n", 169 | " Inputs:\n", 170 | " net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n", 171 | " img: image to warp\n", 172 | " '''\n", 173 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n", 174 | " torch.linspace(-1,1, grid_h + 1),\n", 175 | " indexing='ij')\n", 176 | " src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n", 177 | " new_grid = src_grid + net_out\n", 178 | " grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n", 179 | " warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=True)\n", 180 | " return warped ,grid_upscaled.permute(0,2,3,1)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 7, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "def temp_loss(warped0,warped1, flow):\n", 190 | " #prev warped1\n", 191 | " #curr warped0\n", 192 | " temp = dense_warp(warped1, flow)\n", 193 | " return F.l1_loss(warped0,temp)\n", 194 | "\n", 195 | "def shape_loss(net_out):\n", 196 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n", 197 | " torch.linspace(-1,1, grid_h + 1),\n", 198 | " indexing='ij')\n", 199 | " grid_tensor = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n", 200 | " new_grid = grid_tensor + net_out\n", 201 | "\n", 202 | " #Lintra\n", 203 | " vt0 = new_grid[:, :-1, 1:, :]\n", 204 | " vt0_original = grid_tensor[:, :-1, 1:, :]\n", 205 | " vt1 = new_grid[:, :-1, :-1, :]\n", 206 | " vt1_original = grid_tensor[:, :-1, :-1, :]\n", 207 | " vt = new_grid[:, 1:, :-1, :]\n", 208 | " vt_original = grid_tensor[:, 1:, :-1, :]\n", 209 | " alpha = vt - vt1\n", 210 | " s = torch.norm(vt_original - vt1_original, dim=-1) / torch.norm(vt0_original - vt1_original, dim=-1)\n", 211 | " vt01 = vt0 - vt1\n", 212 | " beta = s[..., None] * torch.stack([vt01[..., 1], -vt01[..., 0]], dim=-1)\n", 213 | " norm = torch.norm(alpha - beta, dim=-1, keepdim=True)\n", 214 | " Lintra = torch.sum(norm) / (((grid_h + 1) * (grid_w + 1)) * batch_size)\n", 215 | "\n", 216 | " # Extract the vertices for computation\n", 217 | " vt1_vertical = new_grid[:, :-2, :, :]\n", 218 | " vt_vertical = new_grid[:, 1:-1, :, :]\n", 219 | " vt0_vertical = new_grid[:, 2:, :, :]\n", 220 | "\n", 221 | " vt1_horizontal = new_grid[:, :, :-2, :]\n", 222 | " vt_horizontal = new_grid[:, :, 1:-1, :]\n", 223 | " vt0_horizontal = new_grid[:, :, 2:, :]\n", 224 | "\n", 225 | " # Compute the differences\n", 226 | " vt_diff_vertical = vt1_vertical - vt_vertical\n", 227 | " vt_diff_horizontal = vt1_horizontal - vt_horizontal\n", 228 | "\n", 229 | " # Compute Linter for vertical direction\n", 230 | " Linter_vertical = torch.mean(torch.norm(vt_diff_vertical - (vt_vertical - vt0_vertical), dim=-1))\n", 231 | "\n", 232 | " # Compute Linter for horizontal direction\n", 233 | " Linter_horizontal = torch.mean(torch.norm(vt_diff_horizontal - (vt_horizontal - vt0_horizontal), dim=-1))\n", 234 | "\n", 235 | " # Combine Linter for both directions\n", 236 | " Linter = Linter_vertical + Linter_horizontal\n", 237 | "\n", 238 | " # Compute the shape loss\n", 239 | " shape_loss = Lintra + 20 * Linter\n", 240 | "\n", 241 | " return shape_loss\n", 242 | "\n", 243 | "def feature_loss(features, warp_field):\n", 244 | " stable_features = ((features[:, :, 0, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n", 245 | " unstable_features = ((features[:, :, 1, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n", 246 | " \n", 247 | " # Clip the features to the range [0, 255]\n", 248 | " stable_features = torch.clamp(stable_features, min=0, max=255)\n", 249 | " unstable_features = torch.clamp(unstable_features, min=0, max=255)\n", 250 | " \n", 251 | " warped_unstable_features = unstable_features + warp_field[:, unstable_features[:, :, 1].long(), unstable_features[:, :, 0].long(), :]\n", 252 | " loss = torch.mean(torch.sqrt(torch.sum((stable_features - warped_unstable_features) ** 2,dim = -1)))\n", 253 | " \n", 254 | " return loss" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 8, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "def draw_features(image,features,color):\n", 264 | " drawn = image.copy()\n", 265 | " for point in features:\n", 266 | " x,y = point\n", 267 | " cv2.circle(drawn, (int(x),int(y)), 2, color, -1)\n", 268 | " return drawn" 269 | ] 270 | }, 271 | { 272 | "attachments": { 273 | "image.png": { 274 | "image/png": "" 275 | } 276 | }, 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "![image.png](attachment:image.png)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 9, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "class IterDataset(data.IterableDataset):\n", 290 | " def __init__(self, data_generator):\n", 291 | " super(IterDataset, self).__init__()\n", 292 | " self.data_generator = data_generator\n", 293 | "\n", 294 | " def __iter__(self):\n", 295 | " return iter(self.data_generator())\n", 296 | "generator = Datagen(shape = (H,W),txt_path = './trainlist.txt')\n", 297 | "iter_dataset = IterDataset(generator)\n", 298 | "data_loader = data.DataLoader(iter_dataset, batch_size=batch_size)\n" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 10, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "from torch.utils.tensorboard import SummaryWriter\n", 308 | "# default `log_dir` is \"runs\" - we'll be more specific here\n", 309 | "writer = SummaryWriter('runs/vgg_16x16/')" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 11, 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "Epoch: 2, Batch:1193, loss:8.925530743091665 pixel_loss:4.0964155197143555, feature_loss:3.904881715774536 ,temp:4.310582160949707, shape:0.2874298393726349353" 322 | ] 323 | }, 324 | { 325 | "ename": "KeyboardInterrupt", 326 | "evalue": "", 327 | "output_type": "error", 328 | "traceback": [ 329 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 330 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 331 | "\u001b[1;32mc:\\Users\\VINY\\Desktop\\Stabnet multigrid2-DeepStab Modded - Future Frames\\Stabnet_vgg19_16x16.ipynb Cell 12\u001b[0m line \u001b[0;36m6\n\u001b[0;32m 3\u001b[0m \u001b[39mfor\u001b[39;00m epoch \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(starting_epoch,EPOCHS): \n\u001b[0;32m 4\u001b[0m \u001b[39m# Generate the data for each iteration\u001b[39;00m\n\u001b[0;32m 5\u001b[0m running_loss \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m----> 6\u001b[0m \u001b[39mfor\u001b[39;00m idx,batch \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(data_loader):\n\u001b[0;32m 7\u001b[0m start \u001b[39m=\u001b[39m time()\n\u001b[0;32m 8\u001b[0m St, St_1, Igt , flow,features \u001b[39m=\u001b[39m batch\n", 332 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 631\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m 632\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 633\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[0;32m 634\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[0;32m 635\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 636\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 637\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n", 333 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:677\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 675\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m 676\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 677\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m 678\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[0;32m 679\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n", 334 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:32\u001b[0m, in \u001b[0;36m_IterableDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m possibly_batched_index:\n\u001b[0;32m 31\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 32\u001b[0m data\u001b[39m.\u001b[39mappend(\u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset_iter))\n\u001b[0;32m 33\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n\u001b[0;32m 34\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mended \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n", 335 | "File \u001b[1;32mc:\\Users\\VINY\\Desktop\\Stabnet multigrid2-DeepStab Modded - Future Frames\\datagen.py:42\u001b[0m, in \u001b[0;36mDatagen.__call__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 40\u001b[0m u_cap\u001b[39m.\u001b[39mset(cv2\u001b[39m.\u001b[39mCAP_PROP_POS_FRAMES, frame_idx)\n\u001b[0;32m 41\u001b[0m \u001b[39mfor\u001b[39;00m i,pos \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(\u001b[39mrange\u001b[39m(frame_idx \u001b[39m-\u001b[39m SKIP, frame_idx \u001b[39m+\u001b[39m SKIP \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m, SKIP \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m \u001b[39m2\u001b[39m)):\n\u001b[1;32m---> 42\u001b[0m u_cap\u001b[39m.\u001b[39;49mset(cv2\u001b[39m.\u001b[39;49mCAP_PROP_POS_FRAMES,pos \u001b[39m-\u001b[39;49m \u001b[39m1\u001b[39;49m)\n\u001b[0;32m 43\u001b[0m _,temp \u001b[39m=\u001b[39m u_cap\u001b[39m.\u001b[39mread()\n\u001b[0;32m 44\u001b[0m temp \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpreprocess(temp)\n", 336 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: " 337 | ] 338 | } 339 | ], 340 | "source": [ 341 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n", 342 | "# Training loop\n", 343 | "for epoch in range(starting_epoch,EPOCHS): \n", 344 | " # Generate the data for each iteration\n", 345 | " running_loss = 0\n", 346 | " for idx,batch in enumerate(data_loader):\n", 347 | " start = time()\n", 348 | " St, St_1, Igt , flow,features = batch\n", 349 | " # Move the data to GPU if available\n", 350 | " St = St.to(device)\n", 351 | " St_1 = St_1.to(device)\n", 352 | " Igt = Igt.to(device)\n", 353 | " It = St[:,6:9,...].to(device)\n", 354 | " It_1 = St_1[:,6:9,...].to(device)\n", 355 | " flow = flow.to(device)\n", 356 | " features = features.to(device)\n", 357 | " # Forward pass through the Siamese Network\n", 358 | " \n", 359 | " transform0 = stabnet(St)\n", 360 | " transform1 = stabnet(St_1)\n", 361 | "\n", 362 | " warped0,warp_field = get_warp(transform0,It)\n", 363 | " warped1,_ = get_warp(transform1,It_1)\n", 364 | " # Compute the losses\n", 365 | " #stability loss\n", 366 | " pixel_loss = 10 * F.mse_loss(warped0, Igt)\n", 367 | " feat_loss = feature_loss(features,warp_field)\n", 368 | " stability_loss = pixel_loss + feat_loss\n", 369 | " #shape_loss\n", 370 | " sh_loss = shape_loss(transform0)\n", 371 | " #temporal loss\n", 372 | " warped2 = dense_warp(warped1, flow)\n", 373 | " temp_loss = 10 * F.mse_loss(warped0,warped2)\n", 374 | " # Perform backpropagation and update the model parameters\n", 375 | " optimizer.zero_grad()\n", 376 | " total_loss = stability_loss + sh_loss + temp_loss\n", 377 | " total_loss.backward()\n", 378 | " optimizer.step()\n", 379 | "\n", 380 | " \n", 381 | " means = np.array([0.485, 0.456, 0.406],dtype = np.float32)\n", 382 | " stds = np.array([0.229, 0.224, 0.225],dtype = np.float32)\n", 383 | "\n", 384 | " img = warped0[0,...].detach().cpu().permute(1,2,0).numpy()\n", 385 | " img *= stds\n", 386 | " img += means\n", 387 | " img = np.clip(img * 255.0,0,255).astype(np.uint8)\n", 388 | "\n", 389 | " img1 = Igt[0,...].cpu().permute(1,2,0).numpy()\n", 390 | " img1 *= stds\n", 391 | " img1 += means\n", 392 | " img1 = np.clip(img1 * 255.0,0,255).astype(np.uint8)\n", 393 | " #draw features\n", 394 | " stable_features = ((features[:, :, 0, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n", 395 | " unstable_features = ((features[:, :, 1, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n", 396 | " # Clip the features to the range [0, 255]\n", 397 | " stable_features = torch.clamp(stable_features, min=0, max=255).cpu().numpy()\n", 398 | " unstable_features = torch.clamp(unstable_features, min=0, max=255).cpu().numpy()\n", 399 | " img = draw_features(img,unstable_features[0,...],color = (0,255,0))\n", 400 | " img1 = draw_features(img1,stable_features[0,...],color = (0,0,255))\n", 401 | " conc = cv2.hconcat([img,img1])\n", 402 | " cv2.imshow('window',conc)\n", 403 | " if cv2.waitKey(1) & 0xFF == ord('9'):\n", 404 | " break\n", 405 | " \n", 406 | " running_loss += total_loss.item()\n", 407 | " print(f\"\\rEpoch: {epoch}, Batch:{idx}, loss:{running_loss / (idx % 100 + 1)}\\\n", 408 | " pixel_loss:{pixel_loss.item()}, feature_loss:{feat_loss.item()} ,temp:{temp_loss.item()}, shape:{sh_loss.item()}\",end = \"\")\n", 409 | " if idx % 100 == 99:\n", 410 | " writer.add_scalar('training_loss',\n", 411 | " running_loss / 100, \n", 412 | " epoch * 41328 // batch_size + idx)\n", 413 | " running_loss = 0.0\n", 414 | " # Get current date and time as a formatted string\n", 415 | " current_datetime = datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n", 416 | "\n", 417 | " # Append the formatted date-time string to the model filename\n", 418 | " model_path = os.path.join(ckpt_dir, f'stabnet_{current_datetime}.pth')\n", 419 | " \n", 420 | " torch.save({'model': stabnet.state_dict(),\n", 421 | " 'optimizer' : optimizer.state_dict(),\n", 422 | " 'epoch' : epoch}\n", 423 | " ,model_path)\n", 424 | " del St, St_1, It, Igt, It_1,flow,features\n", 425 | "cv2.destroyAllWindows()" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": null, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [] 434 | } 435 | ], 436 | "metadata": { 437 | "kernelspec": { 438 | "display_name": "DUTCode", 439 | "language": "python", 440 | "name": "python3" 441 | }, 442 | "language_info": { 443 | "codemirror_mode": { 444 | "name": "ipython", 445 | "version": 3 446 | }, 447 | "file_extension": ".py", 448 | "mimetype": "text/x-python", 449 | "name": "python", 450 | "nbconvert_exporter": "python", 451 | "pygments_lexer": "ipython3", 452 | "version": "3.9.16" 453 | }, 454 | "orig_nbformat": 4 455 | }, 456 | "nbformat": 4, 457 | "nbformat_minor": 2 458 | } 459 | -------------------------------------------------------------------------------- /train_vgg19_16x16_online.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import cv2\n", 11 | "import torch\n", 12 | "from time import time\n", 13 | "import os\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import torch.nn as nn\n", 16 | "import torchvision\n", 17 | "import torch.nn.functional as F\n", 18 | "from torch.utils import data\n", 19 | "from datagen import Datagen\n", 20 | "from image_utils import *\n", 21 | "from v2_93 import *\n", 22 | "import math\n", 23 | "import datetime\n", 24 | "\n", 25 | "device = 'cuda'\n", 26 | "ckpt_dir = 'E:/ModelCkpts/stabnet_multigrid_3/vgg_16x16'\n", 27 | "starting_epoch = 0\n", 28 | "H,W,C = shape = (256,256,3)\n", 29 | "batch_size = 4\n", 30 | "EPOCHS = 10\n", 31 | "grid_h, grid_w = 15,15" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "from models import transformer\n", 41 | "class StabNet(nn.Module):\n", 42 | " def __init__(self,trainable_layers = 10):\n", 43 | " super(StabNet, self).__init__()\n", 44 | " # Load the pre-trained ResNet model\n", 45 | " vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n", 46 | " # Extract conv1 pretrained weights for RGB input\n", 47 | " rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n", 48 | " # Calculate the average across the RGB channels\n", 49 | " average_rgb_weights = torch.mean(rgb_weights, dim=1, keepdim=True).repeat(1,6,1,1) #torch.Size([64, 5, 7, 7])\n", 50 | " # Change size of the first layer from 3 to 9 channels\n", 51 | " vgg19.features[0] = nn.Conv2d(9,64, kernel_size=3, stride=1, padding=1, bias=False)\n", 52 | " # set new weights\n", 53 | " new_weights = torch.cat((rgb_weights, average_rgb_weights), dim=1)\n", 54 | " vgg19.features[0].weight = nn.Parameter(new_weights)\n", 55 | " # Determine the total number of layers in the model\n", 56 | " total_layers = sum(1 for _ in vgg19.parameters())\n", 57 | " # Freeze the layers except the last 10\n", 58 | " for idx, param in enumerate(vgg19.parameters()):\n", 59 | " if idx > total_layers - trainable_layers:\n", 60 | " param.requires_grad = True\n", 61 | " else:\n", 62 | " param.requires_grad = False\n", 63 | " # Remove the last layer of ResNet\n", 64 | " self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n", 65 | " self.regressor = nn.Sequential(nn.Linear(512,2048),\n", 66 | " nn.ReLU(),\n", 67 | " nn.Linear(2048,1024),\n", 68 | " nn.ReLU(),\n", 69 | " nn.Linear(1024,512),\n", 70 | " nn.ReLU(),\n", 71 | " nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n", 72 | " #self.regressor[-1].bias.data.fill_(0)\n", 73 | " total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n", 74 | " total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n", 75 | " print(\"Total Trainable mobilenet Parameters: \", total_resnet_params)\n", 76 | " print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n", 77 | " print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n", 78 | " \n", 79 | " def forward(self, x_tensor):\n", 80 | " x_batch_size = x_tensor.size()[0]\n", 81 | " x = x_tensor[:, :3, :, :]\n", 82 | "\n", 83 | " # summary 1, dismiss now\n", 84 | " x_tensor = self.encoder(x_tensor)\n", 85 | " x_tensor = torch.mean(x_tensor, dim=[2, 3])\n", 86 | " x = self.regressor(x_tensor)\n", 87 | " x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n", 88 | " return x" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 3, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Total Trainable mobilenet Parameters: 9439232\n", 101 | "Total Trainable regressor Parameters: 3936256\n", 102 | "Total Trainable parameters: 13375488\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "stabnet = StabNet(trainable_layers=15).to(device).train()\n", 108 | "optimizer = torch.optim.Adam(stabnet.parameters(),lr = 2e-5)\n" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 4, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "data": { 118 | "text/plain": [ 119 | "torch.Size([1, 16, 16, 2])" 120 | ] 121 | }, 122 | "execution_count": 4, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "out = stabnet(torch.randn(1,9,256,256).float().to(device))\n", 129 | "out.shape" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 5, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "Loaded weights from E:/ModelCkpts/stabnet_multigrid_3/vgg_16x16\\stabnet_2023-10-26_06-34-42.pth\n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "ckpts = os.listdir(ckpt_dir)\n", 147 | "if ckpts:\n", 148 | " ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n", 149 | " # Get the filename of the latest checkpoint\n", 150 | " latest = os.path.join(ckpt_dir, ckpts[0])\n", 151 | " # Load the latest checkpoint\n", 152 | " state = torch.load(latest)\n", 153 | " stabnet.load_state_dict(state['model'])\n", 154 | " optimizer.load_state_dict(state['optimizer'])\n", 155 | " starting_epoch = state['epoch'] + 1\n", 156 | " print('Loaded weights from', latest)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 6, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "def get_warp(net_out,img):\n", 166 | " '''\n", 167 | " Inputs:\n", 168 | " net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n", 169 | " img: image to warp\n", 170 | " '''\n", 171 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n", 172 | " torch.linspace(-1,1, grid_h + 1),\n", 173 | " indexing='ij')\n", 174 | " src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n", 175 | " new_grid = src_grid + net_out\n", 176 | " grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n", 177 | " warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=True)\n", 178 | " return warped ,grid_upscaled.permute(0,2,3,1)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 7, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "def temp_loss(warped0,warped1, flow):\n", 188 | " #prev warped1\n", 189 | " #curr warped0\n", 190 | " temp = dense_warp(warped1, flow)\n", 191 | " return F.l1_loss(warped0,temp)\n", 192 | "\n", 193 | "def shape_loss(net_out):\n", 194 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n", 195 | " torch.linspace(-1,1, grid_h + 1),\n", 196 | " indexing='ij')\n", 197 | " grid_tensor = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n", 198 | " new_grid = grid_tensor + net_out\n", 199 | "\n", 200 | " #Lintra\n", 201 | " vt0 = new_grid[:, :-1, 1:, :]\n", 202 | " vt0_original = grid_tensor[:, :-1, 1:, :]\n", 203 | " vt1 = new_grid[:, :-1, :-1, :]\n", 204 | " vt1_original = grid_tensor[:, :-1, :-1, :]\n", 205 | " vt = new_grid[:, 1:, :-1, :]\n", 206 | " vt_original = grid_tensor[:, 1:, :-1, :]\n", 207 | " alpha = vt - vt1\n", 208 | " s = torch.norm(vt_original - vt1_original, dim=-1) / torch.norm(vt0_original - vt1_original, dim=-1)\n", 209 | " vt01 = vt0 - vt1\n", 210 | " beta = s[..., None] * torch.stack([vt01[..., 1], -vt01[..., 0]], dim=-1)\n", 211 | " norm = torch.norm(alpha - beta, dim=-1, keepdim=True)\n", 212 | " Lintra = torch.sum(norm) / (((grid_h + 1) * (grid_w + 1)) * batch_size)\n", 213 | "\n", 214 | " # Extract the vertices for computation\n", 215 | " vt1_vertical = new_grid[:, :-2, :, :]\n", 216 | " vt_vertical = new_grid[:, 1:-1, :, :]\n", 217 | " vt0_vertical = new_grid[:, 2:, :, :]\n", 218 | "\n", 219 | " vt1_horizontal = new_grid[:, :, :-2, :]\n", 220 | " vt_horizontal = new_grid[:, :, 1:-1, :]\n", 221 | " vt0_horizontal = new_grid[:, :, 2:, :]\n", 222 | "\n", 223 | " # Compute the differences\n", 224 | " vt_diff_vertical = vt1_vertical - vt_vertical\n", 225 | " vt_diff_horizontal = vt1_horizontal - vt_horizontal\n", 226 | "\n", 227 | " # Compute Linter for vertical direction\n", 228 | " Linter_vertical = torch.mean(torch.norm(vt_diff_vertical - (vt_vertical - vt0_vertical), dim=-1))\n", 229 | "\n", 230 | " # Compute Linter for horizontal direction\n", 231 | " Linter_horizontal = torch.mean(torch.norm(vt_diff_horizontal - (vt_horizontal - vt0_horizontal), dim=-1))\n", 232 | "\n", 233 | " # Combine Linter for both directions\n", 234 | " Linter = Linter_vertical + Linter_horizontal\n", 235 | "\n", 236 | " # Compute the shape loss\n", 237 | " shape_loss = Lintra + 20 * Linter\n", 238 | "\n", 239 | " return shape_loss\n", 240 | "\n", 241 | "def feature_loss(features, warp_field):\n", 242 | " stable_features = ((features[:, :, 0, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n", 243 | " unstable_features = ((features[:, :, 1, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n", 244 | " \n", 245 | " # Clip the features to the range [0, 255]\n", 246 | " stable_features = torch.clamp(stable_features, min=0, max=255)\n", 247 | " unstable_features = torch.clamp(unstable_features, min=0, max=255)\n", 248 | " \n", 249 | " warped_unstable_features = unstable_features + warp_field[:, unstable_features[:, :, 1].long(), unstable_features[:, :, 0].long(), :]\n", 250 | " loss = torch.mean(torch.sqrt(torch.sum((stable_features - warped_unstable_features) ** 2,dim = -1)))\n", 251 | " \n", 252 | " return loss\n", 253 | "\n", 254 | "\n" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 11, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "def draw_features(image,features,color):\n", 264 | " drawn = image.copy()\n", 265 | " for point in features:\n", 266 | " x,y = point\n", 267 | " cv2.circle(drawn, (int(x),int(y)), 5, (0, 255, 0), -1)\n", 268 | " return drawn" 269 | ] 270 | }, 271 | { 272 | "attachments": { 273 | "image.png": { 274 | "image/png": "" 275 | } 276 | }, 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "![image.png](attachment:image.png)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 8, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "class IterDataset(data.IterableDataset):\n", 290 | " def __init__(self, data_generator):\n", 291 | " super(IterDataset, self).__init__()\n", 292 | " self.data_generator = data_generator\n", 293 | "\n", 294 | " def __iter__(self):\n", 295 | " return iter(self.data_generator())\n", 296 | "generator = Datagen(shape = (H,W),txt_path = './trainlist.txt')\n", 297 | "iter_dataset = IterDataset(generator)\n", 298 | "data_loader = data.DataLoader(iter_dataset, batch_size=batch_size)\n" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 9, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "from torch.utils.tensorboard import SummaryWriter\n", 308 | "# default `log_dir` is \"runs\" - we'll be more specific here\n", 309 | "writer = SummaryWriter('runs/vgg_16x16/')" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 16, 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "Epoch: 1, Batch:21, loss:5.301943264224312 pixel_loss:0.03195870295166969, feature_loss:7.603750228881836 ,temp:0.19699248671531677, shape:0.25114700198173523593" 322 | ] 323 | }, 324 | { 325 | "ename": "KeyboardInterrupt", 326 | "evalue": "", 327 | "output_type": "error", 328 | "traceback": [ 329 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 330 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 331 | "\u001b[1;32mc:\\Users\\VINY\\Desktop\\Stabnet multigrid2-DeepStab Modded\\Stabnet_vgg19_16x16.ipynb Cell 12\u001b[0m line \u001b[0;36m8\n\u001b[0;32m 6\u001b[0m \u001b[39m# Generate the data for each iteration\u001b[39;00m\n\u001b[0;32m 7\u001b[0m running_loss \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m----> 8\u001b[0m \u001b[39mfor\u001b[39;00m idx,batch \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(data_loader):\n\u001b[0;32m 9\u001b[0m start \u001b[39m=\u001b[39m time()\n\u001b[0;32m 10\u001b[0m St, St_1, It, Igt, It_1,flow,features \u001b[39m=\u001b[39m batch\n", 332 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 631\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m 632\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 633\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[0;32m 634\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[0;32m 635\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 636\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 637\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n", 333 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:677\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 675\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m 676\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 677\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m 678\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[0;32m 679\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n", 334 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:32\u001b[0m, in \u001b[0;36m_IterableDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m possibly_batched_index:\n\u001b[0;32m 31\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 32\u001b[0m data\u001b[39m.\u001b[39mappend(\u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset_iter))\n\u001b[0;32m 33\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n\u001b[0;32m 34\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mended \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n", 335 | "File \u001b[1;32mc:\\Users\\VINY\\Desktop\\Stabnet multigrid2-DeepStab Modded\\datagen.py:34\u001b[0m, in \u001b[0;36mDatagen.__call__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 32\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m5\u001b[39m,\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m,\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m):\n\u001b[0;32m 33\u001b[0m pos \u001b[39m=\u001b[39m \u001b[39m2\u001b[39m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39m i \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m---> 34\u001b[0m s_cap\u001b[39m.\u001b[39;49mset(cv2\u001b[39m.\u001b[39;49mCAP_PROP_POS_FRAMES, idx \u001b[39m-\u001b[39;49m pos)\n\u001b[0;32m 35\u001b[0m _,temp1 \u001b[39m=\u001b[39m s_cap\u001b[39m.\u001b[39mread()\n\u001b[0;32m 36\u001b[0m temp1 \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpreprocess(temp1) \u001b[39m# -33\u001b[39;00m\n", 336 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: " 337 | ] 338 | } 339 | ], 340 | "source": [ 341 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n", 342 | "# Training loop\n", 343 | "for epoch in range(starting_epoch,EPOCHS): \n", 344 | " for param_group in optimizer.param_groups:\n", 345 | " param_group['lr'] *= 0.5 \n", 346 | " # Generate the data for each iteration\n", 347 | " running_loss = 0\n", 348 | " for idx,batch in enumerate(data_loader):\n", 349 | " start = time()\n", 350 | " St, St_1, It, Igt, It_1,flow,features = batch\n", 351 | " # Move the data to GPU if available\n", 352 | " St = St.to(device)\n", 353 | " St_1 = St_1.to(device)\n", 354 | " It = It.to(device)\n", 355 | " Igt = Igt.to(device)\n", 356 | " It_1 = It_1.to(device)\n", 357 | " flow = flow.to(device)\n", 358 | " features = features.to(device)\n", 359 | " # Forward pass through the Siamese Network\n", 360 | " \n", 361 | " transform0 = stabnet(torch.cat([It, St], dim=1))\n", 362 | " transform1 = stabnet(torch.cat([It_1, St_1], dim=1))\n", 363 | "\n", 364 | " warped0,warp_field = get_warp(transform0,It)\n", 365 | " warped1,_ = get_warp(transform1,It_1)\n", 366 | " # Compute the losses\n", 367 | " #stability loss\n", 368 | " pixel_loss = F.mse_loss(warped0, Igt)\n", 369 | " feat_loss = feature_loss(features,warp_field)\n", 370 | " stability_loss = 50 * pixel_loss + feat_loss\n", 371 | " #shape_loss\n", 372 | " sh_loss = shape_loss(transform0)\n", 373 | " #temporal loss\n", 374 | " warped2 = dense_warp(warped1, flow)\n", 375 | " temp_loss = 10 * F.mse_loss(warped0,warped2)\n", 376 | " # Perform backpropagation and update the model parameters\n", 377 | " optimizer.zero_grad()\n", 378 | " total_loss = stability_loss + sh_loss + temp_loss\n", 379 | " total_loss.backward()\n", 380 | " optimizer.step()\n", 381 | "\n", 382 | " \n", 383 | " img = warped0[0,...].detach().cpu().permute(1,2,0).numpy()\n", 384 | " img = (img* 255).astype(np.uint8)\n", 385 | " img1 = Igt[0,...].cpu().permute(1,2,0).numpy()\n", 386 | " img1 = (img1* 255).astype(np.uint8)\n", 387 | " #draw features\n", 388 | " stable_features = ((features[:, :, 0, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n", 389 | " unstable_features = ((features[:, :, 1, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n", 390 | " # Clip the features to the range [0, 255]\n", 391 | " stable_features = torch.clamp(stable_features, min=0, max=255).cpu().numpy()\n", 392 | " unstable_features = torch.clamp(unstable_features, min=0, max=255).cpu().numpy()\n", 393 | " img = draw_features(img,unstable_features[0,...],color = (0,255,0))\n", 394 | " img1 = draw_features(img1,stable_features[0,...],color = (0,0,255))\n", 395 | " conc = cv2.hconcat([img,img1])\n", 396 | " cv2.imshow('window',conc)\n", 397 | " if cv2.waitKey(1) & 0xFF == ord('9'):\n", 398 | " break\n", 399 | " \n", 400 | " running_loss += total_loss.item()\n", 401 | " print(f\"\\rEpoch: {epoch}, Batch:{idx}, loss:{running_loss / (idx % 100 + 1)}\\\n", 402 | " pixel_loss:{pixel_loss.item()}, feature_loss:{feat_loss.item()} ,temp:{temp_loss.item()}, shape:{sh_loss.item()}\",end = \"\")\n", 403 | " if idx % 100 == 99:\n", 404 | " writer.add_scalar('training_loss',\n", 405 | " running_loss / 100, \n", 406 | " epoch * 41328 // batch_size + idx)\n", 407 | " running_loss = 0.0\n", 408 | " # Get current date and time as a formatted string\n", 409 | " current_datetime = datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n", 410 | "\n", 411 | " # Append the formatted date-time string to the model filename\n", 412 | " model_path = os.path.join(ckpt_dir, f'stabnet_{current_datetime}.pth')\n", 413 | " \n", 414 | " torch.save({'model': stabnet.state_dict(),\n", 415 | " 'optimizer' : optimizer.state_dict(),\n", 416 | " 'epoch' : epoch}\n", 417 | " ,model_path)\n", 418 | " del St, St_1, It, Igt, It_1,flow,features\n", 419 | "cv2.destroyAllWindows()" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [] 428 | } 429 | ], 430 | "metadata": { 431 | "kernelspec": { 432 | "display_name": "DUTCode", 433 | "language": "python", 434 | "name": "python3" 435 | }, 436 | "language_info": { 437 | "codemirror_mode": { 438 | "name": "ipython", 439 | "version": 3 440 | }, 441 | "file_extension": ".py", 442 | "mimetype": "text/x-python", 443 | "name": "python", 444 | "nbconvert_exporter": "python", 445 | "pygments_lexer": "ipython3", 446 | "version": "3.9.16" 447 | }, 448 | "orig_nbformat": 4 449 | }, 450 | "nbformat": 4, 451 | "nbformat_minor": 2 452 | } 453 | --------------------------------------------------------------------------------