├── Flows_dataset_raft.ipynb
├── README.md
├── create_txt.ipynb
├── datagen.py
├── image_utils.py
├── inference_future_frames.ipynb
├── inference_online.ipynb
├── matched_features_dataset.ipynb
├── metrics.py
├── requirements.txt
├── stabilize_future_frames.py
├── stabilize_online.py
├── train_vgg19_16x16_future_frames.ipynb
├── train_vgg19_16x16_online.ipynb
└── trainlist.txt
/Flows_dataset_raft.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch.nn as nn\n",
11 | "import torch.nn.functional as F\n",
12 | "import warnings\n",
13 | "import numpy as np\n",
14 | "import os\n",
15 | "import cv2\n",
16 | "import tqdm\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "device = 'cuda'\n",
19 | "shape = (H,W,C) = (256,256,3)"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 7,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "from torchvision import models\n",
29 | "raft = models.optical_flow.raft_small(weights = 'Raft_Small_Weights.C_T_V2').eval().to(device)\n",
30 | "\n",
31 | "def get_flow(img1, img2):\n",
32 | " img1_t = torch.from_numpy(img1/255.0).permute(-1,0,1).unsqueeze(0).float().to(device)\n",
33 | " img2_t = torch.from_numpy(img2/255.0).permute(-1,0,1).unsqueeze(0).float().to(device)\n",
34 | " flow = raft(img1_t,img2_t)[-1].detach().cpu().permute(0,2,3,1).squeeze(0).numpy()\n",
35 | " return flow\n",
36 | "\n",
37 | "def show_flow(flow):\n",
38 | " hsv_mask = np.zeros(shape= flow.shape[:-1] +(3,),dtype = np.uint8)\n",
39 | " hsv_mask[...,1] = 255\n",
40 | " mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1],angleInDegrees=True)\n",
41 | " hsv_mask[:,:,0] = ang /2 \n",
42 | " hsv_mask[:,:,2] = cv2.normalize(mag,None,0,255,cv2.NORM_MINMAX)\n",
43 | " rgb = cv2.cvtColor(hsv_mask,cv2.COLOR_HSV2RGB)\n",
44 | " return(rgb)"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 8,
50 | "metadata": {},
51 | "outputs": [
52 | {
53 | "name": "stderr",
54 | "output_type": "stream",
55 | "text": [
56 | " 9%|▊ | 2/23 [00:25<04:22, 12.52s/it]\n"
57 | ]
58 | },
59 | {
60 | "ename": "KeyboardInterrupt",
61 | "evalue": "",
62 | "output_type": "error",
63 | "traceback": [
64 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
65 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
66 | "Cell \u001b[1;32mIn[8], line 11\u001b[0m\n\u001b[0;32m 9\u001b[0m flows \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m 10\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m---> 11\u001b[0m ret,curr \u001b[38;5;241m=\u001b[39m \u001b[43mcap\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m ret: \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[0;32m 13\u001b[0m curr \u001b[38;5;241m=\u001b[39m cv2\u001b[38;5;241m.\u001b[39mresize(curr,(W,H))\n",
67 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "videos_path = 'E:/Datasets/DeepStab_Dataset/stable/'\n",
73 | "flows_path = 'E:/Datasets/Flows/'\n",
74 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
75 | "videos = os.listdir(videos_path)\n",
76 | "for video in tqdm.tqdm(videos):\n",
77 | " cap = cv2.VideoCapture(os.path.join(videos_path, video))\n",
78 | " ret,prev = cap.read()\n",
79 | " prev = cv2.resize(prev,(W,H))\n",
80 | " flows = []\n",
81 | " while True:\n",
82 | " ret,curr = cap.read()\n",
83 | " if not ret: break\n",
84 | " curr = cv2.resize(curr,(W,H))\n",
85 | " flow = get_flow(prev,curr)\n",
86 | " flows.append(flow)\n",
87 | " prev = curr\n",
88 | " cv2.imshow('window',show_flow(flow))\n",
89 | " if cv2.waitKey(1) & 0xFF == ord('9'):\n",
90 | " break\n",
91 | " flows = np.array(flows).astype(np.float32)\n",
92 | " output_path = os.path.join(flows_path,video.split('.')[0] + '.npy')\n",
93 | " np.save(output_path,flows)\n",
94 | "cv2.destroyAllWindows()"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": []
103 | }
104 | ],
105 | "metadata": {
106 | "kernelspec": {
107 | "display_name": "DUTCode",
108 | "language": "python",
109 | "name": "python3"
110 | },
111 | "language_info": {
112 | "codemirror_mode": {
113 | "name": "ipython",
114 | "version": 3
115 | },
116 | "file_extension": ".py",
117 | "mimetype": "text/x-python",
118 | "name": "python",
119 | "nbconvert_exporter": "python",
120 | "pygments_lexer": "ipython3",
121 | "version": "3.9.18"
122 | }
123 | },
124 | "nbformat": 4,
125 | "nbformat_minor": 2
126 | }
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Online Video Stabilization With Multi-Grid Warping Transformation Learning
2 |
3 | This is a PyTorch implementation of the [paper](https://cg.cs.tsinghua.edu.cn/papers/TIP-2019-VideoStab.pdf).
4 |
5 | .
6 |
7 | I provide the original online algorithm described in the paper and a second implementation using a buffer of future frames. The latter can no longer be categorized as an online algorithm but it achieves better stabilization results
8 |
9 | ## Inference Instructions
10 |
11 | Follow these instructions to perform video stabilization using the pretrained model:
12 |
13 | 1. **Download the pretrained models:**
14 | - Download the pretrained models [weights](https://drive.google.com/drive/folders/1K8HfenNEr_0Joi6RdX4SfKVnCg-GjhvW?usp=drive_link).
15 | - Place the downloaded weights folder in the main folder of your project.
16 |
17 | 2. **Run the Stabilization Script:**
18 | - For the original model run:
19 | ```bash
20 | python stabilize_online.py --in_path unstable_video_path --out_path result_path
21 | ```
22 | - Replace `unstable_video_path` with the path to your input unstable video.
23 | - Replace `result_path` with the desired path for the stabilized output video.
24 | - For the second model with future frames:
25 | ```bash
26 | python stabilize_future_frames.py --in_path unstable_video_path --out_path result_path
27 | ```
28 |
29 | Make sure you have the necessary dependencies installed, and that your environment is set up correctly before running the stabilization scripts.
30 | ```bash
31 | pip install numpy opencv-python torch==2.1.2 matplotlib
32 | ```
33 |
34 |
35 |
36 | ## Training Instructions
37 |
38 | Follow these instructions to train the model:
39 |
40 | 1. **Download Datasets:**
41 | - Download the training dataset: [DeepStab](https://cg.cs.tsinghua.edu.cn/people/~miao/stabnet/data.zip).
42 | - Extract the contents of the downloaded dataset to a location on your machine.
43 |
44 | 2. **Create Datasets for Loss Functions:**
45 | - Create the optical flows and matched feature datasets to be used in the loss functions descrined in the paper:
46 | - [Flows_dataset_raft.ipynb](https://github.com/btxviny/StabNet/blob/main/Flows_dataset_raft.ipynb) for optical flow dataset.
47 | - [matched_features_dataset.ipynb](https://github.com/btxviny/StabNet/blob/main/matched_features_dataset.ipynb) for matched feature dataset.
48 | - create a train_list.txt containing the file paths for each sample input, using [create.txt](https://github.com/btxviny/StabNet/blob/main/create_txt.ipynb)(adjust paths as needed).
49 |
50 | 3. **Training Notebooks:**
51 | - Online version: [train_vgg19_16x16_online.ipynb](https://github.com/btxviny/StabNet/blob/main/train_vgg19_16x16_online.ipynb)
52 | - Future frame version: [train_vgg19_16x16_future_frames.ipynb](https://github.com/btxviny/StabNet/edit/main/train_vgg19_16x16_future_frames.ipynb)
53 | - Make sure to change `ckpt_dir` to the destination you want the model checkpoints to be saved at.
54 |
55 | 4. **Metrics Calculation:**
56 | - Use `metrics.py` to compute cropping, distortion, and stability scores for the generated results.
57 |
--------------------------------------------------------------------------------
/create_txt.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 6,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import cv2\n",
11 | "import os \n",
12 | "import random\n",
13 | "import torch\n",
14 | "from torchvision import transforms"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 9,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "stable_prefix = 'E:/Datasets/DeepStab Modded/Stable_60/'\n",
24 | "unstable_prefix = 'E:/Datasets/DeepStab Modded/Unstable/'\n",
25 | "flow_prefix = 'E:/Datasets/DeepStab Modded/Flows_256x256'\n",
26 | "feature_prefix = 'E:/Datasets/DeepStab Modded/matched_features'\n",
27 | "videos = os.listdir(feature_prefix)\n",
28 | "with open('./trainlist.txt','w') as f:\n",
29 | " for vid in videos:\n",
30 | " video_name = vid.split('.')[0] + '.avi'\n",
31 | " s_path = os.path.join(stable_prefix, video_name)\n",
32 | " u_path = os.path.join(unstable_prefix, video_name)\n",
33 | " flow_path = os.path.join(flow_prefix,vid)\n",
34 | " feature_path = os.path.join(feature_prefix,vid)\n",
35 | " s_cap = cv2.VideoCapture(s_path)\n",
36 | " u_cap = cv2.VideoCapture(u_path)\n",
37 | " num_frames = min(int(s_cap.get(cv2.CAP_PROP_FRAME_COUNT)),int(u_cap.get(cv2.CAP_PROP_FRAME_COUNT)))\n",
38 | " for idx in range(33,num_frames):\n",
39 | " line = f'{s_path},{u_path},{flow_path},{feature_path},{idx}\\n'\n",
40 | " f.write(line)\n"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": []
49 | }
50 | ],
51 | "metadata": {
52 | "kernelspec": {
53 | "display_name": "DUTCode",
54 | "language": "python",
55 | "name": "python3"
56 | },
57 | "language_info": {
58 | "codemirror_mode": {
59 | "name": "ipython",
60 | "version": 3
61 | },
62 | "file_extension": ".py",
63 | "mimetype": "text/x-python",
64 | "name": "python",
65 | "nbconvert_exporter": "python",
66 | "pygments_lexer": "ipython3",
67 | "version": "3.9.16"
68 | },
69 | "orig_nbformat": 4
70 | },
71 | "nbformat": 4,
72 | "nbformat_minor": 2
73 | }
74 |
--------------------------------------------------------------------------------
/datagen.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 | import random
5 | import torch
6 |
7 | class Datagen:
8 | def __init__(self,shape = (256,256),txt_path = './trainlist.txt'):
9 | self.shape = shape
10 | with open(txt_path,'r') as f:
11 | self.trainlist = f.read().splitlines()
12 | def preprocess(self,img,gray = True):
13 | h,w = self.shape
14 | if gray:
15 | img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
16 | img = cv2.resize(img,(w,h))
17 | img = img / 255.0
18 | return img
19 | def __call__(self):
20 | self.trainlist = random.sample(self.trainlist, len(self.trainlist))
21 | for sample in self.trainlist:
22 | s_path = sample.split(',')[0]
23 | u_path = sample.split(',')[1]
24 | flow_path = sample.split(',')[2]
25 | feature_path = sample.split(',')[3]
26 | idx = int(sample.split(',')[4])
27 | s_cap = cv2.VideoCapture(s_path)
28 | u_cap = cv2.VideoCapture(u_path)
29 | seq1 = []
30 | seq2 = []
31 | s_cap.set(cv2.CAP_PROP_POS_FRAMES, idx - 33)
32 | for i in range(5,-1,-1):
33 | pos = 2 ** i + 1
34 | s_cap.set(cv2.CAP_PROP_POS_FRAMES, idx - pos)
35 | _,temp1 = s_cap.read()
36 | temp1 = self.preprocess(temp1) # -33
37 | temp1 = random_translation(temp1)
38 | seq2.append(temp1)
39 | _,temp2 = s_cap.read()
40 | temp2 = self.preprocess(temp2) # -32
41 | temp2 = random_translation(temp2)
42 | seq1.append(temp2)
43 | seq1 = np.array(seq1,dtype=np.uint8)
44 | seq2 = np.array(seq2,dtype=np.uint8)
45 | _,Igt = s_cap.read()
46 | Igt = self.preprocess(Igt, gray= False)
47 | u_cap.set(cv2.CAP_PROP_POS_FRAMES, idx - 1)
48 | _,It_prev = u_cap.read()
49 | It_prev = self.preprocess(It_prev, gray= False)
50 | _,It_curr = u_cap.read()
51 | It_curr = self.preprocess(It_curr, gray= False)
52 | flow = np.load(flow_path,mmap_mode='r')
53 | flo = torch.from_numpy(flow[idx - 1,...].copy()).permute(-1,0,1).float()
54 | features = np.load(feature_path,mmap_mode='r')
55 | features = torch.from_numpy(features[idx,...].copy()).float()
56 | seq1 = np.flip(seq1,axis = 0)
57 | seq1 = torch.from_numpy(seq1.copy()).float()
58 | seq2 = np.flip(seq2,axis = 0)
59 | seq2 = torch.from_numpy(seq2.copy()).float()
60 | Igt = torch.from_numpy(Igt).permute(-1,0,1).float()
61 | It_prev = torch.from_numpy(It_prev).permute(-1,0,1).float()
62 | It_curr = torch.from_numpy(It_curr).permute(-1,0,1).float()
63 |
64 | yield seq1, seq2, It_curr, Igt, It_prev, flo, features
65 |
66 | def random_translation(img):
67 | img = np.array(img)
68 | (h,w) = img.shape
69 | dx = np.random.randint(-w//8,w//8)
70 | dy = np.random.randint(-h//8,h//8)
71 | mat = np.array([[1,0,dx],[0,1,dy]],dtype=np.float32)
72 | img = cv2.warpAffine(img, mat, (w,h))
73 | return img
--------------------------------------------------------------------------------
/image_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | def dense_warp(image, flow):
7 | """
8 | Densely warps an image using optical flow.
9 |
10 | Args:
11 | image (torch.Tensor): Input image tensor of shape (batch_size, channels, height, width).
12 | flow (torch.Tensor): Optical flow tensor of shape (batch_size, 2, height, width).
13 |
14 | Returns:
15 | torch.Tensor: Warped image tensor of shape (batch_size, channels, height, width).
16 | """
17 | batch_size, channels, height, width = image.size()
18 |
19 | # Generate a grid of pixel coordinates based on the optical flow
20 | grid_y, grid_x = torch.meshgrid(torch.arange(height), torch.arange(width),indexing='ij')
21 | grid = torch.stack((grid_x, grid_y), dim=-1).to(image.device)
22 | grid = grid.unsqueeze(0).expand(batch_size, -1, -1, -1)
23 | new_grid = grid + flow.permute(0, 2, 3, 1)
24 |
25 | # Normalize the grid coordinates between -1 and 1
26 | new_grid /= torch.tensor([width - 1, height - 1], dtype=torch.float32, device=image.device)
27 | new_grid = new_grid * 2 - 1
28 | # Perform the dense warp using grid_sample
29 | warped_image = F.grid_sample(image, new_grid, align_corners=False)
30 |
31 | return warped_image
32 |
33 |
34 | def find_homography_numpy(src_points, dst_points):
35 | A = []
36 | B = []
37 | for src, dst in zip(src_points, dst_points):
38 | x, y = src
39 | x_prime, y_prime = dst
40 | A.append([-x, -y, -1, 0, 0, 0, x * x_prime, y * x_prime, x_prime])
41 | A.append([0, 0, 0, -x, -y, -1, x * y_prime, y * y_prime, y_prime])
42 | B.extend([-x_prime, -y_prime])
43 | A = np.array(A)
44 | B = np.array(B)
45 | ATA = np.dot(A.T, A)
46 | eigenvalues, eigenvectors = np.linalg.eigh(ATA)
47 | min_eigenvalue_index = np.argmin(eigenvalues)
48 | homography_vector = eigenvectors[:, min_eigenvalue_index]
49 | homography_vector /= homography_vector[-1]
50 | homography_matrix = np.reshape(homography_vector,(3, 3))
51 |
52 | return homography_matrix
53 |
54 | import torch
55 |
56 | def warp(img, mat):
57 | device = img.device
58 | mat = torch.cat([mat,torch.ones((mat.size(0),1),device = device)], axis = -1).view(-1,3,3)
59 | batch_size, channels, height, width = img.shape
60 | cy, cx = height // 2, width // 2
61 |
62 | # Compute the translation matrix to shift the center to the origin
63 | translation_matrix1 = torch.tensor([[1, 0, -cx],
64 | [0, 1, -cy],
65 | [0, 0, 1]], dtype=torch.float32, device=device)
66 | translation_matrix1 = translation_matrix1.repeat(batch_size, 1, 1)
67 |
68 | # Compute the translation matrix to shift the origin back to the center
69 | translation_matrix2 = torch.tensor([[1, 0, cx],
70 | [0, 1, cy],
71 | [0, 0, 1]], dtype=torch.float32, device=device)
72 | translation_matrix2 = translation_matrix2.repeat(batch_size, 1, 1)
73 | transformation_matrix = torch.matmul(translation_matrix2, torch.matmul(mat, translation_matrix1))
74 |
75 | # Compute the grid coordinates
76 | y_coords, x_coords = torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device))
77 | coords = torch.stack([x_coords, y_coords, torch.ones_like(x_coords)], dim=-1).float()
78 | coords = coords.view(1, -1, 3).repeat(batch_size, 1, 1)
79 |
80 | # Apply the transformation matrix to the grid coordinates
81 | transformed_coords = torch.matmul(coords, transformation_matrix.transpose(1, 2))
82 |
83 | # Normalize the transformed coordinates
84 | x_transformed = transformed_coords[:, :, 0] / transformed_coords[:, :, 2]
85 | y_transformed = transformed_coords[:, :, 1] / transformed_coords[:, :, 2]
86 |
87 | # Reshape the transformed coordinates to match the image size
88 | x_transformed = x_transformed.view(batch_size, height, width)
89 | y_transformed = y_transformed.view(batch_size, height, width)
90 |
91 | # Normalize the grid coordinates to the range [-1, 1]
92 | x_normalized = (x_transformed / (width - 1)) * 2 - 1
93 | y_normalized = (y_transformed / (height - 1)) * 2 - 1
94 |
95 | # Perform bilinear interpolation using grid_sample
96 | grid = torch.stack([x_normalized, y_normalized], dim=-1)
97 | warped_image = torch.nn.functional.grid_sample(img, grid, mode='bilinear', align_corners=False ,padding_mode='zeros')
98 |
99 | return warped_image
100 |
101 |
102 | def find_homography(src_points, dst_points):
103 | device = src_points.device
104 | A = []
105 | B = []
106 | # Convert input lists to PyTorch tensors
107 | src_points = torch.tensor(src_points, dtype=torch.float32)
108 | dst_points = torch.tensor(dst_points, dtype=torch.float32)
109 | for src, dst in zip(src_points, dst_points):
110 | x, y = src
111 | x_prime, y_prime = dst
112 | A.append([-x, -y, -1, 0, 0, 0, x * x_prime, y * x_prime, x_prime])
113 | A.append([0, 0, 0, -x, -y, -1, x * y_prime, y * y_prime, y_prime])
114 | B.extend([-x_prime, -y_prime])
115 | A = torch.tensor(A, dtype=torch.float32)
116 | B = torch.tensor(B, dtype=torch.float32)
117 | # Calculate ATA matrix
118 | ATA = torch.matmul(A.T, A)
119 | # Eigenvalue decomposition
120 | eigenvalues, eigenvectors = torch.linalg.eigh(ATA)
121 | # Find the index of the smallest eigenvalue
122 | min_eigenvalue_index = torch.argmin(eigenvalues)
123 | # Extract the corresponding eigenvector
124 | homography_vector = eigenvectors[:, min_eigenvalue_index]
125 | # Normalize homography vector
126 | homography_vector = homography_vector / homography_vector[-1]
127 | # Reshape to obtain the homography matrix
128 | homography_matrix = homography_vector.view(3, 3)
129 | return homography_matrix.to(device)
130 |
131 |
132 | def findHomography(grids, new_grids_loc):
133 | """
134 | @param: grids the location of origin grid vertices [2, H, W]
135 | @param: new_grids_loc the location of desired grid vertices [2, H, W]
136 |
137 | @return: homo_t homograph projection matrix for each grid [3, 3, H-1, W-1]
138 | """
139 |
140 | _, H, W = grids.shape
141 |
142 | new_grids = new_grids_loc.unsqueeze(0)
143 |
144 | Homo = torch.zeros(1, 3, 3, H-1, W-1).to(grids.device)
145 |
146 | grids = grids.unsqueeze(0)
147 |
148 | try:
149 | # for common cases if all the homograph can be calculated
150 | one = torch.ones_like(grids[:, 0:1, :-1, :-1], device=grids.device)
151 | zero = torch.zeros_like(grids[:, 1:2, :-1, :-1], device=grids.device)
152 |
153 | A = torch.cat([
154 | torch.stack([grids[:, 0:1, :-1, :-1], grids[:, 1:2, :-1, :-1], one, zero, zero, zero,
155 | -1 * grids[:, 0:1, :-1, :-1] * new_grids[:, 0:1, :-1, :-1], -1 * grids[:, 1:2, :-1, :-1] * new_grids[:, 0:1, :-1, :-1]], 2), # 1, 1, 8, h-1, w-1
156 | torch.stack([grids[:, 0:1, 1:, :-1], grids[:, 1:2, 1:, :-1], one, zero, zero, zero,
157 | -1 * grids[:, 0:1, 1:, :-1] * new_grids[:, 0:1, 1:, :-1], -1 * grids[:, 1:2, 1:, :-1] * new_grids[:, 0:1, 1:, :-1]], 2),
158 | torch.stack([grids[:, 0:1, :-1, 1:], grids[:, 1:2, :-1, 1:], one, zero, zero, zero,
159 | -1 * grids[:, 0:1, :-1, 1:] * new_grids[:, 0:1, :-1, 1:], -1 * grids[:, 1:2, :-1, 1:] * new_grids[:, 0:1, :-1, 1:]], 2),
160 | torch.stack([grids[:, 0:1, 1:, 1:], grids[:, 1:2, 1:, 1:], one, zero, zero, zero,
161 | -1 * grids[:, 0:1, 1:, 1:] * new_grids[:, 0:1, 1:, 1:], -1 * grids[:, 1:2, 1:, 1:] * new_grids[:, 0:1, 1:, 1:]], 2),
162 | torch.stack([zero, zero, zero, grids[:, 0:1, :-1, :-1], grids[:, 1:2, :-1, :-1], one,
163 | -1 * grids[:, 0:1, :-1, :-1] * new_grids[:, 1:2, :-1, :-1], -1 * grids[:, 1:2, :-1, :-1] * new_grids[:, 1:2, :-1, :-1]], 2),
164 | torch.stack([zero, zero, zero, grids[:, 0:1, 1:, :-1], grids[:, 1:2, 1:, :-1], one,
165 | -1 * grids[:, 0:1, 1:, :-1] * new_grids[:, 1:2, 1:, :-1], -1 * grids[:, 1:2, 1:, :-1] * new_grids[:, 1:2, 1:, :-1]], 2),
166 | torch.stack([zero, zero, zero, grids[:, 0:1, :-1, 1:], grids[:, 1:2, :-1, 1:], one,
167 | -1 * grids[:, 0:1, :-1, 1:] * new_grids[:, 1:2, :-1, 1:], -1 * grids[:, 1:2, :-1, 1:] * new_grids[:, 1:2, :-1, 1:]], 2),
168 | torch.stack([zero, zero, zero, grids[:, 0:1, 1:, 1:], grids[:, 1:2, 1:, 1:], one,
169 | -1 * grids[:, 0:1, 1:, 1:] * new_grids[:, 1:2, 1:, 1:], -1 * grids[:, 1:2, 1:, 1:] * new_grids[:, 1:2, 1:, 1:]], 2),
170 | ], 1).view(8, 8, -1).permute(2, 0, 1) # 1, 8, 8, h-1, w-1
171 | B_ = torch.stack([
172 | new_grids[:, 0, :-1, :-1],
173 | new_grids[:, 0, 1:, :-1],
174 | new_grids[:, 0, :-1, 1:],
175 | new_grids[:, 0, 1:, 1:],
176 | new_grids[:, 1, :-1, :-1],
177 | new_grids[:, 1, 1:, :-1],
178 | new_grids[:, 1, :-1, 1:],
179 | new_grids[:, 1, 1:, 1:],
180 | ], 1).view(8, -1).permute(1, 0) # B, 8, h-1, w-1 ==> A @ H = B ==> H = A^-1 @ B
181 | A_inverse = torch.inverse(A)
182 | # B, 8, 8 @ B, 8, 1 --> B, 8, 1
183 | H_recovered = torch.bmm(A_inverse, B_.unsqueeze(2))
184 |
185 | H_ = torch.cat([H_recovered, torch.ones_like(
186 | H_recovered[:, 0:1, :], device=H_recovered.device)], 1).view(H_recovered.shape[0], 3, 3)
187 |
188 | H_ = H_.permute(1, 2, 0)
189 | H_ = H_.view(Homo.shape)
190 | Homo = H_
191 | except:
192 | # if some of the homography can not be calculated
193 | one = torch.ones_like(grids[:, 0:1, 0, 0], device=grids.device)
194 | zero = torch.zeros_like(grids[:, 1:2, 0, 0], device=grids.device)
195 | H_ = torch.eye(3, device=grids.device)
196 | for i in range(H - 1):
197 | for j in range(W - 1):
198 | A = torch.cat([
199 | torch.stack([grids[:, 0:1, i, j], grids[:, 1:2, i, j], one, zero, zero, zero,
200 | -1 * grids[:, 0:1, i, j] * new_grids[:, 0:1, i, j], -1 * grids[:, 1:2, i, j] * new_grids[:, 0:1, i, j]], 2),
201 | torch.stack([grids[:, 0:1, i+1, j], grids[:, 1:2, i+1, j], one, zero, zero, zero,
202 | -1 * grids[:, 0:1, i+1, j] * new_grids[:, 0:1, i+1, j], -1 * grids[:, 1:2, i+1, j] * new_grids[:, 0:1, i+1, j]], 2),
203 | torch.stack([grids[:, 0:1, i, j+1], grids[:, 1:2, i, j+1], one, zero, zero, zero,
204 | -1 * grids[:, 0:1, i, j+1] * new_grids[:, 0:1, i, j+1], -1 * grids[:, 1:2, i, j+1] * new_grids[:, 0:1, i, j+1]], 2),
205 | torch.stack([grids[:, 0:1, i+1, j+1], grids[:, 1:2, i+1, j+1], one, zero, zero, zero,
206 | -1 * grids[:, 0:1, i+1, j+1] * new_grids[:, 0:1, i+1, j+1], -1 * grids[:, 1:2, i+1, j+1] * new_grids[:, 0:1, i+1, j+1]], 2),
207 | torch.stack([zero, zero, zero, grids[:, 0:1, i, j], grids[:, 1:2, i, j], one,
208 | -1 * grids[:, 0:1, i, j] * new_grids[:, 1:2, i, j], -1 * grids[:, 1:2, i, j] * new_grids[:, 1:2, i, j]], 2),
209 | torch.stack([zero, zero, zero, grids[:, 0:1, i+1, j], grids[:, 1:2, i+1, j], one,
210 | -1 * grids[:, 0:1, i+1, j] * new_grids[:, 1:2, i+1, j], -1 * grids[:, 1:2, i+1, j] * new_grids[:, 1:2, i+1, j]], 2),
211 | torch.stack([zero, zero, zero, grids[:, 0:1, i, j+1], grids[:, 1:2, i, j+1], one,
212 | -1 * grids[:, 0:1, i, j+1] * new_grids[:, 1:2, i, j+1], -1 * grids[:, 1:2, i, j+1] * new_grids[:, 1:2, i, j+1]], 2),
213 | torch.stack([zero, zero, zero, grids[:, 0:1, i+1, j+1], grids[:, 1:2, i+1, j+1], one,
214 | -1 * grids[:, 0:1, i+1, j+1] * new_grids[:, 1:2, i+1, j+1], -1 * grids[:, 1:2, i+1, j+1] * new_grids[:, 1:2, i+1, j+1]], 2),
215 | ], 1) # B, 8, 8
216 | B_ = torch.stack([
217 | new_grids[:, 0, i, j],
218 | new_grids[:, 0, i+1, j],
219 | new_grids[:, 0, i, j+1],
220 | new_grids[:, 0, i+1, j+1],
221 | new_grids[:, 1, i, j],
222 | new_grids[:, 1, i+1, j],
223 | new_grids[:, 1, i, j+1],
224 | new_grids[:, 1, i+1, j+1],
225 | ], 1) # B, 8 ==> A @ H = B ==> H = A^-1 @ B
226 | try:
227 | A_inverse = torch.inverse(A)
228 |
229 | # B, 8, 8 @ B, 8, 1 --> B, 8, 1
230 | H_recovered = torch.bmm(A_inverse, B_.unsqueeze(2))
231 |
232 | H_ = torch.cat([H_recovered, torch.ones_like(H_recovered[:, 0:1, :]).to(
233 | H_recovered.device)], 1).view(H_recovered.shape[0], 3, 3)
234 | except:
235 | pass
236 | Homo[:, :, :, i, j] = H_
237 |
238 | homo_t = Homo.view(3, 3, H-1, W-1)
239 |
240 | return homo_t
--------------------------------------------------------------------------------
/inference_future_frames.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import cv2\n",
11 | "import torch\n",
12 | "from time import time\n",
13 | "import os\n",
14 | "import datetime\n",
15 | "import matplotlib.pyplot as plt\n",
16 | "import torch.nn as nn\n",
17 | "import torchvision\n",
18 | "import torch.nn.functional as F\n",
19 | "device = 'cuda'\n",
20 | "batch_size = 1\n",
21 | "grid_h,grid_w = 15,15\n",
22 | "H,W = height,width = 360,640"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "def get_warp(net_out,img):\n",
32 | " '''\n",
33 | " Inputs:\n",
34 | " net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n",
35 | " img: image to warp\n",
36 | " '''\n",
37 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n",
38 | " torch.linspace(-1,1, grid_h + 1),\n",
39 | " indexing='ij')\n",
40 | " src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n",
41 | " new_grid = src_grid + 1 * net_out\n",
42 | " grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n",
43 | " warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=True)\n",
44 | " return warped"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 3,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "class StabNet(nn.Module):\n",
54 | " def __init__(self,trainable_layers = 10):\n",
55 | " super(StabNet, self).__init__()\n",
56 | " # Load the pre-trained ResNet model\n",
57 | " vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n",
58 | " # Extract conv1 pretrained weights for RGB input\n",
59 | " rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n",
60 | " # Calculate the average across the RGB channels\n",
61 | " tiled_rgb_weights = rgb_weights.repeat(1,5,1,1) \n",
62 | " # Change size of the first layer from 3 to 9 channels\n",
63 | " vgg19.features[0] = nn.Conv2d(15,64, kernel_size=3, stride=1, padding=1, bias=False)\n",
64 | " # set new weights\n",
65 | " vgg19.features[0].weight = nn.Parameter(tiled_rgb_weights)\n",
66 | " # Determine the total number of layers in the model\n",
67 | " total_layers = sum(1 for _ in vgg19.parameters())\n",
68 | " # Freeze the layers except the last 10\n",
69 | " for idx, param in enumerate(vgg19.parameters()):\n",
70 | " if idx > total_layers - trainable_layers:\n",
71 | " param.requires_grad = True\n",
72 | " else:\n",
73 | " param.requires_grad = False\n",
74 | " # Remove the last layer of ResNet\n",
75 | " self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n",
76 | " self.regressor = nn.Sequential(nn.Linear(512,2048),\n",
77 | " nn.ReLU(),\n",
78 | " nn.Linear(2048,1024),\n",
79 | " nn.ReLU(),\n",
80 | " nn.Linear(1024,512),\n",
81 | " nn.ReLU(),\n",
82 | " nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n",
83 | " #self.regressor[-1].bias.data.fill_(0)\n",
84 | " total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n",
85 | " total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n",
86 | " print(\"Total Trainable mobilenet Parameters: \", total_resnet_params)\n",
87 | " print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n",
88 | " print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n",
89 | " \n",
90 | " def forward(self, x_tensor):\n",
91 | " x_batch_size = x_tensor.size()[0]\n",
92 | " x = x_tensor[:, :3, :, :]\n",
93 | "\n",
94 | " # summary 1, dismiss now\n",
95 | " x_tensor = self.encoder(x_tensor)\n",
96 | " x_tensor = torch.mean(x_tensor, dim=[2, 3])\n",
97 | " x = self.regressor(x_tensor)\n",
98 | " x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n",
99 | " return x"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 4,
105 | "metadata": {},
106 | "outputs": [
107 | {
108 | "name": "stdout",
109 | "output_type": "stream",
110 | "text": [
111 | "Total Trainable mobilenet Parameters: 2360320\n",
112 | "Total Trainable regressor Parameters: 3936256\n",
113 | "Total Trainable parameters: 6296576\n",
114 | "loaded weights ./ckpts/with_future_frames/stabnet_2023-11-02_23-03-39.pth\n"
115 | ]
116 | }
117 | ],
118 | "source": [
119 | "ckpt_dir = './ckpts/with_future_frames/'\n",
120 | "stabnet = StabNet().to(device).eval()\n",
121 | "ckpts = os.listdir(ckpt_dir)\n",
122 | "if ckpts:\n",
123 | " ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n",
124 | " \n",
125 | " # Get the filename of the latest checkpoint\n",
126 | " latest = os.path.join(ckpt_dir, ckpts[0])\n",
127 | "\n",
128 | " state = torch.load(latest)\n",
129 | " stabnet.load_state_dict(state['model'])\n",
130 | " print('loaded weights',latest)"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 5,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "path = 'E:/Datasets/DeepStab_Dataset/unstable/2.avi'\n",
140 | "cap = cv2.VideoCapture(path)\n",
141 | "mean = np.array([0.485, 0.456, 0.406],dtype = np.float32) \n",
142 | "std = np.array([0.229, 0.224, 0.225],dtype = np.float32)\n",
143 | "frames = []\n",
144 | "while True:\n",
145 | " ret, img = cap.read()\n",
146 | " if not ret: break\n",
147 | " img = cv2.resize(img, (W,H))\n",
148 | " img = (img / 255.0).astype(np.float32)\n",
149 | " img = (img - mean)/std\n",
150 | " frames.append(img)\n",
151 | "frames = np.array(frames,dtype = np.float32)\n",
152 | "frame_count,_,_,_ = frames.shape"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 6,
158 | "metadata": {},
159 | "outputs": [
160 | {
161 | "name": "stdout",
162 | "output_type": "stream",
163 | "text": [
164 | "speed: 0.007669420583669505 seconds per frame\n"
165 | ]
166 | }
167 | ],
168 | "source": [
169 | "frames_tensor = torch.from_numpy(frames).permute(0,3,1,2).float().to('cpu')\n",
170 | "stable_frames_tensor = frames_tensor.clone()\n",
171 | "\n",
172 | "SKIP = 32\n",
173 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
174 | "def get_batch(idx):\n",
175 | " batch = torch.zeros((5,3,H,W)).float()\n",
176 | " for i,j in enumerate(range(idx - SKIP, idx + SKIP + 1, SKIP//2)):\n",
177 | " batch[i,...] = frames_tensor[j,...]\n",
178 | " batch = batch.view(1,-1,H,W)\n",
179 | " return batch.to(device)\n",
180 | "start = time()\n",
181 | "for frame_idx in range(SKIP,frame_count - SKIP):\n",
182 | " batch = get_batch(frame_idx)\n",
183 | " with torch.no_grad():\n",
184 | " transform = stabnet(batch)\n",
185 | " warped = get_warp(transform, frames_tensor[frame_idx: frame_idx + 1,...].cuda())\n",
186 | " stable_frames_tensor[frame_idx] = warped\n",
187 | " img = warped.permute(0,2,3,1)[0,...].cpu().detach().numpy()\n",
188 | " img *= std\n",
189 | " img += mean\n",
190 | " img = np.clip(img * 255.0,0,255).astype(np.uint8)\n",
191 | " cv2.imshow('window', img)\n",
192 | " if cv2.waitKey(1) & 0xFF == ord('q'):\n",
193 | " break\n",
194 | "cv2.destroyAllWindows()\n",
195 | "total = time() - start\n",
196 | "speed = total / frame_count\n",
197 | "print(f'speed: {speed} seconds per frame')"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": 18,
203 | "metadata": {},
204 | "outputs": [],
205 | "source": [
206 | "stable_frames = np.clip(((stable_frames_tensor.permute(0,2,3,1).numpy() * std) + mean) * 255,0,255).astype(np.uint8)\n",
207 | "frames = np.clip(((frames_tensor.permute(0,2,3,1).numpy() * std) + mean) * 255,0,255).astype(np.uint8)"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 19,
213 | "metadata": {},
214 | "outputs": [],
215 | "source": [
216 | "from time import sleep\n",
217 | "fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
218 | "out = cv2.VideoWriter('2.avi', fourcc, 30.0, (W,H))\n",
219 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
220 | "for idx in range(frame_count):\n",
221 | " img = stable_frames[idx,...]\n",
222 | " out.write(img)\n",
223 | " cv2.imshow('window',img)\n",
224 | " #sleep(1/30)\n",
225 | " if cv2.waitKey(1) & 0xFF == ord(' '):\n",
226 | " break\n",
227 | "out.release()\n",
228 | "cv2.destroyAllWindows()"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": 20,
234 | "metadata": {},
235 | "outputs": [
236 | {
237 | "name": "stdout",
238 | "output_type": "stream",
239 | "text": [
240 | "Frame: 446/447\n",
241 | "cropping score:0.996\tdistortion score:0.982\tstability:0.666\tpixel:0.997\n"
242 | ]
243 | },
244 | {
245 | "data": {
246 | "text/plain": [
247 | "(0.9962942926265075, 0.9820633, 0.66624937150459, 0.9968942715786397)"
248 | ]
249 | },
250 | "execution_count": 20,
251 | "metadata": {},
252 | "output_type": "execute_result"
253 | }
254 | ],
255 | "source": [
256 | "from metrics import metric\n",
257 | "metric('E:/Datasets/DeepStab_Dataset/unstable/2.avi','2.avi')"
258 | ]
259 | }
260 | ],
261 | "metadata": {
262 | "kernelspec": {
263 | "display_name": "DUTCode",
264 | "language": "python",
265 | "name": "python3"
266 | },
267 | "language_info": {
268 | "codemirror_mode": {
269 | "name": "ipython",
270 | "version": 3
271 | },
272 | "file_extension": ".py",
273 | "mimetype": "text/x-python",
274 | "name": "python",
275 | "nbconvert_exporter": "python",
276 | "pygments_lexer": "ipython3",
277 | "version": "3.9.18"
278 | },
279 | "orig_nbformat": 4
280 | },
281 | "nbformat": 4,
282 | "nbformat_minor": 2
283 | }
284 |
--------------------------------------------------------------------------------
/inference_online.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 4,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import cv2\n",
11 | "import torch\n",
12 | "from time import time\n",
13 | "import os\n",
14 | "import datetime\n",
15 | "import pandas as pd\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "import torch.nn as nn\n",
18 | "import torchvision\n",
19 | "import torch.nn.functional as F\n",
20 | "from torch.utils import data\n",
21 | "from image_utils import dense_warp, warp\n",
22 | "device = 'cuda'\n",
23 | "height,width = 360,640\n",
24 | "batch_size = 1\n",
25 | "grid_h,grid_w = 15,15"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 5,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "def get_warp(net_out,img):\n",
35 | " '''\n",
36 | " Inputs:\n",
37 | " net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n",
38 | " img: image to warp\n",
39 | " '''\n",
40 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n",
41 | " torch.linspace(-1,1, grid_h + 1),\n",
42 | " indexing='ij')\n",
43 | " src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n",
44 | " new_grid = src_grid + net_out\n",
45 | " grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n",
46 | " warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=False,padding_mode='zeros')\n",
47 | " return warped"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 3,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "class StabNet(nn.Module):\n",
57 | " def __init__(self,trainable_layers = 10):\n",
58 | " super(StabNet, self).__init__()\n",
59 | " # Load the pre-trained ResNet model\n",
60 | " vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n",
61 | " # Extract conv1 pretrained weights for RGB input\n",
62 | " rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n",
63 | " # Calculate the average across the RGB channels\n",
64 | " average_rgb_weights = torch.mean(rgb_weights, dim=1, keepdim=True).repeat(1,6,1,1) #torch.Size([64, 5, 7, 7])\n",
65 | " # Change size of the first layer from 3 to 9 channels\n",
66 | " vgg19.features[0] = nn.Conv2d(9,64, kernel_size=3, stride=1, padding=1, bias=False)\n",
67 | " # set new weights\n",
68 | " new_weights = torch.cat((rgb_weights, average_rgb_weights), dim=1)\n",
69 | " vgg19.features[0].weight = nn.Parameter(new_weights)\n",
70 | " # Determine the total number of layers in the model\n",
71 | " total_layers = sum(1 for _ in vgg19.parameters())\n",
72 | " # Freeze the layers except the last 10\n",
73 | " for idx, param in enumerate(vgg19.parameters()):\n",
74 | " if idx > total_layers - trainable_layers:\n",
75 | " param.requires_grad = True\n",
76 | " else:\n",
77 | " param.requires_grad = False\n",
78 | " # Remove the last layer of ResNet\n",
79 | " self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n",
80 | " self.regressor = nn.Sequential(nn.Linear(512,2048),\n",
81 | " nn.ReLU(),\n",
82 | " nn.Linear(2048,1024),\n",
83 | " nn.ReLU(),\n",
84 | " nn.Linear(1024,512),\n",
85 | " nn.ReLU(),\n",
86 | " nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n",
87 | " total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n",
88 | " total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n",
89 | " print(\"Total Trainable encoder Parameters: \", total_resnet_params)\n",
90 | " print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n",
91 | " print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n",
92 | " \n",
93 | " def forward(self, x_tensor):\n",
94 | " x_batch_size = x_tensor.size()[0]\n",
95 | " x = x_tensor[:, :3, :, :]\n",
96 | "\n",
97 | " # summary 1, dismiss now\n",
98 | " x_tensor = self.encoder(x_tensor)\n",
99 | " x_tensor = torch.mean(x_tensor, dim=[2, 3])\n",
100 | " x = self.regressor(x_tensor)\n",
101 | " x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n",
102 | " return x"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 10,
108 | "metadata": {},
109 | "outputs": [
110 | {
111 | "name": "stdout",
112 | "output_type": "stream",
113 | "text": [
114 | "Total Trainable encoder Parameters: 2360320\n",
115 | "Total Trainable regressor Parameters: 3936256\n",
116 | "Total Trainable parameters: 6296576\n",
117 | "loaded weights ./ckpts/original/stabnet_2023-10-26_13-42-14.pth\n"
118 | ]
119 | }
120 | ],
121 | "source": [
122 | "ckpt_dir = './ckpts/original/'\n",
123 | "stabnet = StabNet().to(device).eval()\n",
124 | "ckpts = os.listdir(ckpt_dir)\n",
125 | "if ckpts:\n",
126 | " ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n",
127 | " \n",
128 | " # Get the filename of the latest checkpoint\n",
129 | " latest = os.path.join(ckpt_dir, ckpts[0])\n",
130 | "\n",
131 | " state = torch.load(latest)\n",
132 | " stabnet.load_state_dict(state['model'])\n",
133 | " print('loaded weights',latest)"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": 7,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "path = 'E:/Datasets/DeepStab_Dataset/unstable/2.avi'\n",
143 | "cap = cv2.VideoCapture(path)\n",
144 | "frames = []\n",
145 | "while True:\n",
146 | " ret,frame = cap.read()\n",
147 | " if not ret : break\n",
148 | " frame = cv2.resize(frame,(width,height))\n",
149 | " frames.append(frame)\n",
150 | "frames = np.array(frames)"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 8,
156 | "metadata": {},
157 | "outputs": [
158 | {
159 | "data": {
160 | "text/plain": [
161 | "torch.Size([447, 3, 360, 640])"
162 | ]
163 | },
164 | "execution_count": 8,
165 | "metadata": {},
166 | "output_type": "execute_result"
167 | }
168 | ],
169 | "source": [
170 | "frames_t = torch.from_numpy(frames/255.0).permute(0,3,1,2).float()\n",
171 | "frames_t.shape"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 9,
177 | "metadata": {},
178 | "outputs": [
179 | {
180 | "ename": "KeyboardInterrupt",
181 | "evalue": "",
182 | "output_type": "error",
183 | "traceback": [
184 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
185 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
186 | "Cell \u001b[1;32mIn[9], line 14\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m 13\u001b[0m trasnform \u001b[38;5;241m=\u001b[39m stabnet(net_in)\n\u001b[1;32m---> 14\u001b[0m warped \u001b[38;5;241m=\u001b[39m get_warp(trasnform \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m1\u001b[39m ,\u001b[43mcurr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m 15\u001b[0m warped_frames[idx:idx\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m] \u001b[38;5;241m=\u001b[39m warped\u001b[38;5;241m.\u001b[39mcpu()\n\u001b[0;32m 16\u001b[0m warped_gray \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmean(warped,dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m,keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
187 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
188 | ]
189 | }
190 | ],
191 | "source": [
192 | "num_frames,_,h,w = frames_t.shape\n",
193 | "warped_frames = frames_t.clone()\n",
194 | "buffer = torch.zeros((6,1,h,w)).float()\n",
195 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
196 | "start = time()\n",
197 | "for iter in range(1):\n",
198 | " for idx in range(33,num_frames):\n",
199 | " for i in range(6):\n",
200 | " buffer[i,...] = torch.mean(warped_frames[idx - 2**i,...],dim = 0,keepdim = True)\n",
201 | " curr = warped_frames[idx:idx+1,...] \n",
202 | " net_in = torch.cat([curr,buffer.permute(1,0,2,3)], dim = 1).to(device)\n",
203 | " with torch.no_grad():\n",
204 | " trasnform = stabnet(net_in)\n",
205 | " warped = get_warp(trasnform * 1 ,curr.to(device))\n",
206 | " warped_frames[idx:idx+1,...] = warped.cpu()\n",
207 | " warped_gray = torch.mean(warped,dim = 1,keepdim=True)\n",
208 | " buffer = torch.roll(buffer, shifts= 1, dims=1)\n",
209 | " buffer[:,:1,:,:] = warped_gray\n",
210 | " img = warped_frames[idx,...].permute(1,2,0).numpy()\n",
211 | " img = (img * 255).astype(np.uint8)\n",
212 | " cv2.imshow('window',img)\n",
213 | " if cv2.waitKey(1) & 0xFF == ord(' '):\n",
214 | " break\n",
215 | "cv2.destroyAllWindows()\n",
216 | "total = time() - start\n",
217 | "speed = total / num_frames\n",
218 | "print(f'speed: {speed} seconds per frame') "
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": 17,
224 | "metadata": {},
225 | "outputs": [],
226 | "source": [
227 | "from time import sleep\n",
228 | "fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
229 | "out = cv2.VideoWriter('./results/2.avi', fourcc, 30.0, (256,256))\n",
230 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
231 | "for idx in range(num_frames):\n",
232 | " img = warped_frames[idx,...].permute(1,2,0).numpy()\n",
233 | " img = (img * 255).astype(np.uint8)\n",
234 | " diff = cv2.absdiff(img,frames[idx,...])\n",
235 | " out.write(img)\n",
236 | " cv2.imshow('window',img)\n",
237 | " sleep(1/30)\n",
238 | " if cv2.waitKey(1) & 0xFF == ord(' '):\n",
239 | " break\n",
240 | "cv2.destroyAllWindows()\n",
241 | "out.release()"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": null,
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
251 | "for idx in range(num_frames):\n",
252 | " img = warped_frames[idx,...].permute(1,2,0).numpy()\n",
253 | " img = (img * 255).astype(np.uint8)\n",
254 | " diff = cv2.absdiff(img,frames[idx,...])\n",
255 | " cv2.imshow('window',diff)\n",
256 | " sleep(1/30)\n",
257 | " if cv2.waitKey(1) & 0xFF == ord(' '):\n",
258 | " break\n",
259 | "cv2.destroyAllWindows()"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": 18,
265 | "metadata": {},
266 | "outputs": [
267 | {
268 | "name": "stdout",
269 | "output_type": "stream",
270 | "text": [
271 | "Frame: 446/447\n",
272 | "cropping score:1.000\tdistortion score:0.989\tstability:0.639\tpixel:0.997\n"
273 | ]
274 | },
275 | {
276 | "data": {
277 | "text/plain": [
278 | "(1.0, 0.98863894, 0.6389072784994596, 0.99749401723966)"
279 | ]
280 | },
281 | "execution_count": 18,
282 | "metadata": {},
283 | "output_type": "execute_result"
284 | }
285 | ],
286 | "source": [
287 | "from metrics import metric\n",
288 | "metric('E:/Datasets/DeepStab_Dataset/unstable/2.avi','./results/Regular_2.avi')"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "metadata": {},
295 | "outputs": [],
296 | "source": []
297 | }
298 | ],
299 | "metadata": {
300 | "kernelspec": {
301 | "display_name": "DUTCode",
302 | "language": "python",
303 | "name": "python3"
304 | },
305 | "language_info": {
306 | "codemirror_mode": {
307 | "name": "ipython",
308 | "version": 3
309 | },
310 | "file_extension": ".py",
311 | "mimetype": "text/x-python",
312 | "name": "python",
313 | "nbconvert_exporter": "python",
314 | "pygments_lexer": "ipython3",
315 | "version": "3.9.18"
316 | },
317 | "orig_nbformat": 4
318 | },
319 | "nbformat": 4,
320 | "nbformat_minor": 2
321 | }
322 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import sys
4 |
5 | def metric(original_video_path, pred_video_path,shape=(128,128)):
6 | ''' Inputs:
7 | path1: path to ground truth stable video
8 | path2: path to generated stable video
9 | Outputs:
10 | cropping_score
11 | distortion_score
12 | pixel_score
13 | stability_score
14 | '''
15 | # Create brute-force matcher object
16 | sys.stdout.flush()
17 | bf = cv2.BFMatcher()
18 | sift = cv2.SIFT_create()
19 |
20 | # Apply the homography transformation if we have enough good matches
21 | MIN_MATCH_COUNT = 10
22 | ratio = 0.7
23 | thresh = 5.0
24 |
25 | CR_seq = []
26 | DV_seq = []
27 | Pt = np.eye(3)
28 | P_seq = []
29 | pixel_loss = []
30 |
31 | # Video loading
32 | H,W = shape
33 | #load both videos
34 | cap1 = cv2.VideoCapture(original_video_path)
35 | cap2 = cv2.VideoCapture(pred_video_path)
36 | frame_count = min(int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),int(cap2.get(cv2.CAP_PROP_FRAME_COUNT)))
37 | original_frames = np.zeros((frame_count,H,W,3),np.uint8)
38 | pred_frames = np.zeros_like(original_frames)
39 | for i in range(frame_count):
40 | ret1,img1 = cap1.read()
41 | ret2,img2 = cap2.read()
42 | if not ret1 or not ret2:
43 | break
44 | img1 = cv2.resize(img1,(W,H))
45 | img2 = cv2.resize(img2,(W,H))
46 | original_frames[i,...] = img1
47 | pred_frames[i,...] = img2
48 |
49 | for i in range(frame_count):
50 | img1 = original_frames[i,...]
51 | img1o = pred_frames[i,...]
52 |
53 | # Convert frames to grayscale
54 | a = (img1 / 255.0).astype(np.float32)
55 | b = (img1o/ 255.0).astype(np.float32)
56 | img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
57 | img1o = cv2.cvtColor(img1o, cv2.COLOR_BGR2GRAY)
58 | pixel_loss.append(np.mean((a-b)**2))
59 | # Detect the SIFT key points and compute the descriptors for the two images
60 | keyPoints1, descriptors1 = sift.detectAndCompute(img1, None)
61 | keyPoints1o, descriptors1o = sift.detectAndCompute(img1o, None)
62 |
63 | # Match the descriptors
64 | matches = bf.knnMatch(descriptors1, descriptors1o, k=2)
65 |
66 | # Select the good matches using the ratio test
67 | goodMatches = []
68 |
69 | for m, n in matches:
70 | if m.distance < ratio * n.distance:
71 | goodMatches.append(m)
72 |
73 | if len(goodMatches) > MIN_MATCH_COUNT:
74 | # Get the good key points positions
75 | sourcePoints = np.float32([keyPoints1[m.queryIdx].pt for m in goodMatches]).reshape(-1, 1, 2)
76 | destinationPoints = np.float32([keyPoints1o[m.trainIdx].pt for m in goodMatches]).reshape(-1, 1, 2)
77 |
78 | # Obtain the homography matrix
79 | M, _ = cv2.findHomography(sourcePoints, destinationPoints, method=cv2.RANSAC, ransacReprojThreshold=thresh)
80 |
81 | # Obtain Scale, Translation, Rotation, Distortion value
82 | scaleRecovered = np.sqrt(M[0, 1] ** 2 + M[0, 0] ** 2)
83 |
84 | eigenvalues = np.abs(np.linalg.eigvals(M[0:2, 0:2]))
85 | eigenvalues = sorted(eigenvalues,reverse= True)
86 | DV = (eigenvalues[1] / eigenvalues[0]).astype(np.float32)
87 |
88 | CR_seq.append(1 / scaleRecovered)
89 | DV_seq.append(DV)
90 |
91 | # For Stability score calculation
92 | if i + 1 < frame_count:
93 | img2o = pred_frames[i+1,...]
94 | # Convert frame to grayscale
95 | img2o = cv2.cvtColor(img2o, cv2.COLOR_BGR2GRAY)
96 |
97 | keyPoints2o, descriptors2o = sift.detectAndCompute(img2o, None)
98 | matches = bf.knnMatch(descriptors1o, descriptors2o, k=2)
99 | goodMatches = []
100 |
101 | for m, n in matches:
102 | if m.distance < ratio * n.distance:
103 | goodMatches.append(m)
104 |
105 | if len(goodMatches) > MIN_MATCH_COUNT:
106 | # Get the good key points positions
107 | sourcePoints = np.float32([keyPoints1o[m.queryIdx].pt for m in goodMatches]).reshape(-1, 1, 2)
108 | destinationPoints = np.float32([keyPoints2o[m.trainIdx].pt for m in goodMatches]).reshape(-1, 1, 2)
109 |
110 | # Obtain the homography matrix
111 | M, _ = cv2.findHomography(sourcePoints, destinationPoints, method=cv2.RANSAC, ransacReprojThreshold=thresh)
112 |
113 | P_seq.append(np.matmul(Pt, M))
114 | Pt = np.matmul(Pt, M)
115 |
116 | sys.stdout.write('\rFrame: ' + str(i) + '/' + str(frame_count))
117 |
118 | cap1.release()
119 | cap2.release()
120 |
121 | # Make 1D temporal signals
122 | P_seq_t = []
123 | P_seq_r = []
124 |
125 | for Mp in P_seq:
126 | transRecovered = np.sqrt(Mp[0, 2] ** 2 + Mp[1, 2] ** 2)
127 | thetaRecovered = np.arctan2(Mp[1, 0], Mp[0, 0]) * 180 / np.pi
128 | P_seq_t.append(transRecovered)
129 | P_seq_r.append(thetaRecovered)
130 |
131 | # FFT
132 | fft_t = np.fft.fft(P_seq_t)
133 | fft_r = np.fft.fft(P_seq_r)
134 | fft_t = np.abs(fft_t) ** 2
135 | fft_r = np.abs(fft_r) ** 2
136 |
137 | fft_t = np.delete(fft_t, 0)
138 | fft_r = np.delete(fft_r, 0)
139 | fft_t = fft_t[:len(fft_t) // 2]
140 | fft_r = fft_r[:len(fft_r) // 2]
141 |
142 | SS_t = np.sum(fft_t[:5]) / np.sum(fft_t)
143 | SS_r = np.sum(fft_r[:5]) / np.sum(fft_r)
144 |
145 | cropping_score = np.min([np.mean(CR_seq), 1])
146 | distortion_score = np.min(DV_seq)
147 | stability_score = (SS_t+SS_r)/2
148 | pixel_score = 1 - np.mean(pixel_loss)
149 | out = f'\ncropping score:{cropping_score:.3f}\tdistortion score:{distortion_score:.3f}\tstability:{stability_score:.3f}\tpixel:{pixel_score:.3f}\n'
150 | sys.stdout.write(out)
151 | return cropping_score, distortion_score, stability_score, pixel_score
152 |
153 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1
2 | torchvision==0.14.1
3 | opencv-python==4.10.0.84
4 | matplotlib==3.9.2
5 | pandas==2.2.3
6 | numpy==1.26.4
7 | scikit-learn==1.5.2
8 | scipy==1.13.1
--------------------------------------------------------------------------------
/stabilize_future_frames.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import argparse
4 | from time import time
5 | import os
6 | import datetime
7 | import matplotlib.pyplot as plt
8 | import torch
9 | import torch.nn as nn
10 | import torchvision
11 | import torch.nn.functional as F
12 |
13 | device = 'cuda'
14 | batch_size = 1
15 | grid_h,grid_w = 15,15
16 | H,W = height,width = 360,640
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser(description='Video Stabilization using StabNet')
21 | parser.add_argument('--in_path', type=str, help='Input video file path')
22 | parser.add_argument('--out_path', type=str, help='Output stabilized video file path')
23 | return parser.parse_args()
24 |
25 | def get_warp(net_out,img):
26 | '''
27 | Inputs:
28 | net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])
29 | img: image to warp
30 | '''
31 | grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),
32 | torch.linspace(-1,1, grid_h + 1),
33 | indexing='ij')
34 | src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)
35 | new_grid = src_grid + 1 * net_out
36 | grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)
37 | warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=False,padding_mode='zeros')
38 | return warped
39 |
40 | def save_video(frames, path):
41 | frame_count,h,w,_ = frames.shape
42 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
43 | out = cv2.VideoWriter(path, fourcc, 30.0, (w,h))
44 | for idx in range(frame_count):
45 | out.write(frames[idx,...])
46 | out.release()
47 |
48 | class StabNet(nn.Module):
49 | def __init__(self,trainable_layers = 10):
50 | super(StabNet, self).__init__()
51 | # Load the pre-trained ResNet model
52 | vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')
53 | # Extract conv1 pretrained weights for RGB input
54 | rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])
55 | # Calculate the average across the RGB channels
56 | tiled_rgb_weights = rgb_weights.repeat(1,5,1,1)
57 | # Change size of the first layer from 3 to 9 channels
58 | vgg19.features[0] = nn.Conv2d(15,64, kernel_size=3, stride=1, padding=1, bias=False)
59 | # set new weights
60 | vgg19.features[0].weight = nn.Parameter(tiled_rgb_weights)
61 | # Determine the total number of layers in the model
62 | total_layers = sum(1 for _ in vgg19.parameters())
63 | # Freeze the layers except the last 10
64 | for idx, param in enumerate(vgg19.parameters()):
65 | if idx > total_layers - trainable_layers:
66 | param.requires_grad = True
67 | else:
68 | param.requires_grad = False
69 | # Remove the last layer of ResNet
70 | self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])
71 | self.regressor = nn.Sequential(nn.Linear(512,2048),
72 | nn.ReLU(),
73 | nn.Linear(2048,1024),
74 | nn.ReLU(),
75 | nn.Linear(1024,512),
76 | nn.ReLU(),
77 | nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))
78 | #self.regressor[-1].bias.data.fill_(0)
79 | total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)
80 | total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)
81 | print("Total Trainable encoder Parameters: ", total_resnet_params)
82 | print("Total Trainable regressor Parameters: ", total_regressor_params)
83 | print("Total Trainable parameters:",total_regressor_params + total_resnet_params)
84 |
85 | def forward(self, x_tensor):
86 | x_batch_size = x_tensor.size()[0]
87 | x = x_tensor[:, :3, :, :]
88 |
89 | # summary 1, dismiss now
90 | x_tensor = self.encoder(x_tensor)
91 | x_tensor = torch.mean(x_tensor, dim=[2, 3])
92 | x = self.regressor(x_tensor)
93 | x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)
94 | return x
95 |
96 | def stabilize(in_path,out_path):
97 |
98 | if not os.path.exists(in_path):
99 | print(f"The input file '{in_path}' does not exist.")
100 | exit()
101 | _,ext = os.path.splitext(in_path)
102 | if ext not in ['.mp4','.avi']:
103 | print(f"The input file '{in_path}' is not a supported video file (only .mp4 and .avi are supported).")
104 | exit()
105 |
106 | #Load frames and stardardize
107 | cap = cv2.VideoCapture(in_path)
108 | mean = np.array([0.485, 0.456, 0.406],dtype = np.float32)
109 | std = np.array([0.229, 0.224, 0.225],dtype = np.float32)
110 | frames = []
111 | while True:
112 | ret, img = cap.read()
113 | if not ret: break
114 | img = cv2.resize(img, (W,H))
115 | img = (img / 255.0).astype(np.float32)
116 | img = (img - mean)/std
117 | frames.append(img)
118 | frames = np.array(frames,dtype = np.float32)
119 | frame_count,_,_,_ = frames.shape
120 |
121 | # stabilize video
122 | frames_tensor = torch.from_numpy(frames).permute(0,3,1,2).float().to('cpu')
123 | stable_frames_tensor = frames_tensor.clone()
124 | SKIP = 16
125 | cv2.namedWindow('window',cv2.WINDOW_NORMAL)
126 | def get_batch(idx):
127 | batch = torch.zeros((5,3,H,W)).float()
128 | for i,j in enumerate(range(idx - SKIP, idx + SKIP + 1, SKIP//2)):
129 | batch[i,...] = frames_tensor[j,...]
130 | batch = batch.view(1,-1,H,W)
131 | return batch.to(device)
132 |
133 | for frame_idx in range(SKIP,frame_count - SKIP):
134 | batch = get_batch(frame_idx)
135 | with torch.no_grad():
136 | transform = stabnet(batch)
137 | warped = get_warp(transform, frames_tensor[frame_idx: frame_idx + 1,...].cuda())
138 | stable_frames_tensor[frame_idx] = warped
139 | img = warped.permute(0,2,3,1)[0,...].cpu().detach().numpy()
140 | img *= std
141 | img += mean
142 | img = np.clip(img * 255.0,0,255).astype(np.uint8)
143 | cv2.imshow('window', img)
144 | if cv2.waitKey(1) & 0xFF == ord('q'):
145 | break
146 | cv2.destroyAllWindows()
147 |
148 | #undo standardization
149 | stable_frames = np.clip(((stable_frames_tensor.permute(0,2,3,1).numpy() * std) + mean) * 255,0,255).astype(np.uint8)
150 | frames = np.clip(((frames_tensor.permute(0,2,3,1).numpy() * std) + mean) * 255,0,255).astype(np.uint8)
151 | save_video(stable_frames,out_path)
152 |
153 |
154 | if __name__ == '__main__':
155 | args = parse_args()
156 | ckpt_dir = './ckpts/with_future_frames/'
157 | stabnet = StabNet().to(device).eval()
158 | ckpts = os.listdir(ckpt_dir)
159 | if ckpts:
160 | ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], "%H-%M-%S"), reverse=True)
161 | latest = os.path.join(ckpt_dir, ckpts[0])
162 | state = torch.load(latest)
163 | stabnet.load_state_dict(state['model'])
164 | print('Loaded StabNet',latest)
165 | stabilize(args.in_path, args.out_path)
--------------------------------------------------------------------------------
/stabilize_online.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import argparse
4 | from time import time
5 | import os
6 | import datetime
7 | import matplotlib.pyplot as plt
8 | import torch
9 | import torch.nn as nn
10 | import torchvision
11 | import torch.nn.functional as F
12 |
13 | device = 'cuda'
14 | batch_size = 1
15 | grid_h,grid_w = 15,15
16 | H,W = height,width = 360,640
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser(description='Video Stabilization using StabNet')
21 | parser.add_argument('--in_path', type=str, help='Input video file path')
22 | parser.add_argument('--out_path', type=str, help='Output stabilized video file path')
23 | return parser.parse_args()
24 |
25 | def get_warp(net_out,img):
26 | '''
27 | Inputs:
28 | net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])
29 | img: image to warp
30 | '''
31 | grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),
32 | torch.linspace(-1,1, grid_h + 1),
33 | indexing='ij')
34 | src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)
35 | new_grid = src_grid + 1 * net_out
36 | grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)
37 | warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=False,padding_mode='zeros')
38 | return warped
39 |
40 | def save_video(frames, path):
41 | frame_count,h,w,_ = frames.shape
42 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
43 | out = cv2.VideoWriter(path, fourcc, 30.0, (w,h))
44 | for idx in range(frame_count):
45 | out.write(frames[idx,...])
46 | out.release()
47 |
48 | class StabNet(nn.Module):
49 | def __init__(self,trainable_layers = 10):
50 | super(StabNet, self).__init__()
51 | # Load the pre-trained ResNet model
52 | vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')
53 | # Extract conv1 pretrained weights for RGB input
54 | rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])
55 | # Calculate the average across the RGB channels
56 | average_rgb_weights = torch.mean(rgb_weights, dim=1, keepdim=True).repeat(1,6,1,1) #torch.Size([64, 5, 7, 7])
57 | # Change size of the first layer from 3 to 9 channels
58 | vgg19.features[0] = nn.Conv2d(9,64, kernel_size=3, stride=1, padding=1, bias=False)
59 | # set new weights
60 | new_weights = torch.cat((rgb_weights, average_rgb_weights), dim=1)
61 | vgg19.features[0].weight = nn.Parameter(new_weights)
62 | # Determine the total number of layers in the model
63 | total_layers = sum(1 for _ in vgg19.parameters())
64 | # Freeze the layers except the last 10
65 | for idx, param in enumerate(vgg19.parameters()):
66 | if idx > total_layers - trainable_layers:
67 | param.requires_grad = True
68 | else:
69 | param.requires_grad = False
70 | # Remove the last layer of ResNet
71 | self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])
72 | self.regressor = nn.Sequential(nn.Linear(512,2048),
73 | nn.ReLU(),
74 | nn.Linear(2048,1024),
75 | nn.ReLU(),
76 | nn.Linear(1024,512),
77 | nn.ReLU(),
78 | nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))
79 | total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)
80 | total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)
81 | print("Total Trainable encoder Parameters: ", total_resnet_params)
82 | print("Total Trainable regressor Parameters: ", total_regressor_params)
83 | print("Total Trainable parameters:",total_regressor_params + total_resnet_params)
84 |
85 | def forward(self, x_tensor):
86 | x_batch_size = x_tensor.size()[0]
87 | x = x_tensor[:, :3, :, :]
88 |
89 | # summary 1, dismiss now
90 | x_tensor = self.encoder(x_tensor)
91 | x_tensor = torch.mean(x_tensor, dim=[2, 3])
92 | x = self.regressor(x_tensor)
93 | x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)
94 | return x
95 |
96 | def stabilize(in_path,out_path):
97 |
98 | if not os.path.exists(in_path):
99 | print(f"The input file '{in_path}' does not exist.")
100 | exit()
101 | _,ext = os.path.splitext(in_path)
102 | if ext not in ['.mp4','.avi']:
103 | print(f"The input file '{in_path}' is not a supported video file (only .mp4 and .avi are supported).")
104 | exit()
105 |
106 | #Load frames and stardardize
107 | cap = cv2.VideoCapture(in_path)
108 | frames = []
109 | while True:
110 | ret,frame = cap.read()
111 | if not ret : break
112 | frame = cv2.resize(frame,(256,256))
113 | frames.append(frame)
114 | frames = np.array(frames)
115 | frame_count,_,_,_ = frames.shape
116 |
117 | # stabilize video
118 | frames_tensor = torch.from_numpy(frames/255.0).permute(0,3,1,2).float()
119 | stable_frames = frames_tensor.clone()
120 | buffer = torch.zeros((6,1,H,W)).float()
121 | cv2.namedWindow('window',cv2.WINDOW_NORMAL)
122 | for idx in range(33,frame_count):
123 | for i in range(6):
124 | buffer[i,...] = torch.mean(stable_frames[idx - 2**i,...],dim = 0,keepdim = True)
125 | curr = stable_frames[idx:idx+1,...]
126 | net_in = torch.cat([curr,buffer.permute(1,0,2,3)], dim = 1).to(device)
127 | with torch.no_grad():
128 | trasnform = stabnet(net_in)
129 | warped = get_warp(trasnform * 0.5 ,curr.to(device))
130 | stable_frames[idx:idx+1,...] = warped.cpu()
131 | warped_gray = torch.mean(warped,dim = 1,keepdim=True)
132 | buffer = torch.roll(buffer, shifts= 1, dims=1)
133 | buffer[:,:1,:,:] = warped_gray
134 | img = stable_frames[idx,...].permute(1,2,0).numpy()
135 | img = (img * 255).astype(np.uint8)
136 | cv2.imshow('window',img)
137 | if cv2.waitKey(1) & 0xFF == ord(' '):
138 | break
139 | cv2.destroyAllWindows()
140 |
141 | #undo standardization
142 | stable_frames = np.clip(stable_frames.permute(0,2,3,1).numpy() * 255,0,255).astype(np.uint8)
143 | save_video(stable_frames,out_path)
144 |
145 |
146 | if __name__ == '__main__':
147 | args = parse_args()
148 | ckpt_dir = './ckpts/original/'
149 | stabnet = StabNet().to(device).eval()
150 | ckpts = os.listdir(ckpt_dir)
151 | if ckpts:
152 | ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], "%H-%M-%S"), reverse=True)
153 | latest = os.path.join(ckpt_dir, ckpts[0])
154 | state = torch.load(latest)
155 | stabnet.load_state_dict(state['model'])
156 | print('Loaded StabNet',latest)
157 | stabilize(args.in_path, args.out_path)
--------------------------------------------------------------------------------
/train_vgg19_16x16_future_frames.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import cv2\n",
11 | "import torch\n",
12 | "from time import time\n",
13 | "import os\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import torch.nn as nn\n",
16 | "import torchvision\n",
17 | "import torch.nn.functional as F\n",
18 | "from torch.utils import data\n",
19 | "from datagen import Datagen\n",
20 | "from image_utils import *\n",
21 | "from v2_93 import *\n",
22 | "import math\n",
23 | "import datetime\n",
24 | "\n",
25 | "device = 'cuda'\n",
26 | "ckpt_dir = 'E:/ModelCkpts/StabNet-multigrid-future_frames'\n",
27 | "starting_epoch = 0\n",
28 | "H,W,C = shape = (256,256,3)\n",
29 | "batch_size = 4\n",
30 | "EPOCHS = 10\n",
31 | "grid_h, grid_w = 15,15"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 2,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "from models import transformer\n",
41 | "class StabNet(nn.Module):\n",
42 | " def __init__(self,trainable_layers = 10):\n",
43 | " super(StabNet, self).__init__()\n",
44 | " # Load the pre-trained ResNet model\n",
45 | " vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n",
46 | " # Extract conv1 pretrained weights for RGB input\n",
47 | " rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n",
48 | " # Calculate the average across the RGB channels\n",
49 | " tiled_rgb_weights = rgb_weights.repeat(1,5,1,1) \n",
50 | " # Change size of the first layer from 3 to 9 channels\n",
51 | " vgg19.features[0] = nn.Conv2d(15,64, kernel_size=3, stride=1, padding=1, bias=False)\n",
52 | " # set new weights\n",
53 | " vgg19.features[0].weight = nn.Parameter(tiled_rgb_weights)\n",
54 | " # Determine the total number of layers in the model\n",
55 | " total_layers = sum(1 for _ in vgg19.parameters())\n",
56 | " # Freeze the layers except the last 10\n",
57 | " for idx, param in enumerate(vgg19.parameters()):\n",
58 | " if idx > total_layers - trainable_layers:\n",
59 | " param.requires_grad = True\n",
60 | " else:\n",
61 | " param.requires_grad = False\n",
62 | " # Remove the last layer of ResNet\n",
63 | " self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n",
64 | " self.regressor = nn.Sequential(nn.Linear(512,2048),\n",
65 | " nn.ReLU(),\n",
66 | " nn.Linear(2048,1024),\n",
67 | " nn.ReLU(),\n",
68 | " nn.Linear(1024,512),\n",
69 | " nn.ReLU(),\n",
70 | " nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n",
71 | " #self.regressor[-1].bias.data.fill_(0)\n",
72 | " total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n",
73 | " total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n",
74 | " print(\"Total Trainable mobilenet Parameters: \", total_resnet_params)\n",
75 | " print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n",
76 | " print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n",
77 | " \n",
78 | " def forward(self, x_tensor):\n",
79 | " x_batch_size = x_tensor.size()[0]\n",
80 | " x = x_tensor[:, :3, :, :]\n",
81 | "\n",
82 | " # summary 1, dismiss now\n",
83 | " x_tensor = self.encoder(x_tensor)\n",
84 | " x_tensor = torch.mean(x_tensor, dim=[2, 3])\n",
85 | " x = self.regressor(x_tensor)\n",
86 | " x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n",
87 | " return x"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 3,
93 | "metadata": {},
94 | "outputs": [
95 | {
96 | "name": "stdout",
97 | "output_type": "stream",
98 | "text": [
99 | "Total Trainable mobilenet Parameters: 2360320\n",
100 | "Total Trainable regressor Parameters: 3936256\n",
101 | "Total Trainable parameters: 6296576\n"
102 | ]
103 | }
104 | ],
105 | "source": [
106 | "stabnet = StabNet(trainable_layers=10).to(device).train()\n",
107 | "optimizer = torch.optim.Adam(stabnet.parameters(),lr = 2e-5)\n"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 4,
113 | "metadata": {},
114 | "outputs": [
115 | {
116 | "data": {
117 | "text/plain": [
118 | "torch.Size([1, 16, 16, 2])"
119 | ]
120 | },
121 | "execution_count": 4,
122 | "metadata": {},
123 | "output_type": "execute_result"
124 | }
125 | ],
126 | "source": [
127 | "out = stabnet(torch.randn(1,15,256,256).float().to(device))\n",
128 | "out.shape"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 5,
134 | "metadata": {},
135 | "outputs": [
136 | {
137 | "name": "stdout",
138 | "output_type": "stream",
139 | "text": [
140 | "Loaded weights from E:/ModelCkpts/StabNet-multigrid-future_frames\\stabnet_2023-11-01_23-17-37.pth\n",
141 | "Reduced learning rate to 5.000000000000001e-07\n"
142 | ]
143 | }
144 | ],
145 | "source": [
146 | "ckpts = os.listdir(ckpt_dir)\n",
147 | "if ckpts:\n",
148 | " ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n",
149 | " # Get the filename of the latest checkpoint\n",
150 | " latest = os.path.join(ckpt_dir, ckpts[0])\n",
151 | " # Load the latest checkpoint\n",
152 | " state = torch.load(latest)\n",
153 | " stabnet.load_state_dict(state['model'])\n",
154 | " optimizer.load_state_dict(state['optimizer'])\n",
155 | " starting_epoch = state['epoch'] + 1\n",
156 | " optimizer.param_groups[0]['lr'] *= 0.1\n",
157 | " print('Loaded weights from', latest)\n",
158 | " print('Reduced learning rate to', optimizer.param_groups[0]['lr'])"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 6,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "def get_warp(net_out,img):\n",
168 | " '''\n",
169 | " Inputs:\n",
170 | " net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n",
171 | " img: image to warp\n",
172 | " '''\n",
173 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n",
174 | " torch.linspace(-1,1, grid_h + 1),\n",
175 | " indexing='ij')\n",
176 | " src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n",
177 | " new_grid = src_grid + net_out\n",
178 | " grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n",
179 | " warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=True)\n",
180 | " return warped ,grid_upscaled.permute(0,2,3,1)"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": 7,
186 | "metadata": {},
187 | "outputs": [],
188 | "source": [
189 | "def temp_loss(warped0,warped1, flow):\n",
190 | " #prev warped1\n",
191 | " #curr warped0\n",
192 | " temp = dense_warp(warped1, flow)\n",
193 | " return F.l1_loss(warped0,temp)\n",
194 | "\n",
195 | "def shape_loss(net_out):\n",
196 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n",
197 | " torch.linspace(-1,1, grid_h + 1),\n",
198 | " indexing='ij')\n",
199 | " grid_tensor = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n",
200 | " new_grid = grid_tensor + net_out\n",
201 | "\n",
202 | " #Lintra\n",
203 | " vt0 = new_grid[:, :-1, 1:, :]\n",
204 | " vt0_original = grid_tensor[:, :-1, 1:, :]\n",
205 | " vt1 = new_grid[:, :-1, :-1, :]\n",
206 | " vt1_original = grid_tensor[:, :-1, :-1, :]\n",
207 | " vt = new_grid[:, 1:, :-1, :]\n",
208 | " vt_original = grid_tensor[:, 1:, :-1, :]\n",
209 | " alpha = vt - vt1\n",
210 | " s = torch.norm(vt_original - vt1_original, dim=-1) / torch.norm(vt0_original - vt1_original, dim=-1)\n",
211 | " vt01 = vt0 - vt1\n",
212 | " beta = s[..., None] * torch.stack([vt01[..., 1], -vt01[..., 0]], dim=-1)\n",
213 | " norm = torch.norm(alpha - beta, dim=-1, keepdim=True)\n",
214 | " Lintra = torch.sum(norm) / (((grid_h + 1) * (grid_w + 1)) * batch_size)\n",
215 | "\n",
216 | " # Extract the vertices for computation\n",
217 | " vt1_vertical = new_grid[:, :-2, :, :]\n",
218 | " vt_vertical = new_grid[:, 1:-1, :, :]\n",
219 | " vt0_vertical = new_grid[:, 2:, :, :]\n",
220 | "\n",
221 | " vt1_horizontal = new_grid[:, :, :-2, :]\n",
222 | " vt_horizontal = new_grid[:, :, 1:-1, :]\n",
223 | " vt0_horizontal = new_grid[:, :, 2:, :]\n",
224 | "\n",
225 | " # Compute the differences\n",
226 | " vt_diff_vertical = vt1_vertical - vt_vertical\n",
227 | " vt_diff_horizontal = vt1_horizontal - vt_horizontal\n",
228 | "\n",
229 | " # Compute Linter for vertical direction\n",
230 | " Linter_vertical = torch.mean(torch.norm(vt_diff_vertical - (vt_vertical - vt0_vertical), dim=-1))\n",
231 | "\n",
232 | " # Compute Linter for horizontal direction\n",
233 | " Linter_horizontal = torch.mean(torch.norm(vt_diff_horizontal - (vt_horizontal - vt0_horizontal), dim=-1))\n",
234 | "\n",
235 | " # Combine Linter for both directions\n",
236 | " Linter = Linter_vertical + Linter_horizontal\n",
237 | "\n",
238 | " # Compute the shape loss\n",
239 | " shape_loss = Lintra + 20 * Linter\n",
240 | "\n",
241 | " return shape_loss\n",
242 | "\n",
243 | "def feature_loss(features, warp_field):\n",
244 | " stable_features = ((features[:, :, 0, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
245 | " unstable_features = ((features[:, :, 1, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
246 | " \n",
247 | " # Clip the features to the range [0, 255]\n",
248 | " stable_features = torch.clamp(stable_features, min=0, max=255)\n",
249 | " unstable_features = torch.clamp(unstable_features, min=0, max=255)\n",
250 | " \n",
251 | " warped_unstable_features = unstable_features + warp_field[:, unstable_features[:, :, 1].long(), unstable_features[:, :, 0].long(), :]\n",
252 | " loss = torch.mean(torch.sqrt(torch.sum((stable_features - warped_unstable_features) ** 2,dim = -1)))\n",
253 | " \n",
254 | " return loss"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": 8,
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "def draw_features(image,features,color):\n",
264 | " drawn = image.copy()\n",
265 | " for point in features:\n",
266 | " x,y = point\n",
267 | " cv2.circle(drawn, (int(x),int(y)), 2, color, -1)\n",
268 | " return drawn"
269 | ]
270 | },
271 | {
272 | "attachments": {
273 | "image.png": {
274 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZoAAABRCAYAAADrXmCgAAAgAElEQVR4nO3dZ2BUxdrA8f8m2d0ku5tGCiSBhBJCIPQmSBEEBPReBUV9RQHFchWwt2tFEUQREeWKBa96EQsioCAoSIdESuiBhPReN2V7Pef9QBEUSEiBBOf3cXN2zuzm7HnOzDwzo5BlWUYQhLpzF7H6jR8JvqcDyf/Opve8jmy/5xVcz/yHPuU/4bhuHFUfPELRuKVM7ReBn+pKV1gQLi+PK10BQWj2yvayy+VN2s9p9PtkAvGeZeR7dcGveC0nejzCwBATeYf70trPG6Xnla6sIFx+ItAIQj05i9M5ziHofRsdfXRUHtxJUicHB5yDmNDeB+/Uzfw2sDdxkT6oRaAR/oZEoBGEenFyNGEj3SPHMKiLP94qPft27MPjUDR3/rMrOh0c2bWBXr0708rHW/zghL8lcd0LVyVZcmAx6NGb7LikRhyGdB8ncW0PevbpTrhWhaczlxOJKm6aOYleIX6o3MdI+Kk3fVr743RLNGZVBKGpEoFGuLpITqzmakrTdrHi/Sd5bVUSOdWuxjtdcQYF1/ajaxsN3p4gpSTys3Is3WN1+KiAshwyu7YjxL6TNftKMNulRquLIDRVItAIVxdLPnu3f8en77zBf77ZRVGFtVFPZ8hJx69/PEEaHzwBg8FM6JSBxAX4ogIwGijKXMXLP7gZ1CMQjbf4yQl/PwqR3ixclVL+xx2PfYg0djZzJw2hfaDyStdIEP62xOOVcHXyABRXuhKCIIAINIIgCEIjE4FGEARBaFQi0AiCIAiNSgQaQRAEoVF5XekKCELT5cJSqcdgdeG+HLmZPgGEBvii9BBZDMLVRQQaQbigUrYvnMH8Hw+RU+U+Z1a/UhdEgEaNV11jguzG6XJit5kxmx0nyx49kzVv30GsX1NeqkbGbihEb9IRHO6HWIi6kcluHFYTRpMTlX8QOnU9rgzZhd1sorrajBNAqcZXF0CgT+OHARFoBOGCwhn91CuUuV7i0zXHKTK4OT3rLGrco7w86Qa6hGrwqksLxFxCZm4aB3atYcWq3WQVVGCvsmCVZJryxDZZkkhd8X88/uX/sWjbw3S+0hW62jn1ZG77hs+XZ9HpsVnc20NXx4IknMY89q9exqcf/0KG5MTSMpY+dz3JnJviGz3YiAmbwtXpxP+449EPkcY0wITN8kQWvT6LrzedoMzsPhUIfOnz6Dxev+c62gf74lmP3i5z+gpmPvcW6w6NZOGuFxgapqVpTC+VcZorqXR6ExTgK55KmzHZXUHqoaOkpGsZcXsvtNVpbPhuIc+tq2bYgzN5Z2z7Rm1FN90WuiDUg9vtRpZl3G43Un2fpYIHMP2V55gwsAMtfDxPzQO1sO/9Oby7cjc5VfZ6LZap6XAbT08bRYjGRIWxKS28aeLgl08x5c115MuAo5qCggIKC/VY3Fe6blc7GbfTQnVZAQUlZeit9Vuvz1luQ6fU0Of2XmgB/GPoN2w0U/vbsZeWYWqQOl+YCDQNRkZy2bFYrdiczX/hRFlyI0lSk+7GOS/JhqGslOPJaZjNVgoyUkjPyaHc7KjfgH7wUJ5483kmDGxHCx+PU8GmgN/efJn31h0h1+CsV4AI6zOC64MVuG0yTaaPwX6Y1Z9tIy+7kAqrjDvlf9w5/h+MHjWTzforXbmrnYuK7F188eI/uWXqI8zcWlCv0lRh4UR07U3kmVdkVKoAAgI74OOtbvRAUP/W8KnBqupKE46zX/dUotb6E6Rt/A9xxckSTlsVhVkHSCqQaNmuLwPbB5z5s+QwYzAaMNtqEYA81Gj8/PDTqBrue5MdmKsMGC12zn4QVWr88ddqUJ9nRNtt0lNud6PwDiDQ1xtVffqGLqfq/Sz79/v8fDwfE354H1zO3OnLuf75j3hoRDxhPvUoO/g6Hp3lRH5xLst/z6bSKiFTwLp5z6D0mM8TN3UjQudVt5Vv/Hoy9ol8zL5KPC4WaBwGisqd+IUE4atUnHUuCUtlCVWeQbTUqWmIxDXHgV9YXa1l5JQJ9PJVQLfpbPrWjxmLAxkYWv/yG4Qs4bRUUlplq8XBHnh6+RIY6o+6yV/OSkJirmfSc/OJ2bIcZ5eovx4iubCaTZgcnvgF6vgjT0DG7bBiNFuQvAMIOt/4i2ylpDQXh8uL0YN6nmzlNKL6BxqnnsxtX/POq0tJMlZhdHig1vjh37YbA+56jJl39KCuw1fNg4TTXMDBLWvZW+FP9+tvYWCk7zlHVKWt479LPmb5xkKMZjdKnQ6dVs3Zmy3Ksgub2YjJvz/3PPkYD4/rQWBDVdGSwa+LFrN09RbSDGZsCm90Oh1d73iC6XePo1+4+i9v8dKq0CeuZHtFAN36DqZXVAt8lM3gkSFwIA8vGcjDjVS8Mnwkj71YReXri1j3ez4mh4Rclc2Pb71BaOBspgyJJdTXow7BJoDeN91T82Gpn3HT/yXz6A8L+b8YDaoz/5IKNrwxkhdbf8yeR65Bo6rvVp4ukpM2Q/Q93Dm01cmXZJmyA7/hNWwhQfUsvaHIDiO5m+cx+Y2NmCorsbh98A8JwOecS/Vkb4PZ6EITdQ/zVz5Nf98LldiEyEb0+Tby9w5n1P3n+bsxi13ffsH7u1ry/MIZDDxzw3BQdGwNn377IxVDZvPB2LZ/Lhi3sYRiRwDKrg8wIpJGV/9Aowql05jpvNOpG98ueI7397dkyAP/5uXbBxJen6fHZsLtKOPwtq/4LcuT7mPu5trIv17BQV0m8PTcnvRo+R4ffF1Gtycf59HJAwg5uxxTLjtXfMBn6f6ER0c0XJAB0MQx/uV36TvkXeYsWEai92genvEv7uzXFr8LjTp7BNDlhrvQbP2Az9YWUH7deEZ1DcenubRsGpGywwRmvuLE/uRifksuxOyWkKsO8OnzL6KYN5d7B3Ug2KcuwaYJcZ8gaaeBa2ZM4JpTl7Qs28lJOUrHm5tKmAGF2p/2/5jF+r438Omj0/mubDyzV8xiRIuzDpIcmHIS+ea/b7LKtzuxzSHIALLNhEWhx9C3PS1qPrzWJJeJ3GN5GPO9GHZH6wYs+cIa6BH15NOTwsNFm6hejB/59wgySFby965hxa5ynFFjGBhzkbabAiRAERtJWMfIvzwRemrbcO11Y7hj1ECiW/g3QmU9USCjULRg8IChDI2/SJA5w5vogRMYzFG2rPuJxMwqXE1l/OAKU3a4i9fn/YthnVqh8TwVVKoO8Mnbb7F8dw5V9mY4vnUWKX8viaX38MDI6FOvyMj2JNZ/P5kRXa5o1c5PBtnDG3X/fsT9+a7soUIb2ZN/3jWVAZ1jCDhvAU2NjL1aT1V5AZqeXRqsV0h2Oyg/cZAcfR4hY0ZxGRozQEMFGtmMsSKf7IIWBHUZQJfwBim1yXOXHmT7uk1k6APpEdf+ohewVJZHSlkBztCWtA8JPtNtJrvsOBwOnBIovEJp2yqC1sGNMA1OrqYgu4gKVxRt20USWttYpoqm35CeyPt/46cN+8g0iHSj03xjJ/Lv1yczNCbkj5Ze2mYWzF7A+qN5GJ3NNdhIFO3diTzj9jOtGWSwH1zPmoeG08VhwOC4aAGXl8uOJfcIe1JUBHXrSMTp12UJyWnG7AQUCjS+flzTOfoiBdVAcuMwV1JaXoWtsX8Gsp2K0gry09TER3ljs9qpb4qR7HZgzN1NSnEOju4T6ekHyC6cTjtWR+MmMDVMoLFVo89I45jen6Bu7fl7xBkH+Sf2cfh4HrqwtkSHXaw9LlFamE5RkYXIllG0CfU587o5bQ+Hjhwg0wKekfF0jomlfWM07c05pB8qQx/cjvCoUC6lzRTQoQ89Wyo4tHs3h7LKcDbPu2ejCO39AC/OvIOe7VugOj0Cn7aGV56fz8/JhZicTWUCpoSlsoiiwkIKCwtPJjJITiyGckqKCyks0WM68481k5oVy4Mj/rgpy7g4nPgd3X2hYP9/+Tr9ynyK83E7bZSkJnFCpWFUXLszr8uOavSHl7M2DfDSoY26gZEx9TiPuYgDK1/nrkfeZkfJpb5bRnLZMJSf/P5LysoxOdy4nVYMlaUUlxRTWmHAerrLwF5BeVoqhw4oMRUkk7D9EOa6Vx3JZaXixG/8un0be539iPc4WY/CrCPs25/AthOWepReswaZg+UwVpCbk0KFLpDx0bVvjEkuKxZjJQYr4KnEW+NHQHPJUnMUkHkgkxNp4cSOir54C0GuoCgrh+LKFkRGRhPiByDhthdyOLmMKm1rons3bnXt+ekc0ZcT1DaKdiEtLu079gunTTt/dEv3kHx0JAPjWhLx1/yBs5xM0a3bXGAFHs1sra/Qa55g3qsOHpv5PQczKnBKMqSt4fV3Q9A+9yDXdwrF94qPbVWxbeFdvPVzGQXlbsZ/vIkXuley84dFfLZuD2nGaO6d9SHTBoehVqjpdPu9hJ+diiRXUaoPI2vldO5+YDFbJl+xD/InMg57BccO70WluZUubU/2B8uSC4O+hMQd+bScUo/SJRcOuwNUPngho3DasFfZMTo4mfHmsOFEhY+6pmxDF5V5u1k+71k+32nEK6IHk2bO40bNEdb98CUrE3JRRAzn/iee4ub4AJQuJ1ZXLom5BylZombarKn16j4z5iWy9uvZfLYJYD0/nvlLNP1umMLjrzZu3lkDBBo3hqp8MlJzCAicTGxk7QZnZLeD0hMbWf75bL7fLmHVRtLj1mm8+eAIQupSK8mJ0+XC5emDT30TbmpBrigktTSf5JY6+kT44Xexg616CtL1FHjpaBcIjqIiilxGCg/+yPJUJzFDejKiUWvrpjD3IBUVnrQd2IZQ/0udd+5HUCs1vn5HSM3NpqyqNxFhF/knOUwUlegxW+vSv6IjLKIF2oZM774MQq99joUvy8yYuZxD2VW4JBnn9v/yrELNwpn3MjgqGJ8rGkCDGDNzM4OufYr+D5cSq97H6v2xjHvgQ268ayMv3fIwi594l97b5jJUqyI86tyBDoVHCP+YvYt/XKHaX5DsxGnOIuWICs9ObdFYiyiySDhNhST99h3LT/ThzboOysgOjGUp7Ni0D3eXmxge7YU2KAo/L4lArRNjSRYH9yaQ4tmfu8fGcfE7n5IWbYfy4BtfEfvVu8z51QdvWypHiWfS84u4PnEpb774LUv/E03se1Ppqo2i/x2z2X5HHev+J/5thzN51nAmz2qY8i5VA8yjMVJdnEN6RhCB47rSrlb/VAlrQRKbdx8id9T37HjJwZbVn7AkaRsHikYw6pITIWRcZQc5mJVHWfh4xra59I9xqSyGCkzVZaANQ+Wn5WIP+I6SXFJKCijMzOfnN4+wQQEg4bAa0Q26kyHBQX9anFDGaXcgKzxRquo4L+NsUhFZx8ooJZoubcMI/HPXnNuK1e2Fl5eS82cwawkK9sVX40FCfiH51Qa6hwVduF4lO1n81hdsTcqrQ2WvY8aChxh7TVSj5/Y3tNAhT/HWU0aefmsNR/NNJ4PNtsU8OzeQT+bcRa8Wmiu8jIuZ3xN+RvLSsTmzDZ9O7ngy+AUM5vpR8N2nezieB0PjGu6Mcl0m/SoUKBSK2l33LgfmzCMkpJuxVizhifFLTp3XidVTRch9d/4xZnNWnZw2KwpfzcWX+nGYqUrdwrIvllPikUjajClcrw6hm5cNZfUuVn61gI9+ddLhOi3X3xBHuxofcN1UG/JJTt5EhaM/hfa2TBwaiRKJoKhu9Or2CatKUsgtg66XJxnssqn3dS9bjJTnZZOp9KdvfHtqN4/LQF7yAXI2FHHtnDbgD8Mmv82wujbHJTOZx/exe1827W4fX8dCLo3JWICxOgeUvfFUnjsn5k+Vo6zwOKUlDobeM4tHp99IrAZw6Dm0fhGJ1WGEB57b7yY7Kzi8JwVJG0p8z5ganpRqJlcVkJ1VgTm8L60jw/40PuPCmvobm81tiO7QnS4XyKv20YWjVPvhqjJgttpwwYV/pK3HMGvRmHrWumZ6vR6Ho36j0n5+fvj6+qJQNERrQ0n0jTN502ri8fkbSCu2IXn54uusxmqVTs74v5KNGmcyieslVH0e4sV7uuF9poVVTWUR4NGGsAZMeJRlF1X5mZTXZi7lGQo8vHwJiYrArxY9E26nhcLjeygPjmX60rVM7QjgxqpP5bdv5pLQts2516ksYSnPZ++GLUTePZkOFytcHUibwY/w369HkLjuO775fg4HIzvjKxXz/WcbqfAdxasfj2dI51b41qYXRTZjKMni+LEIYsdOZNKotqfq5sRlN2Gs9EalCiNAU4uympl6Bxq7UU9+5nHsflH0bP/nZ4e/kt0ObGY9ZZUGzA4rVWVFlPl5o/UPxMfTjcNiwmCw4MQDpVqDX4D25KQ02Y3dYsRosJ5Z4tpH64e/0o0lbz/bt27jp/RQplxXRoVWjVLpwmSyo/T2Revnh5fLjMloxCqp8NHq8PdVITstmKxWJA81XpIdo8MTjU6Hn9oDl82EyWjC6gJPtS9anR++yj/uEgqFJwpFLTp3pFJy04vIc0XSJS6KsNMXkaoF7WNHItm90J3+cctOrAYjpYdW8XuWTFDHKGLNDtT17EYyF6ZzvEJPeGwEbYP/aHK6bVVUm9JZvzgbqasvHTuasbt9UZ9nPEGt0uHlqQaLDZfTWe8MmIbw+qsz2Z+UVK8yHn/qKUaPGY1G01C/biXtbnqS+/fvY87aSqSuE5nz2oMMiNChvMLDNK7jCWzV+zDyjVvp6PnHPB/ZkcnB38GzdXdiWjVUJWXcjmK2fDiFBdsu5X1KfAOG8czSmYwIrulYN3ZbKUcPpKDWPUiPM/MSPfHRhjP42uFU6U4/+p6ctGnIzyTlyCreqB7C16XV2IL98b7Yj0vhhTqoIwPGP06HPofZsHwBX1h13DpiDs/1jCRQo8G7ll31srma8rRj5Gjb8M+xQ/9ILZasGMqLyMzUoRvajsjGmqYkn31/BbxU+Oj8CfBu8tsEODBU5pKeWkFA4P8Re544I7tduF1OJKUPKg9w6lPZ9M07zP7iIBaLE+X0W9nYfziTn3mNYdpMdq9axpIlG0h3aYnuNZb7XniIoVFq5Ko0Er79nMWfbafI04UlIp4hk5/m1Z4WNn73EZ+sPILVDu8+UsmRYf2Jikrlv4s3ET5gHP96+RXiin7ki4XzWFkexw33PcULt3SiYvfnfPbLLqythtHOkMCnR3VMePg5nu/nzfHty1j25XJ2pTvw6zaCCQ88zp29W565qNS+LVBrwsAiIUsnV/Q9309UriokI6OAQp9IhkeGnpMC7fIKwl/lScDpJos5na2fLGbxF1vJ91ThrTlA0bT7mXR3L2r8zV2QndysI5TrA4hu3Y7QwD9+VSVJH7Fg0VI2HvNCtduXTZWPM/2eW+hznpUCZFkCZPD0wKPGp//Lkwzw3gfv1/Ec5/LwaLjRILetmpSf3+X9DSbcXW7n9denMSDCD9WVzgVAIvPIZorUE3jlGv+zlqiRqN6xmlWVvnScMZYuf6qnw2gAXV32nVHgpY5k/NwEGq2PQXJg1x9m33YVPpO60P6spossK1Eo2hDb6vQrDqoLdvDhhOf4vtKOssUWHtp+A48vfoahF7qxyxIuWxXFWQdJ2LiWXw5k4N1zMNeE5VGy5XkeXtOLcWNuYVivGMJbaGp8kLCaK8jOzsUj5J9c2/n0j17GVVVC1oHf2e8fyY2j+3H2YjOS24XL4cDDp76rZ0s4TbkkrfqKjxevJ0NyYg3vTL+JTzLnpi6NHmzqV7rLQEV2GsfTdQRMjaL1n0uT3ViK00hPz0DqcRM9/UEV2pWbHltITMevWfm/o3Sa8yHj2oJkymDv6k3syRjAG4kz8cv4hW8/+ICv54fT5t0R2Lf+wqYtJsZ+tpEpUbms+fI9PvvxC3b1nMeEGbNpFfQhP+z2YOQrb50co7Fk0tlDxer9AGqi+k1k2vMy6lU/U4Gd/IMrWfHxZ6wt8SPkGhdx7Xpxf6syZC8LKQk/kZArcdP8BOZ67eerTxey+ptPadnyBUa3ORlpfHyD8NMG43GiCou+EjOtzzOmIGMsziKnOJ/QsD7EhZ47wBrQodO5c2+0cYx58D4M5jiMkb24+f7+56wecLrMk/dWBbXq7XEXknlYT5m6Df3ahnB2HkD4tc8xzVxJq/JYrhk8kYGtLzzSZLNV4XTZ8GoVhJ9Wc9ExqcuVDKA41ZffVEgOI2nrXmXGm1up9h/IE8/PYHi0P/XZq6rByJWkHTqB500P0Mn7j/EPtzWTld/+hFf7+3n+1thzH5YcBrZ/8RmtZjxBk5yj6bRjTE1ih9qbQd3OnYip8Nbg32M4/c68oiYwajAPfT0X5QO76PXDa+euHnA+TgNFSd/ywotfUhJ8HZNfeo+BFVv5zyYdtz36EMXLZjP31af5acD9zHnrTjpddMDHidVYQn5lFX6DY2hz6lhZslKUe5idv2fToec07hx0Vtau5KIyO5Xk5BNE/XMc51ntrNZkdxUZJ/IoUY9lUeLLaKtP8Ot37/P80nd4XTOTd8acvU2AjOR2I0ng4eXVIOvm1T3QyBKOymIyjh8kJTCA8V1iz0q/k5HdLuxVmRxITiKpqgsPXrTvV6aq8BhF9n0Ej7sFZVExVs8Q2rSP4fCeXSTnjiTOR0vLiDAiAlRYPFTotJH0wFjHynsT3X8S054D1QcrKPaO5cb7RhAOyCW7+XynityQgdwsF1PsDKRdTAxBlaUcyC7khjatUQCqwDAiQsJpaTdhNzuwwV8Cjew2kZNxnJw8HyJGdCI8pOY2tttQSl4EqNsF/TVwyU6sVXr0VVYUmkCCAv0uvv6Y7MaRl0JyTiFS9CjaR4ScW6ZURVGBGx+fAPy1FwsdMmZDEQ6riZiQYIK1NUz0KdnJh3O/YGtS7sWPO69hPPpe80sGkJ1mcrYv4rW3tlOtG8Dj82dxa3yL2vXdXw6W/ST8KBPzUgeCFApAxmWt5uiK2Xy0vzv3fT6dAZrTdxQnxiI9puQPeMR4HduLKrGGBl6WbM7ak3FYqjiStBtPZTTXxtZiWoXTgb00n0L/toyrzZouSl8COw5l6rNt8e41gmv89RxZl8tBp8wNdOaGBz+g47W/c9QVf05r6rzcRqrzU8k44Eu7sZHoANntpDr/CDvW/MgBj2FMfeRm4tQAEi67FVPuUXYc2skPlsEsqDBhD9DW+aHl9DYBve/ofWqbgI70HzaaqeX/43hJGSban8mclRzVFGWnk12upGWnONoFqeo9tFiHQCPjdtow63M4nriOlT/uxEs7AJ3bTlFR0clD3A6sZen8/ttKNpxwc91jt9YwoO3AbDJxaGMCaz4ez5KzLujg7tdhNSvpOHISkQOqqcjbzopln/Ph8sO07DWavpf+Ac4VqUPTMejMl2wzG7DsXcPaDUvYrj6rIlH96WQxYgV8AXRhtO8YRSy7qM4so9IBZyb0y06sRhMV6dv5bf0e9hu0DPUHh9GI1VOLzwX3/5UpyjmC1seL9tGRf/3ODMdZtehdPlq2E2efu3js8YcY3zviPN0yEg6LAWNZNolrVpCYfAx1vxG4qywYzN5/tBaqU0k+1h7/AR0Jv+jiatWU5NswmsPoGNmGYL8a0hNaj+GN/zR+MkBTITvNlO77iJde+p5U3y5Mmf0y47o0oSADOI8lsk4B3U3FZOQHEOBlJWf7B7z8oYNx785nWlfvs24mqfxv/L18UmREF7aV238eyPOr5jG2SazYLCO77VQXFpKT/ANfrqpAoRqGj6uEUnMQoZoL3/FddjOFhQfJuWsGtZq3qVChDe3CdTeebM+5jS6sphLcngpMZvAKCSKm79halSVbDJRmH2Wv7GR0VRFZ+TKK6nQS1n7H6tT23PX8NG6OPl13C4UH1vDhI2/zi8mFKnAzT+8dz7RZU+lbx1RtVVg4EWFnT6WXUakD8P/LNgES+rRNfLPgRT4+0ZWb/vUcs+7qU++HvjoEmpP7JHz77st8dwCgJdizWPbMeJb95dhQ4gbeSZ/2NeVNKfBQeBDZdxjPvrGQ2zucrpaE2y3hdkkYMnfx3eeL+FEfxz0z3uTT2BWs/bno0qtfU00UHig6Dee262/hhZHRZ7LJJJcLSZbPyi4LokP3PnS/JoGtFUfJLh5FzKluNSzZ7PryMz5bto08wBvYveRVCjMnct/9UxnZ7gLfh2ygrDiYQHUYUS3Pc4zNTLXLC/wDcew/xtE9aQzoFkHUXzrQjZzYsJBPlv7C/gKAlpC4nNmJ2Ux87gHuHBdPIOCuLCK7T0t6xYVffBFPdzGFR41YNYPoFh9Ny6t7Oe5L47ahT/ov/37ta1Kcbbht9mwm9wlHd6VH/s/hJuXgVmytopD3zeXBpSbkwCC6DLmTOWvuok/wnxcAjWfazv9Q0eEzuq75mFuaRIA5zYW5fA+fTHiOnwC8/fHmN96ZdpDwB99h+YTYC7xPxmk3YixM5x+j67Y8gIfSh+B23Rk8SEH4RSfO/fXcZqOerJxjOFXeHF/xHJPec6KJiuHa0ROZ8/BwYs4pT0ub/mN5YL6alp/m0+/DRxnY0Au0yVZKSnJxujy44ZxtAjxQa8IIjYwj2h5JeEQw3g1wujoEGiUhMSOYsXgEM+p6VvnkALpLlpEkCVCh9Q9EUpnYeegYN7bpikalQHZXkJdfTEmWmeKD69mgH8HLbz7MNQFlHEh2Yj9f2ZILl0tCQnEyaMgybkkGyYXdZsNsdmA1GTFYnef98N5+QQSW2EjLOUT6oAg6+ihRyC6Ks9IwOiVad+7C6Y4jnzY9GDToGo5szudASiaDImNOLk+uiWHEjLmMuMQvSK7O4MQxC8ZWOjReDhwuJaqzWz9hA3jwhV5MfuIEvy7cSIWHzPnTv/yJv+VV3r/l1YuczU328b20cHYkROWJ0+nEU6k877iIKy+No2Y7rUcOoGeH1lyF2Zd147ZRfWI1b876iiMlkYyf+w6PDIrE/8qP/J9LziN5TyGqMYv5z1ND0XrW3P8i5R/nhHYbT7cAAAvKSURBVDaW8U0qyAAo0YYN4dnfE3n2Ut4m2zDrM0n4uT8DJjoxW0Dje2kTlxXeLWg35CFeH3JJbwPsmKvzyC7REn/f2yy5o2uNW3XLNhPGygoq/SOJaPBVQE9vE+CPsttDjPxTr6Nf9CAmzRzEpAY842UfppTdDqzlBeTlZXPCaCQzPQt9hQmfVvH0C49F8cOHLNmaTnFRMQUpGVQU5KNqF4BSZcRhPUx6ZjHFmSkczUlhn92KsboKi0uBp9ILhdtKVU4yx1OPkuLyJzA4HF+7ieyMDIozfufXTbv4ae1Rtq5cxqc/76Ok2oDZYsdiNmK0Ok9OLGsRx4DBIWiMP7F0xXayioopzj1IWnU1ZV5tOWd0wjOEroPHckOElsqdiRwqttdrF0dLURapLR1UhelJPnCUI/l/nYCg8AJnWSEl7cIJ6hNDeJ0fNwpI3e6HsjgQ46EdHM1Ip/p8h7lNHDuwg+LWPRg+qi8xoWLneADcdgzpPzLrmflsLQxj5Euv8fCQaAKaxMj/ueTKZBK3ezKoUye8a5OSD5QmbeDI9FFNMgmgThwWDAWH+K1XNNrKJFb8kn35zu00U52XTV6aHz2jo2oMMgAOQzn5jhzK/tGrXkkA5yO5TOQez8eY58Wwoc1qm4Dac+pT2bRsPq99kUCWPoMfX32c5+es4qgjnO43TuHeW9ty6JVJ3DZuHNMXriJFN4iebaLpPnAEA5UH+Wj6OG57/SM2VAQSm5HE8sXzWZOnI7JrD4LZy4eff8W6qjC6+QUR038AcZ0q+OaFe7jtzZ/JCO3NQw+M4M6HHmRC6zRWLVnGtm0FJK1bwZebUjm5rJw3bQdOZOq4a/Hc8BKTxo3jtme+Ym+Zjl4d/zoI7tmiJzffN56+UXoSN+4g12iv81L6Ci8f/NJ3snr57+TY/egc/afuM8lOVd7vrN1bRqt2vbm+R2StLtrz88TbX8+eFXP5rVTGOyzur91nkgND5lZ+TQmj76BbGd7pPONGf0eyE3Pmeua+NJ+thV70f/oVnhgdf/6dDC+5aDMGqwN3Q+3nLDupTtrFb1JX+ncOrmUGkZOs40lMHBDfMHVoEjzwcrtwbf2Gd17JYOD4eqyueUlkXMYK8lIzOKHpTNfo2vS5SVhNFXgYixkeV4/Vps9XG7cDfdohcstzCBlzw2XbJgBZaBCuqhPyrjX/k79Zt0vONkuNc47yLDk7L1cuNrsapfyzSU6rbClOlL/78gt59a4Tst7e6KdsHiSHbCrYIi+Yeo3ct2c/+V9f7JFLTc6GKlw2JL4vP/tNglxgccjnvYqOfyKPGf60/H26WXacc4BeXv/acHnYp3tks8Mty7Isu60VcnHOdvn9ezrKsW2nyV8dzZPLzM7zl3s2xy75xdiH5R9LHLLB6GiQT/Z3JLmdsrU6Tz685RP56bG95RH/mCf/mlMuG2zui7/RXSmf2PK9/PakhfKOcodssVzg+jJkyjuWzpWnPPZfeW/V2X+wy4VHf5LfmTNNfmlzzh/1cdnl6qwd8vZN/5N/zT3zouxw2GSLvXHuWacpZLmhHp0EcFBtsiPLagJ0jbCnzGVky93J1hxo2b4bXcL96tFyuorILixle/nkxadYts9F+0mzefu+IbT2VzfIyjKys4RVr97Igsj5rJ5yLS1869dCqtj8bybM3IHd8cfmKT1f+o4FY1vjdbGmzYE5dFgUzI9vxLPxFz8ev/dqatlcPq6qLBJXvMm/lxw585ouOp4bn5zL9H4Xya82ZLJ980oWpYQz7cYOVBaFcsuo+rVsJJeNqvStbN6TRE7L25nY9VRGj72Y7JJKqnz7Mrpr400oEIFGaMYk7GYjFpsNpwtUukB0PioUTjNGswW7rEan0+Ct9EQhu7CZDZhsEiqNP1pv5SVORHNjLz/I0rlPs2RrJa3ueIsFD1xHVEBDBBkZl92EPmER9z/+DTGv/8RrY6LQXamkgqMfcO39y1AwkY9/n3H1jNM0F6Ycdq/9mCffSyF+wBReWPDPeo/TVGdtYtWSF/l005//Ek3/G+7jydcad7dNEWiEZsxA4tJ3+PaX9Ww/6GD4y5/z6KjOKDN/4JNPPmZ9WR9mvPoUt8SHojRlsuWrt1m0Kp/4f73FY2O70UpT2yFKN/bqE6xZMIN3VuSi+8eLzHpkDHFBvhdZTLUWZDdOpwOn1UjGziW8NWsVx40deHr1Uu7pGlDrNbQEoakTKURCM+bHgHteo3NHf6bOqiYu3Bv9kS242oxi4oQUDs7eSHL63QxpUUFOnppet97LyF9fYeXRNO4e2pFWmtpsZerGYc5g7XuPMn9FJuWmKDq6jrP+2xw21rfBYS0hNXk/R/aXYnGcWk5f3ZXwUC8acPk1QbjiRKARmjkbJ5J/x957EIH6Qxi7XU/3iCCK9qbhIJ6oIDtpud507xJJoPE4WU4XUcFBaFS1GXWScVkz2Lz4Nd5dmUuZSQaySfg+m4TG+jiDY4jSenHBxSMEoRkSgUZo3txZHE2wEtGqgkKv6xncIoAArzx2HjFij+uBymSjbd94/HQ+VBzYT6ZdR9+2kWi9axNonGSs+ZCP1mfg8NQQGNT4U1UVXTvQSunVrHYXFYSaiEAjNGtSdjJJxZChCOC2ri3xC1Ahl6Sy/6AJbXsL2tadCVHrUMnVpB4+gDmsLz1jg9DVaqKrC138VGa+fTfumg9uEKrwOPzVYnBGuLqIQCM0YzIluckU2iV6jBhKXEgQ3oAh+wiHLBZKNJ2IaaVBqwKsKRzaZEbuEU9bf00t12/yJbxzV8JrPlAQhIsQLXSh+ZKNZB0/ijW0NwO7tiJI4wG4yc9IwuqM4LbBfWgTqEUBuLOOsMMq0zOuHQE+Yn0DQbicRKARmi97Kke3mNF06kGHFgEnl8dxZ3JkbzW2zkPpExuM/6n9PbJS92Bw9aJbOy2SS8LlurJVF4S/ExFohGbLnZ3MLqsv8Z1jCDm1R45UlMLeNDsde/YlKsj/VN9wGTkphdi6hxNoSifPaMIiZo8JwmUjAo3QbBUVZOIZ0p1uHVugO9UbZirJocq3Oz1iI/A/s/OYhCQHEu4upzKwJ11a+eMn1tQRhMtGrAwgCLUk2Q1UGiW8/fzwVf15s7BLLg2bsQKj5Eugzufia48JQjMnWjSCUEuGpI+Zfu/rfH+4DFN9xngkG4bydFa9OoZR//6KVIP1/PvXCcJVQqQ3C0ItBQx8hm/WNEBBFTt4f8o8NuabMPcRLRnh6idaNIJQE8mOUV9OWWkpRpsbqb6dzcEjeWntBlbOu5Fg/+a9nYQg1IYINIJQk8oEPn74bm4eMJx5m7OodFzpCglC8yK6zgShJi2G8fTXQXjfmkDPjoFo1SdflhxmjDYHUq2aOJ54a3V4e9U3iUAQmh8RaAShFqTURDbeEs/oIA2n4gyGw18xe+VeSo21aeJEcvOMZxnTPgAfsZSZ8DcjAo0g1EiirKCAnqGDz9leIKDPQ8zr89AVrJcgNA8i0AhCTaQy9u71oPvNgWjP2vby0rrOvPDRalGLrjPhb0gEGkGogVSyhwQpnrsCleB2I3l64qEAw+GlJ7vODLXpOmvPbU88wci2/nifnYLjOJnFJmZNC1czEWgEoQYVxxJI0vTm5uLt/HCiD//sG0mQ1pOAPv9iXp9/1aFEF5aKaiqrbbi3pZFWWEKkbxsC1J6itSNclUR6syDUoLI6B92vb/P4F5V07xyEv7a+o/npfDd5Ig+9tQcP9SbmTbmDT/ZUYHE2SHUFockRa50JgiAIjUq0aARBEIRGJQKNIAiC0KhEoBEEQRAalQg0giAIQqMSgUYQBEFoVCLQCIIgCI1KBBpBEAShUYlAIwiCIDQqEWgEQRCERiUCjSAIgtCoRKARBEEQGpUINIIgCEKjEoFGEARBaFQi0AiCIAiNSgQaQRAEoVGJHTYbnITTYsHiApWPLz7KJhLLZRc2swUnStQ+3qg8xV6OgiBcHk3kLng1KefgN+/z9NPzWXq47EpX5g/GZNbMfZpXFqwgqch+pWsjCMLfyP8DJPZxUCjVta8AAAAASUVORK5CYII="
275 | }
276 | },
277 | "cell_type": "markdown",
278 | "metadata": {},
279 | "source": [
280 | ""
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": 9,
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "class IterDataset(data.IterableDataset):\n",
290 | " def __init__(self, data_generator):\n",
291 | " super(IterDataset, self).__init__()\n",
292 | " self.data_generator = data_generator\n",
293 | "\n",
294 | " def __iter__(self):\n",
295 | " return iter(self.data_generator())\n",
296 | "generator = Datagen(shape = (H,W),txt_path = './trainlist.txt')\n",
297 | "iter_dataset = IterDataset(generator)\n",
298 | "data_loader = data.DataLoader(iter_dataset, batch_size=batch_size)\n"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": 10,
304 | "metadata": {},
305 | "outputs": [],
306 | "source": [
307 | "from torch.utils.tensorboard import SummaryWriter\n",
308 | "# default `log_dir` is \"runs\" - we'll be more specific here\n",
309 | "writer = SummaryWriter('runs/vgg_16x16/')"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": 11,
315 | "metadata": {},
316 | "outputs": [
317 | {
318 | "name": "stdout",
319 | "output_type": "stream",
320 | "text": [
321 | "Epoch: 2, Batch:1193, loss:8.925530743091665 pixel_loss:4.0964155197143555, feature_loss:3.904881715774536 ,temp:4.310582160949707, shape:0.2874298393726349353"
322 | ]
323 | },
324 | {
325 | "ename": "KeyboardInterrupt",
326 | "evalue": "",
327 | "output_type": "error",
328 | "traceback": [
329 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
330 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
331 | "\u001b[1;32mc:\\Users\\VINY\\Desktop\\Stabnet multigrid2-DeepStab Modded - Future Frames\\Stabnet_vgg19_16x16.ipynb Cell 12\u001b[0m line \u001b[0;36m6\n\u001b[0;32m 3\u001b[0m \u001b[39mfor\u001b[39;00m epoch \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(starting_epoch,EPOCHS): \n\u001b[0;32m 4\u001b[0m \u001b[39m# Generate the data for each iteration\u001b[39;00m\n\u001b[0;32m 5\u001b[0m running_loss \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m----> 6\u001b[0m \u001b[39mfor\u001b[39;00m idx,batch \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(data_loader):\n\u001b[0;32m 7\u001b[0m start \u001b[39m=\u001b[39m time()\n\u001b[0;32m 8\u001b[0m St, St_1, Igt , flow,features \u001b[39m=\u001b[39m batch\n",
332 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 631\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m 632\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 633\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[0;32m 634\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[0;32m 635\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 636\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 637\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
333 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:677\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 675\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m 676\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 677\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m 678\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[0;32m 679\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n",
334 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:32\u001b[0m, in \u001b[0;36m_IterableDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m possibly_batched_index:\n\u001b[0;32m 31\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 32\u001b[0m data\u001b[39m.\u001b[39mappend(\u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset_iter))\n\u001b[0;32m 33\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n\u001b[0;32m 34\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mended \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
335 | "File \u001b[1;32mc:\\Users\\VINY\\Desktop\\Stabnet multigrid2-DeepStab Modded - Future Frames\\datagen.py:42\u001b[0m, in \u001b[0;36mDatagen.__call__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 40\u001b[0m u_cap\u001b[39m.\u001b[39mset(cv2\u001b[39m.\u001b[39mCAP_PROP_POS_FRAMES, frame_idx)\n\u001b[0;32m 41\u001b[0m \u001b[39mfor\u001b[39;00m i,pos \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(\u001b[39mrange\u001b[39m(frame_idx \u001b[39m-\u001b[39m SKIP, frame_idx \u001b[39m+\u001b[39m SKIP \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m, SKIP \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m \u001b[39m2\u001b[39m)):\n\u001b[1;32m---> 42\u001b[0m u_cap\u001b[39m.\u001b[39;49mset(cv2\u001b[39m.\u001b[39;49mCAP_PROP_POS_FRAMES,pos \u001b[39m-\u001b[39;49m \u001b[39m1\u001b[39;49m)\n\u001b[0;32m 43\u001b[0m _,temp \u001b[39m=\u001b[39m u_cap\u001b[39m.\u001b[39mread()\n\u001b[0;32m 44\u001b[0m temp \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpreprocess(temp)\n",
336 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
337 | ]
338 | }
339 | ],
340 | "source": [
341 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
342 | "# Training loop\n",
343 | "for epoch in range(starting_epoch,EPOCHS): \n",
344 | " # Generate the data for each iteration\n",
345 | " running_loss = 0\n",
346 | " for idx,batch in enumerate(data_loader):\n",
347 | " start = time()\n",
348 | " St, St_1, Igt , flow,features = batch\n",
349 | " # Move the data to GPU if available\n",
350 | " St = St.to(device)\n",
351 | " St_1 = St_1.to(device)\n",
352 | " Igt = Igt.to(device)\n",
353 | " It = St[:,6:9,...].to(device)\n",
354 | " It_1 = St_1[:,6:9,...].to(device)\n",
355 | " flow = flow.to(device)\n",
356 | " features = features.to(device)\n",
357 | " # Forward pass through the Siamese Network\n",
358 | " \n",
359 | " transform0 = stabnet(St)\n",
360 | " transform1 = stabnet(St_1)\n",
361 | "\n",
362 | " warped0,warp_field = get_warp(transform0,It)\n",
363 | " warped1,_ = get_warp(transform1,It_1)\n",
364 | " # Compute the losses\n",
365 | " #stability loss\n",
366 | " pixel_loss = 10 * F.mse_loss(warped0, Igt)\n",
367 | " feat_loss = feature_loss(features,warp_field)\n",
368 | " stability_loss = pixel_loss + feat_loss\n",
369 | " #shape_loss\n",
370 | " sh_loss = shape_loss(transform0)\n",
371 | " #temporal loss\n",
372 | " warped2 = dense_warp(warped1, flow)\n",
373 | " temp_loss = 10 * F.mse_loss(warped0,warped2)\n",
374 | " # Perform backpropagation and update the model parameters\n",
375 | " optimizer.zero_grad()\n",
376 | " total_loss = stability_loss + sh_loss + temp_loss\n",
377 | " total_loss.backward()\n",
378 | " optimizer.step()\n",
379 | "\n",
380 | " \n",
381 | " means = np.array([0.485, 0.456, 0.406],dtype = np.float32)\n",
382 | " stds = np.array([0.229, 0.224, 0.225],dtype = np.float32)\n",
383 | "\n",
384 | " img = warped0[0,...].detach().cpu().permute(1,2,0).numpy()\n",
385 | " img *= stds\n",
386 | " img += means\n",
387 | " img = np.clip(img * 255.0,0,255).astype(np.uint8)\n",
388 | "\n",
389 | " img1 = Igt[0,...].cpu().permute(1,2,0).numpy()\n",
390 | " img1 *= stds\n",
391 | " img1 += means\n",
392 | " img1 = np.clip(img1 * 255.0,0,255).astype(np.uint8)\n",
393 | " #draw features\n",
394 | " stable_features = ((features[:, :, 0, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
395 | " unstable_features = ((features[:, :, 1, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
396 | " # Clip the features to the range [0, 255]\n",
397 | " stable_features = torch.clamp(stable_features, min=0, max=255).cpu().numpy()\n",
398 | " unstable_features = torch.clamp(unstable_features, min=0, max=255).cpu().numpy()\n",
399 | " img = draw_features(img,unstable_features[0,...],color = (0,255,0))\n",
400 | " img1 = draw_features(img1,stable_features[0,...],color = (0,0,255))\n",
401 | " conc = cv2.hconcat([img,img1])\n",
402 | " cv2.imshow('window',conc)\n",
403 | " if cv2.waitKey(1) & 0xFF == ord('9'):\n",
404 | " break\n",
405 | " \n",
406 | " running_loss += total_loss.item()\n",
407 | " print(f\"\\rEpoch: {epoch}, Batch:{idx}, loss:{running_loss / (idx % 100 + 1)}\\\n",
408 | " pixel_loss:{pixel_loss.item()}, feature_loss:{feat_loss.item()} ,temp:{temp_loss.item()}, shape:{sh_loss.item()}\",end = \"\")\n",
409 | " if idx % 100 == 99:\n",
410 | " writer.add_scalar('training_loss',\n",
411 | " running_loss / 100, \n",
412 | " epoch * 41328 // batch_size + idx)\n",
413 | " running_loss = 0.0\n",
414 | " # Get current date and time as a formatted string\n",
415 | " current_datetime = datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n",
416 | "\n",
417 | " # Append the formatted date-time string to the model filename\n",
418 | " model_path = os.path.join(ckpt_dir, f'stabnet_{current_datetime}.pth')\n",
419 | " \n",
420 | " torch.save({'model': stabnet.state_dict(),\n",
421 | " 'optimizer' : optimizer.state_dict(),\n",
422 | " 'epoch' : epoch}\n",
423 | " ,model_path)\n",
424 | " del St, St_1, It, Igt, It_1,flow,features\n",
425 | "cv2.destroyAllWindows()"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": null,
431 | "metadata": {},
432 | "outputs": [],
433 | "source": []
434 | }
435 | ],
436 | "metadata": {
437 | "kernelspec": {
438 | "display_name": "DUTCode",
439 | "language": "python",
440 | "name": "python3"
441 | },
442 | "language_info": {
443 | "codemirror_mode": {
444 | "name": "ipython",
445 | "version": 3
446 | },
447 | "file_extension": ".py",
448 | "mimetype": "text/x-python",
449 | "name": "python",
450 | "nbconvert_exporter": "python",
451 | "pygments_lexer": "ipython3",
452 | "version": "3.9.16"
453 | },
454 | "orig_nbformat": 4
455 | },
456 | "nbformat": 4,
457 | "nbformat_minor": 2
458 | }
459 |
--------------------------------------------------------------------------------
/train_vgg19_16x16_online.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import cv2\n",
11 | "import torch\n",
12 | "from time import time\n",
13 | "import os\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import torch.nn as nn\n",
16 | "import torchvision\n",
17 | "import torch.nn.functional as F\n",
18 | "from torch.utils import data\n",
19 | "from datagen import Datagen\n",
20 | "from image_utils import *\n",
21 | "from v2_93 import *\n",
22 | "import math\n",
23 | "import datetime\n",
24 | "\n",
25 | "device = 'cuda'\n",
26 | "ckpt_dir = 'E:/ModelCkpts/stabnet_multigrid_3/vgg_16x16'\n",
27 | "starting_epoch = 0\n",
28 | "H,W,C = shape = (256,256,3)\n",
29 | "batch_size = 4\n",
30 | "EPOCHS = 10\n",
31 | "grid_h, grid_w = 15,15"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 2,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "from models import transformer\n",
41 | "class StabNet(nn.Module):\n",
42 | " def __init__(self,trainable_layers = 10):\n",
43 | " super(StabNet, self).__init__()\n",
44 | " # Load the pre-trained ResNet model\n",
45 | " vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n",
46 | " # Extract conv1 pretrained weights for RGB input\n",
47 | " rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n",
48 | " # Calculate the average across the RGB channels\n",
49 | " average_rgb_weights = torch.mean(rgb_weights, dim=1, keepdim=True).repeat(1,6,1,1) #torch.Size([64, 5, 7, 7])\n",
50 | " # Change size of the first layer from 3 to 9 channels\n",
51 | " vgg19.features[0] = nn.Conv2d(9,64, kernel_size=3, stride=1, padding=1, bias=False)\n",
52 | " # set new weights\n",
53 | " new_weights = torch.cat((rgb_weights, average_rgb_weights), dim=1)\n",
54 | " vgg19.features[0].weight = nn.Parameter(new_weights)\n",
55 | " # Determine the total number of layers in the model\n",
56 | " total_layers = sum(1 for _ in vgg19.parameters())\n",
57 | " # Freeze the layers except the last 10\n",
58 | " for idx, param in enumerate(vgg19.parameters()):\n",
59 | " if idx > total_layers - trainable_layers:\n",
60 | " param.requires_grad = True\n",
61 | " else:\n",
62 | " param.requires_grad = False\n",
63 | " # Remove the last layer of ResNet\n",
64 | " self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n",
65 | " self.regressor = nn.Sequential(nn.Linear(512,2048),\n",
66 | " nn.ReLU(),\n",
67 | " nn.Linear(2048,1024),\n",
68 | " nn.ReLU(),\n",
69 | " nn.Linear(1024,512),\n",
70 | " nn.ReLU(),\n",
71 | " nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n",
72 | " #self.regressor[-1].bias.data.fill_(0)\n",
73 | " total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n",
74 | " total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n",
75 | " print(\"Total Trainable mobilenet Parameters: \", total_resnet_params)\n",
76 | " print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n",
77 | " print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n",
78 | " \n",
79 | " def forward(self, x_tensor):\n",
80 | " x_batch_size = x_tensor.size()[0]\n",
81 | " x = x_tensor[:, :3, :, :]\n",
82 | "\n",
83 | " # summary 1, dismiss now\n",
84 | " x_tensor = self.encoder(x_tensor)\n",
85 | " x_tensor = torch.mean(x_tensor, dim=[2, 3])\n",
86 | " x = self.regressor(x_tensor)\n",
87 | " x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n",
88 | " return x"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 3,
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "name": "stdout",
98 | "output_type": "stream",
99 | "text": [
100 | "Total Trainable mobilenet Parameters: 9439232\n",
101 | "Total Trainable regressor Parameters: 3936256\n",
102 | "Total Trainable parameters: 13375488\n"
103 | ]
104 | }
105 | ],
106 | "source": [
107 | "stabnet = StabNet(trainable_layers=15).to(device).train()\n",
108 | "optimizer = torch.optim.Adam(stabnet.parameters(),lr = 2e-5)\n"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 4,
114 | "metadata": {},
115 | "outputs": [
116 | {
117 | "data": {
118 | "text/plain": [
119 | "torch.Size([1, 16, 16, 2])"
120 | ]
121 | },
122 | "execution_count": 4,
123 | "metadata": {},
124 | "output_type": "execute_result"
125 | }
126 | ],
127 | "source": [
128 | "out = stabnet(torch.randn(1,9,256,256).float().to(device))\n",
129 | "out.shape"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 5,
135 | "metadata": {},
136 | "outputs": [
137 | {
138 | "name": "stdout",
139 | "output_type": "stream",
140 | "text": [
141 | "Loaded weights from E:/ModelCkpts/stabnet_multigrid_3/vgg_16x16\\stabnet_2023-10-26_06-34-42.pth\n"
142 | ]
143 | }
144 | ],
145 | "source": [
146 | "ckpts = os.listdir(ckpt_dir)\n",
147 | "if ckpts:\n",
148 | " ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n",
149 | " # Get the filename of the latest checkpoint\n",
150 | " latest = os.path.join(ckpt_dir, ckpts[0])\n",
151 | " # Load the latest checkpoint\n",
152 | " state = torch.load(latest)\n",
153 | " stabnet.load_state_dict(state['model'])\n",
154 | " optimizer.load_state_dict(state['optimizer'])\n",
155 | " starting_epoch = state['epoch'] + 1\n",
156 | " print('Loaded weights from', latest)"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 6,
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "def get_warp(net_out,img):\n",
166 | " '''\n",
167 | " Inputs:\n",
168 | " net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n",
169 | " img: image to warp\n",
170 | " '''\n",
171 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n",
172 | " torch.linspace(-1,1, grid_h + 1),\n",
173 | " indexing='ij')\n",
174 | " src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n",
175 | " new_grid = src_grid + net_out\n",
176 | " grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n",
177 | " warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=True)\n",
178 | " return warped ,grid_upscaled.permute(0,2,3,1)"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 7,
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "def temp_loss(warped0,warped1, flow):\n",
188 | " #prev warped1\n",
189 | " #curr warped0\n",
190 | " temp = dense_warp(warped1, flow)\n",
191 | " return F.l1_loss(warped0,temp)\n",
192 | "\n",
193 | "def shape_loss(net_out):\n",
194 | " grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n",
195 | " torch.linspace(-1,1, grid_h + 1),\n",
196 | " indexing='ij')\n",
197 | " grid_tensor = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n",
198 | " new_grid = grid_tensor + net_out\n",
199 | "\n",
200 | " #Lintra\n",
201 | " vt0 = new_grid[:, :-1, 1:, :]\n",
202 | " vt0_original = grid_tensor[:, :-1, 1:, :]\n",
203 | " vt1 = new_grid[:, :-1, :-1, :]\n",
204 | " vt1_original = grid_tensor[:, :-1, :-1, :]\n",
205 | " vt = new_grid[:, 1:, :-1, :]\n",
206 | " vt_original = grid_tensor[:, 1:, :-1, :]\n",
207 | " alpha = vt - vt1\n",
208 | " s = torch.norm(vt_original - vt1_original, dim=-1) / torch.norm(vt0_original - vt1_original, dim=-1)\n",
209 | " vt01 = vt0 - vt1\n",
210 | " beta = s[..., None] * torch.stack([vt01[..., 1], -vt01[..., 0]], dim=-1)\n",
211 | " norm = torch.norm(alpha - beta, dim=-1, keepdim=True)\n",
212 | " Lintra = torch.sum(norm) / (((grid_h + 1) * (grid_w + 1)) * batch_size)\n",
213 | "\n",
214 | " # Extract the vertices for computation\n",
215 | " vt1_vertical = new_grid[:, :-2, :, :]\n",
216 | " vt_vertical = new_grid[:, 1:-1, :, :]\n",
217 | " vt0_vertical = new_grid[:, 2:, :, :]\n",
218 | "\n",
219 | " vt1_horizontal = new_grid[:, :, :-2, :]\n",
220 | " vt_horizontal = new_grid[:, :, 1:-1, :]\n",
221 | " vt0_horizontal = new_grid[:, :, 2:, :]\n",
222 | "\n",
223 | " # Compute the differences\n",
224 | " vt_diff_vertical = vt1_vertical - vt_vertical\n",
225 | " vt_diff_horizontal = vt1_horizontal - vt_horizontal\n",
226 | "\n",
227 | " # Compute Linter for vertical direction\n",
228 | " Linter_vertical = torch.mean(torch.norm(vt_diff_vertical - (vt_vertical - vt0_vertical), dim=-1))\n",
229 | "\n",
230 | " # Compute Linter for horizontal direction\n",
231 | " Linter_horizontal = torch.mean(torch.norm(vt_diff_horizontal - (vt_horizontal - vt0_horizontal), dim=-1))\n",
232 | "\n",
233 | " # Combine Linter for both directions\n",
234 | " Linter = Linter_vertical + Linter_horizontal\n",
235 | "\n",
236 | " # Compute the shape loss\n",
237 | " shape_loss = Lintra + 20 * Linter\n",
238 | "\n",
239 | " return shape_loss\n",
240 | "\n",
241 | "def feature_loss(features, warp_field):\n",
242 | " stable_features = ((features[:, :, 0, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
243 | " unstable_features = ((features[:, :, 1, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
244 | " \n",
245 | " # Clip the features to the range [0, 255]\n",
246 | " stable_features = torch.clamp(stable_features, min=0, max=255)\n",
247 | " unstable_features = torch.clamp(unstable_features, min=0, max=255)\n",
248 | " \n",
249 | " warped_unstable_features = unstable_features + warp_field[:, unstable_features[:, :, 1].long(), unstable_features[:, :, 0].long(), :]\n",
250 | " loss = torch.mean(torch.sqrt(torch.sum((stable_features - warped_unstable_features) ** 2,dim = -1)))\n",
251 | " \n",
252 | " return loss\n",
253 | "\n",
254 | "\n"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": 11,
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "def draw_features(image,features,color):\n",
264 | " drawn = image.copy()\n",
265 | " for point in features:\n",
266 | " x,y = point\n",
267 | " cv2.circle(drawn, (int(x),int(y)), 5, (0, 255, 0), -1)\n",
268 | " return drawn"
269 | ]
270 | },
271 | {
272 | "attachments": {
273 | "image.png": {
274 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZoAAABRCAYAAADrXmCgAAAgAElEQVR4nO3dZ2BUxdrA8f8m2d0ku5tGCiSBhBJCIPQmSBEEBPReBUV9RQHFchWwt2tFEUQREeWKBa96EQsioCAoSIdESuiBhPReN2V7Pef9QBEUSEiBBOf3cXN2zuzm7HnOzDwzo5BlWUYQhLpzF7H6jR8JvqcDyf/Opve8jmy/5xVcz/yHPuU/4bhuHFUfPELRuKVM7ReBn+pKV1gQLi+PK10BQWj2yvayy+VN2s9p9PtkAvGeZeR7dcGveC0nejzCwBATeYf70trPG6Xnla6sIFx+ItAIQj05i9M5ziHofRsdfXRUHtxJUicHB5yDmNDeB+/Uzfw2sDdxkT6oRaAR/oZEoBGEenFyNGEj3SPHMKiLP94qPft27MPjUDR3/rMrOh0c2bWBXr0708rHW/zghL8lcd0LVyVZcmAx6NGb7LikRhyGdB8ncW0PevbpTrhWhaczlxOJKm6aOYleIX6o3MdI+Kk3fVr743RLNGZVBKGpEoFGuLpITqzmakrTdrHi/Sd5bVUSOdWuxjtdcQYF1/ajaxsN3p4gpSTys3Is3WN1+KiAshwyu7YjxL6TNftKMNulRquLIDRVItAIVxdLPnu3f8en77zBf77ZRVGFtVFPZ8hJx69/PEEaHzwBg8FM6JSBxAX4ogIwGijKXMXLP7gZ1CMQjbf4yQl/PwqR3ixclVL+xx2PfYg0djZzJw2hfaDyStdIEP62xOOVcHXyABRXuhKCIIAINIIgCEIjE4FGEARBaFQi0AiCIAiNSgQaQRAEoVF5XekKCELT5cJSqcdgdeG+HLmZPgGEBvii9BBZDMLVRQQaQbigUrYvnMH8Hw+RU+U+Z1a/UhdEgEaNV11jguzG6XJit5kxmx0nyx49kzVv30GsX1NeqkbGbihEb9IRHO6HWIi6kcluHFYTRpMTlX8QOnU9rgzZhd1sorrajBNAqcZXF0CgT+OHARFoBOGCwhn91CuUuV7i0zXHKTK4OT3rLGrco7w86Qa6hGrwqksLxFxCZm4aB3atYcWq3WQVVGCvsmCVZJryxDZZkkhd8X88/uX/sWjbw3S+0hW62jn1ZG77hs+XZ9HpsVnc20NXx4IknMY89q9exqcf/0KG5MTSMpY+dz3JnJviGz3YiAmbwtXpxP+449EPkcY0wITN8kQWvT6LrzedoMzsPhUIfOnz6Dxev+c62gf74lmP3i5z+gpmPvcW6w6NZOGuFxgapqVpTC+VcZorqXR6ExTgK55KmzHZXUHqoaOkpGsZcXsvtNVpbPhuIc+tq2bYgzN5Z2z7Rm1FN90WuiDUg9vtRpZl3G43Un2fpYIHMP2V55gwsAMtfDxPzQO1sO/9Oby7cjc5VfZ6LZap6XAbT08bRYjGRIWxKS28aeLgl08x5c115MuAo5qCggIKC/VY3Fe6blc7GbfTQnVZAQUlZeit9Vuvz1luQ6fU0Of2XmgB/GPoN2w0U/vbsZeWYWqQOl+YCDQNRkZy2bFYrdiczX/hRFlyI0lSk+7GOS/JhqGslOPJaZjNVgoyUkjPyaHc7KjfgH7wUJ5483kmDGxHCx+PU8GmgN/efJn31h0h1+CsV4AI6zOC64MVuG0yTaaPwX6Y1Z9tIy+7kAqrjDvlf9w5/h+MHjWTzforXbmrnYuK7F188eI/uWXqI8zcWlCv0lRh4UR07U3kmVdkVKoAAgI74OOtbvRAUP/W8KnBqupKE46zX/dUotb6E6Rt/A9xxckSTlsVhVkHSCqQaNmuLwPbB5z5s+QwYzAaMNtqEYA81Gj8/PDTqBrue5MdmKsMGC12zn4QVWr88ddqUJ9nRNtt0lNud6PwDiDQ1xtVffqGLqfq/Sz79/v8fDwfE354H1zO3OnLuf75j3hoRDxhPvUoO/g6Hp3lRH5xLst/z6bSKiFTwLp5z6D0mM8TN3UjQudVt5Vv/Hoy9ol8zL5KPC4WaBwGisqd+IUE4atUnHUuCUtlCVWeQbTUqWmIxDXHgV9YXa1l5JQJ9PJVQLfpbPrWjxmLAxkYWv/yG4Qs4bRUUlplq8XBHnh6+RIY6o+6yV/OSkJirmfSc/OJ2bIcZ5eovx4iubCaTZgcnvgF6vgjT0DG7bBiNFuQvAMIOt/4i2ylpDQXh8uL0YN6nmzlNKL6BxqnnsxtX/POq0tJMlZhdHig1vjh37YbA+56jJl39KCuw1fNg4TTXMDBLWvZW+FP9+tvYWCk7zlHVKWt479LPmb5xkKMZjdKnQ6dVs3Zmy3Ksgub2YjJvz/3PPkYD4/rQWBDVdGSwa+LFrN09RbSDGZsCm90Oh1d73iC6XePo1+4+i9v8dKq0CeuZHtFAN36DqZXVAt8lM3gkSFwIA8vGcjDjVS8Mnwkj71YReXri1j3ez4mh4Rclc2Pb71BaOBspgyJJdTXow7BJoDeN91T82Gpn3HT/yXz6A8L+b8YDaoz/5IKNrwxkhdbf8yeR65Bo6rvVp4ukpM2Q/Q93Dm01cmXZJmyA7/hNWwhQfUsvaHIDiO5m+cx+Y2NmCorsbh98A8JwOecS/Vkb4PZ6EITdQ/zVz5Nf98LldiEyEb0+Tby9w5n1P3n+bsxi13ffsH7u1ry/MIZDDxzw3BQdGwNn377IxVDZvPB2LZ/Lhi3sYRiRwDKrg8wIpJGV/9Aowql05jpvNOpG98ueI7397dkyAP/5uXbBxJen6fHZsLtKOPwtq/4LcuT7mPu5trIv17BQV0m8PTcnvRo+R4ffF1Gtycf59HJAwg5uxxTLjtXfMBn6f6ER0c0XJAB0MQx/uV36TvkXeYsWEai92genvEv7uzXFr8LjTp7BNDlhrvQbP2Az9YWUH7deEZ1DcenubRsGpGywwRmvuLE/uRifksuxOyWkKsO8OnzL6KYN5d7B3Ug2KcuwaYJcZ8gaaeBa2ZM4JpTl7Qs28lJOUrHm5tKmAGF2p/2/5jF+r438Omj0/mubDyzV8xiRIuzDpIcmHIS+ea/b7LKtzuxzSHIALLNhEWhx9C3PS1qPrzWJJeJ3GN5GPO9GHZH6wYs+cIa6BH15NOTwsNFm6hejB/59wgySFby965hxa5ynFFjGBhzkbabAiRAERtJWMfIvzwRemrbcO11Y7hj1ECiW/g3QmU9USCjULRg8IChDI2/SJA5w5vogRMYzFG2rPuJxMwqXE1l/OAKU3a4i9fn/YthnVqh8TwVVKoO8Mnbb7F8dw5V9mY4vnUWKX8viaX38MDI6FOvyMj2JNZ/P5kRXa5o1c5PBtnDG3X/fsT9+a7soUIb2ZN/3jWVAZ1jCDhvAU2NjL1aT1V5AZqeXRqsV0h2Oyg/cZAcfR4hY0ZxGRozQEMFGtmMsSKf7IIWBHUZQJfwBim1yXOXHmT7uk1k6APpEdf+ohewVJZHSlkBztCWtA8JPtNtJrvsOBwOnBIovEJp2yqC1sGNMA1OrqYgu4gKVxRt20USWttYpoqm35CeyPt/46cN+8g0iHSj03xjJ/Lv1yczNCbkj5Ze2mYWzF7A+qN5GJ3NNdhIFO3diTzj9jOtGWSwH1zPmoeG08VhwOC4aAGXl8uOJfcIe1JUBHXrSMTp12UJyWnG7AQUCjS+flzTOfoiBdVAcuMwV1JaXoWtsX8Gsp2K0gry09TER3ljs9qpb4qR7HZgzN1NSnEOju4T6ekHyC6cTjtWR+MmMDVMoLFVo89I45jen6Bu7fl7xBkH+Sf2cfh4HrqwtkSHXaw9LlFamE5RkYXIllG0CfU587o5bQ+Hjhwg0wKekfF0jomlfWM07c05pB8qQx/cjvCoUC6lzRTQoQ89Wyo4tHs3h7LKcDbPu2ejCO39AC/OvIOe7VugOj0Cn7aGV56fz8/JhZicTWUCpoSlsoiiwkIKCwtPJjJITiyGckqKCyks0WM68481k5oVy4Mj/rgpy7g4nPgd3X2hYP9/+Tr9ynyK83E7bZSkJnFCpWFUXLszr8uOavSHl7M2DfDSoY26gZEx9TiPuYgDK1/nrkfeZkfJpb5bRnLZMJSf/P5LysoxOdy4nVYMlaUUlxRTWmHAerrLwF5BeVoqhw4oMRUkk7D9EOa6Vx3JZaXixG/8un0be539iPc4WY/CrCPs25/AthOWepReswaZg+UwVpCbk0KFLpDx0bVvjEkuKxZjJQYr4KnEW+NHQHPJUnMUkHkgkxNp4cSOir54C0GuoCgrh+LKFkRGRhPiByDhthdyOLmMKm1rons3bnXt+ekc0ZcT1DaKdiEtLu079gunTTt/dEv3kHx0JAPjWhLx1/yBs5xM0a3bXGAFHs1sra/Qa55g3qsOHpv5PQczKnBKMqSt4fV3Q9A+9yDXdwrF94qPbVWxbeFdvPVzGQXlbsZ/vIkXuley84dFfLZuD2nGaO6d9SHTBoehVqjpdPu9hJ+diiRXUaoPI2vldO5+YDFbJl+xD/InMg57BccO70WluZUubU/2B8uSC4O+hMQd+bScUo/SJRcOuwNUPngho3DasFfZMTo4mfHmsOFEhY+6pmxDF5V5u1k+71k+32nEK6IHk2bO40bNEdb98CUrE3JRRAzn/iee4ub4AJQuJ1ZXLom5BylZombarKn16j4z5iWy9uvZfLYJYD0/nvlLNP1umMLjrzZu3lkDBBo3hqp8MlJzCAicTGxk7QZnZLeD0hMbWf75bL7fLmHVRtLj1mm8+eAIQupSK8mJ0+XC5emDT30TbmpBrigktTSf5JY6+kT44Xexg616CtL1FHjpaBcIjqIiilxGCg/+yPJUJzFDejKiUWvrpjD3IBUVnrQd2IZQ/0udd+5HUCs1vn5HSM3NpqyqNxFhF/knOUwUlegxW+vSv6IjLKIF2oZM774MQq99joUvy8yYuZxD2VW4JBnn9v/yrELNwpn3MjgqGJ8rGkCDGDNzM4OufYr+D5cSq97H6v2xjHvgQ268ayMv3fIwi594l97b5jJUqyI86tyBDoVHCP+YvYt/XKHaX5DsxGnOIuWICs9ObdFYiyiySDhNhST99h3LT/ThzboOysgOjGUp7Ni0D3eXmxge7YU2KAo/L4lArRNjSRYH9yaQ4tmfu8fGcfE7n5IWbYfy4BtfEfvVu8z51QdvWypHiWfS84u4PnEpb774LUv/E03se1Ppqo2i/x2z2X5HHev+J/5thzN51nAmz2qY8i5VA8yjMVJdnEN6RhCB47rSrlb/VAlrQRKbdx8id9T37HjJwZbVn7AkaRsHikYw6pITIWRcZQc5mJVHWfh4xra59I9xqSyGCkzVZaANQ+Wn5WIP+I6SXFJKCijMzOfnN4+wQQEg4bAa0Q26kyHBQX9anFDGaXcgKzxRquo4L+NsUhFZx8ooJZoubcMI/HPXnNuK1e2Fl5eS82cwawkK9sVX40FCfiH51Qa6hwVduF4lO1n81hdsTcqrQ2WvY8aChxh7TVSj5/Y3tNAhT/HWU0aefmsNR/NNJ4PNtsU8OzeQT+bcRa8Wmiu8jIuZ3xN+RvLSsTmzDZ9O7ngy+AUM5vpR8N2nezieB0PjGu6Mcl0m/SoUKBSK2l33LgfmzCMkpJuxVizhifFLTp3XidVTRch9d/4xZnNWnZw2KwpfzcWX+nGYqUrdwrIvllPikUjajClcrw6hm5cNZfUuVn61gI9+ddLhOi3X3xBHuxofcN1UG/JJTt5EhaM/hfa2TBwaiRKJoKhu9Or2CatKUsgtg66XJxnssqn3dS9bjJTnZZOp9KdvfHtqN4/LQF7yAXI2FHHtnDbgD8Mmv82wujbHJTOZx/exe1827W4fX8dCLo3JWICxOgeUvfFUnjsn5k+Vo6zwOKUlDobeM4tHp99IrAZw6Dm0fhGJ1WGEB57b7yY7Kzi8JwVJG0p8z5ganpRqJlcVkJ1VgTm8L60jw/40PuPCmvobm81tiO7QnS4XyKv20YWjVPvhqjJgttpwwYV/pK3HMGvRmHrWumZ6vR6Ho36j0n5+fvj6+qJQNERrQ0n0jTN502ri8fkbSCu2IXn54uusxmqVTs74v5KNGmcyieslVH0e4sV7uuF9poVVTWUR4NGGsAZMeJRlF1X5mZTXZi7lGQo8vHwJiYrArxY9E26nhcLjeygPjmX60rVM7QjgxqpP5bdv5pLQts2516ksYSnPZ++GLUTePZkOFytcHUibwY/w369HkLjuO775fg4HIzvjKxXz/WcbqfAdxasfj2dI51b41qYXRTZjKMni+LEIYsdOZNKotqfq5sRlN2Gs9EalCiNAU4uympl6Bxq7UU9+5nHsflH0bP/nZ4e/kt0ObGY9ZZUGzA4rVWVFlPl5o/UPxMfTjcNiwmCw4MQDpVqDX4D25KQ02Y3dYsRosJ5Z4tpH64e/0o0lbz/bt27jp/RQplxXRoVWjVLpwmSyo/T2Revnh5fLjMloxCqp8NHq8PdVITstmKxWJA81XpIdo8MTjU6Hn9oDl82EyWjC6gJPtS9anR++yj/uEgqFJwpFLTp3pFJy04vIc0XSJS6KsNMXkaoF7WNHItm90J3+cctOrAYjpYdW8XuWTFDHKGLNDtT17EYyF6ZzvEJPeGwEbYP/aHK6bVVUm9JZvzgbqasvHTuasbt9UZ9nPEGt0uHlqQaLDZfTWe8MmIbw+qsz2Z+UVK8yHn/qKUaPGY1G01C/biXtbnqS+/fvY87aSqSuE5nz2oMMiNChvMLDNK7jCWzV+zDyjVvp6PnHPB/ZkcnB38GzdXdiWjVUJWXcjmK2fDiFBdsu5X1KfAOG8czSmYwIrulYN3ZbKUcPpKDWPUiPM/MSPfHRhjP42uFU6U4/+p6ctGnIzyTlyCreqB7C16XV2IL98b7Yj0vhhTqoIwPGP06HPofZsHwBX1h13DpiDs/1jCRQo8G7ll31srma8rRj5Gjb8M+xQ/9ILZasGMqLyMzUoRvajsjGmqYkn31/BbxU+Oj8CfBu8tsEODBU5pKeWkFA4P8Re544I7tduF1OJKUPKg9w6lPZ9M07zP7iIBaLE+X0W9nYfziTn3mNYdpMdq9axpIlG0h3aYnuNZb7XniIoVFq5Ko0Er79nMWfbafI04UlIp4hk5/m1Z4WNn73EZ+sPILVDu8+UsmRYf2Jikrlv4s3ET5gHP96+RXiin7ki4XzWFkexw33PcULt3SiYvfnfPbLLqythtHOkMCnR3VMePg5nu/nzfHty1j25XJ2pTvw6zaCCQ88zp29W565qNS+LVBrwsAiIUsnV/Q9309UriokI6OAQp9IhkeGnpMC7fIKwl/lScDpJos5na2fLGbxF1vJ91ThrTlA0bT7mXR3L2r8zV2QndysI5TrA4hu3Y7QwD9+VSVJH7Fg0VI2HvNCtduXTZWPM/2eW+hznpUCZFkCZPD0wKPGp//Lkwzw3gfv1/Ec5/LwaLjRILetmpSf3+X9DSbcXW7n9denMSDCD9WVzgVAIvPIZorUE3jlGv+zlqiRqN6xmlWVvnScMZYuf6qnw2gAXV32nVHgpY5k/NwEGq2PQXJg1x9m33YVPpO60P6spossK1Eo2hDb6vQrDqoLdvDhhOf4vtKOssUWHtp+A48vfoahF7qxyxIuWxXFWQdJ2LiWXw5k4N1zMNeE5VGy5XkeXtOLcWNuYVivGMJbaGp8kLCaK8jOzsUj5J9c2/n0j17GVVVC1oHf2e8fyY2j+3H2YjOS24XL4cDDp76rZ0s4TbkkrfqKjxevJ0NyYg3vTL+JTzLnpi6NHmzqV7rLQEV2GsfTdQRMjaL1n0uT3ViK00hPz0DqcRM9/UEV2pWbHltITMevWfm/o3Sa8yHj2oJkymDv6k3syRjAG4kz8cv4hW8/+ICv54fT5t0R2Lf+wqYtJsZ+tpEpUbms+fI9PvvxC3b1nMeEGbNpFfQhP+z2YOQrb50co7Fk0tlDxer9AGqi+k1k2vMy6lU/U4Gd/IMrWfHxZ6wt8SPkGhdx7Xpxf6syZC8LKQk/kZArcdP8BOZ67eerTxey+ptPadnyBUa3ORlpfHyD8NMG43GiCou+EjOtzzOmIGMsziKnOJ/QsD7EhZ47wBrQodO5c2+0cYx58D4M5jiMkb24+f7+56wecLrMk/dWBbXq7XEXknlYT5m6Df3ahnB2HkD4tc8xzVxJq/JYrhk8kYGtLzzSZLNV4XTZ8GoVhJ9Wc9ExqcuVDKA41ZffVEgOI2nrXmXGm1up9h/IE8/PYHi0P/XZq6rByJWkHTqB500P0Mn7j/EPtzWTld/+hFf7+3n+1thzH5YcBrZ/8RmtZjxBk5yj6bRjTE1ih9qbQd3OnYip8Nbg32M4/c68oiYwajAPfT0X5QO76PXDa+euHnA+TgNFSd/ywotfUhJ8HZNfeo+BFVv5zyYdtz36EMXLZjP31af5acD9zHnrTjpddMDHidVYQn5lFX6DY2hz6lhZslKUe5idv2fToec07hx0Vtau5KIyO5Xk5BNE/XMc51ntrNZkdxUZJ/IoUY9lUeLLaKtP8Ot37/P80nd4XTOTd8acvU2AjOR2I0ng4eXVIOvm1T3QyBKOymIyjh8kJTCA8V1iz0q/k5HdLuxVmRxITiKpqgsPXrTvV6aq8BhF9n0Ej7sFZVExVs8Q2rSP4fCeXSTnjiTOR0vLiDAiAlRYPFTotJH0wFjHynsT3X8S054D1QcrKPaO5cb7RhAOyCW7+XynityQgdwsF1PsDKRdTAxBlaUcyC7khjatUQCqwDAiQsJpaTdhNzuwwV8Cjew2kZNxnJw8HyJGdCI8pOY2tttQSl4EqNsF/TVwyU6sVXr0VVYUmkCCAv0uvv6Y7MaRl0JyTiFS9CjaR4ScW6ZURVGBGx+fAPy1FwsdMmZDEQ6riZiQYIK1NUz0KdnJh3O/YGtS7sWPO69hPPpe80sGkJ1mcrYv4rW3tlOtG8Dj82dxa3yL2vXdXw6W/ST8KBPzUgeCFApAxmWt5uiK2Xy0vzv3fT6dAZrTdxQnxiI9puQPeMR4HduLKrGGBl6WbM7ak3FYqjiStBtPZTTXxtZiWoXTgb00n0L/toyrzZouSl8COw5l6rNt8e41gmv89RxZl8tBp8wNdOaGBz+g47W/c9QVf05r6rzcRqrzU8k44Eu7sZHoANntpDr/CDvW/MgBj2FMfeRm4tQAEi67FVPuUXYc2skPlsEsqDBhD9DW+aHl9DYBve/ofWqbgI70HzaaqeX/43hJGSban8mclRzVFGWnk12upGWnONoFqeo9tFiHQCPjdtow63M4nriOlT/uxEs7AJ3bTlFR0clD3A6sZen8/ttKNpxwc91jt9YwoO3AbDJxaGMCaz4ez5KzLujg7tdhNSvpOHISkQOqqcjbzopln/Ph8sO07DWavpf+Ac4VqUPTMejMl2wzG7DsXcPaDUvYrj6rIlH96WQxYgV8AXRhtO8YRSy7qM4so9IBZyb0y06sRhMV6dv5bf0e9hu0DPUHh9GI1VOLzwX3/5UpyjmC1seL9tGRf/3ODMdZtehdPlq2E2efu3js8YcY3zviPN0yEg6LAWNZNolrVpCYfAx1vxG4qywYzN5/tBaqU0k+1h7/AR0Jv+jiatWU5NswmsPoGNmGYL8a0hNaj+GN/zR+MkBTITvNlO77iJde+p5U3y5Mmf0y47o0oSADOI8lsk4B3U3FZOQHEOBlJWf7B7z8oYNx785nWlfvs24mqfxv/L18UmREF7aV238eyPOr5jG2SazYLCO77VQXFpKT/ANfrqpAoRqGj6uEUnMQoZoL3/FddjOFhQfJuWsGtZq3qVChDe3CdTeebM+5jS6sphLcngpMZvAKCSKm79halSVbDJRmH2Wv7GR0VRFZ+TKK6nQS1n7H6tT23PX8NG6OPl13C4UH1vDhI2/zi8mFKnAzT+8dz7RZU+lbx1RtVVg4EWFnT6WXUakD8P/LNgES+rRNfLPgRT4+0ZWb/vUcs+7qU++HvjoEmpP7JHz77st8dwCgJdizWPbMeJb95dhQ4gbeSZ/2NeVNKfBQeBDZdxjPvrGQ2zucrpaE2y3hdkkYMnfx3eeL+FEfxz0z3uTT2BWs/bno0qtfU00UHig6Dee262/hhZHRZ7LJJJcLSZbPyi4LokP3PnS/JoGtFUfJLh5FzKluNSzZ7PryMz5bto08wBvYveRVCjMnct/9UxnZ7gLfh2ygrDiYQHUYUS3Pc4zNTLXLC/wDcew/xtE9aQzoFkHUXzrQjZzYsJBPlv7C/gKAlpC4nNmJ2Ux87gHuHBdPIOCuLCK7T0t6xYVffBFPdzGFR41YNYPoFh9Ny6t7Oe5L47ahT/ov/37ta1Kcbbht9mwm9wlHd6VH/s/hJuXgVmytopD3zeXBpSbkwCC6DLmTOWvuok/wnxcAjWfazv9Q0eEzuq75mFuaRIA5zYW5fA+fTHiOnwC8/fHmN96ZdpDwB99h+YTYC7xPxmk3YixM5x+j67Y8gIfSh+B23Rk8SEH4RSfO/fXcZqOerJxjOFXeHF/xHJPec6KJiuHa0ROZ8/BwYs4pT0ub/mN5YL6alp/m0+/DRxnY0Au0yVZKSnJxujy44ZxtAjxQa8IIjYwj2h5JeEQw3g1wujoEGiUhMSOYsXgEM+p6VvnkALpLlpEkCVCh9Q9EUpnYeegYN7bpikalQHZXkJdfTEmWmeKD69mgH8HLbz7MNQFlHEh2Yj9f2ZILl0tCQnEyaMgybkkGyYXdZsNsdmA1GTFYnef98N5+QQSW2EjLOUT6oAg6+ihRyC6Ks9IwOiVad+7C6Y4jnzY9GDToGo5szudASiaDImNOLk+uiWHEjLmMuMQvSK7O4MQxC8ZWOjReDhwuJaqzWz9hA3jwhV5MfuIEvy7cSIWHzPnTv/yJv+VV3r/l1YuczU328b20cHYkROWJ0+nEU6k877iIKy+No2Y7rUcOoGeH1lyF2Zd147ZRfWI1b876iiMlkYyf+w6PDIrE/8qP/J9LziN5TyGqMYv5z1ND0XrW3P8i5R/nhHYbT7cAAAvKSURBVDaW8U0qyAAo0YYN4dnfE3n2Ut4m2zDrM0n4uT8DJjoxW0Dje2kTlxXeLWg35CFeH3JJbwPsmKvzyC7REn/f2yy5o2uNW3XLNhPGygoq/SOJaPBVQE9vE+CPsttDjPxTr6Nf9CAmzRzEpAY842UfppTdDqzlBeTlZXPCaCQzPQt9hQmfVvH0C49F8cOHLNmaTnFRMQUpGVQU5KNqF4BSZcRhPUx6ZjHFmSkczUlhn92KsboKi0uBp9ILhdtKVU4yx1OPkuLyJzA4HF+7ieyMDIozfufXTbv4ae1Rtq5cxqc/76Ok2oDZYsdiNmK0Ok9OLGsRx4DBIWiMP7F0xXayioopzj1IWnU1ZV5tOWd0wjOEroPHckOElsqdiRwqttdrF0dLURapLR1UhelJPnCUI/l/nYCg8AJnWSEl7cIJ6hNDeJ0fNwpI3e6HsjgQ46EdHM1Ip/p8h7lNHDuwg+LWPRg+qi8xoWLneADcdgzpPzLrmflsLQxj5Euv8fCQaAKaxMj/ueTKZBK3ezKoUye8a5OSD5QmbeDI9FFNMgmgThwWDAWH+K1XNNrKJFb8kn35zu00U52XTV6aHz2jo2oMMgAOQzn5jhzK/tGrXkkA5yO5TOQez8eY58Wwoc1qm4Dac+pT2bRsPq99kUCWPoMfX32c5+es4qgjnO43TuHeW9ty6JVJ3DZuHNMXriJFN4iebaLpPnAEA5UH+Wj6OG57/SM2VAQSm5HE8sXzWZOnI7JrD4LZy4eff8W6qjC6+QUR038AcZ0q+OaFe7jtzZ/JCO3NQw+M4M6HHmRC6zRWLVnGtm0FJK1bwZebUjm5rJw3bQdOZOq4a/Hc8BKTxo3jtme+Ym+Zjl4d/zoI7tmiJzffN56+UXoSN+4g12iv81L6Ci8f/NJ3snr57+TY/egc/afuM8lOVd7vrN1bRqt2vbm+R2StLtrz88TbX8+eFXP5rVTGOyzur91nkgND5lZ+TQmj76BbGd7pPONGf0eyE3Pmeua+NJ+thV70f/oVnhgdf/6dDC+5aDMGqwN3Q+3nLDupTtrFb1JX+ncOrmUGkZOs40lMHBDfMHVoEjzwcrtwbf2Gd17JYOD4eqyueUlkXMYK8lIzOKHpTNfo2vS5SVhNFXgYixkeV4/Vps9XG7cDfdohcstzCBlzw2XbJgBZaBCuqhPyrjX/k79Zt0vONkuNc47yLDk7L1cuNrsapfyzSU6rbClOlL/78gt59a4Tst7e6KdsHiSHbCrYIi+Yeo3ct2c/+V9f7JFLTc6GKlw2JL4vP/tNglxgccjnvYqOfyKPGf60/H26WXacc4BeXv/acHnYp3tks8Mty7Isu60VcnHOdvn9ezrKsW2nyV8dzZPLzM7zl3s2xy75xdiH5R9LHLLB6GiQT/Z3JLmdsrU6Tz685RP56bG95RH/mCf/mlMuG2zui7/RXSmf2PK9/PakhfKOcodssVzg+jJkyjuWzpWnPPZfeW/V2X+wy4VHf5LfmTNNfmlzzh/1cdnl6qwd8vZN/5N/zT3zouxw2GSLvXHuWacpZLmhHp0EcFBtsiPLagJ0jbCnzGVky93J1hxo2b4bXcL96tFyuorILixle/nkxadYts9F+0mzefu+IbT2VzfIyjKys4RVr97Igsj5rJ5yLS1869dCqtj8bybM3IHd8cfmKT1f+o4FY1vjdbGmzYE5dFgUzI9vxLPxFz8ev/dqatlcPq6qLBJXvMm/lxw585ouOp4bn5zL9H4Xya82ZLJ980oWpYQz7cYOVBaFcsuo+rVsJJeNqvStbN6TRE7L25nY9VRGj72Y7JJKqnz7Mrpr400oEIFGaMYk7GYjFpsNpwtUukB0PioUTjNGswW7rEan0+Ct9EQhu7CZDZhsEiqNP1pv5SVORHNjLz/I0rlPs2RrJa3ueIsFD1xHVEBDBBkZl92EPmER9z/+DTGv/8RrY6LQXamkgqMfcO39y1AwkY9/n3H1jNM0F6Ycdq/9mCffSyF+wBReWPDPeo/TVGdtYtWSF/l005//Ek3/G+7jydcad7dNEWiEZsxA4tJ3+PaX9Ww/6GD4y5/z6KjOKDN/4JNPPmZ9WR9mvPoUt8SHojRlsuWrt1m0Kp/4f73FY2O70UpT2yFKN/bqE6xZMIN3VuSi+8eLzHpkDHFBvhdZTLUWZDdOpwOn1UjGziW8NWsVx40deHr1Uu7pGlDrNbQEoakTKURCM+bHgHteo3NHf6bOqiYu3Bv9kS242oxi4oQUDs7eSHL63QxpUUFOnppet97LyF9fYeXRNO4e2pFWmtpsZerGYc5g7XuPMn9FJuWmKDq6jrP+2xw21rfBYS0hNXk/R/aXYnGcWk5f3ZXwUC8acPk1QbjiRKARmjkbJ5J/x957EIH6Qxi7XU/3iCCK9qbhIJ6oIDtpud507xJJoPE4WU4XUcFBaFS1GXWScVkz2Lz4Nd5dmUuZSQaySfg+m4TG+jiDY4jSenHBxSMEoRkSgUZo3txZHE2wEtGqgkKv6xncIoAArzx2HjFij+uBymSjbd94/HQ+VBzYT6ZdR9+2kWi9axNonGSs+ZCP1mfg8NQQGNT4U1UVXTvQSunVrHYXFYSaiEAjNGtSdjJJxZChCOC2ri3xC1Ahl6Sy/6AJbXsL2tadCVHrUMnVpB4+gDmsLz1jg9DVaqKrC138VGa+fTfumg9uEKrwOPzVYnBGuLqIQCM0YzIluckU2iV6jBhKXEgQ3oAh+wiHLBZKNJ2IaaVBqwKsKRzaZEbuEU9bf00t12/yJbxzV8JrPlAQhIsQLXSh+ZKNZB0/ijW0NwO7tiJI4wG4yc9IwuqM4LbBfWgTqEUBuLOOsMMq0zOuHQE+Yn0DQbicRKARmi97Kke3mNF06kGHFgEnl8dxZ3JkbzW2zkPpExuM/6n9PbJS92Bw9aJbOy2SS8LlurJVF4S/ExFohGbLnZ3MLqsv8Z1jCDm1R45UlMLeNDsde/YlKsj/VN9wGTkphdi6hxNoSifPaMIiZo8JwmUjAo3QbBUVZOIZ0p1uHVugO9UbZirJocq3Oz1iI/A/s/OYhCQHEu4upzKwJ11a+eMn1tQRhMtGrAwgCLUk2Q1UGiW8/fzwVf15s7BLLg2bsQKj5Eugzufia48JQjMnWjSCUEuGpI+Zfu/rfH+4DFN9xngkG4bydFa9OoZR//6KVIP1/PvXCcJVQqQ3C0ItBQx8hm/WNEBBFTt4f8o8NuabMPcRLRnh6idaNIJQE8mOUV9OWWkpRpsbqb6dzcEjeWntBlbOu5Fg/+a9nYQg1IYINIJQk8oEPn74bm4eMJx5m7OodFzpCglC8yK6zgShJi2G8fTXQXjfmkDPjoFo1SdflhxmjDYHUq2aOJ54a3V4e9U3iUAQmh8RaAShFqTURDbeEs/oIA2n4gyGw18xe+VeSo21aeJEcvOMZxnTPgAfsZSZ8DcjAo0g1EiirKCAnqGDz9leIKDPQ8zr89AVrJcgNA8i0AhCTaQy9u71oPvNgWjP2vby0rrOvPDRalGLrjPhb0gEGkGogVSyhwQpnrsCleB2I3l64qEAw+GlJ7vODLXpOmvPbU88wci2/nifnYLjOJnFJmZNC1czEWgEoQYVxxJI0vTm5uLt/HCiD//sG0mQ1pOAPv9iXp9/1aFEF5aKaiqrbbi3pZFWWEKkbxsC1J6itSNclUR6syDUoLI6B92vb/P4F5V07xyEv7a+o/npfDd5Ig+9tQcP9SbmTbmDT/ZUYHE2SHUFockRa50JgiAIjUq0aARBEIRGJQKNIAiC0KhEoBEEQRAalQg0giAIQqMSgUYQBEFoVCLQCIIgCI1KBBpBEAShUYlAIwiCIDQqEWgEQRCERiUCjSAIgtCoRKARBEEQGpUINIIgCEKjEoFGEARBaFQi0AiCIAiNSgQaQRAEoVGJHTYbnITTYsHiApWPLz7KJhLLZRc2swUnStQ+3qg8xV6OgiBcHk3kLng1KefgN+/z9NPzWXq47EpX5g/GZNbMfZpXFqwgqch+pWsjCMLfyP8DJPZxUCjVta8AAAAASUVORK5CYII="
275 | }
276 | },
277 | "cell_type": "markdown",
278 | "metadata": {},
279 | "source": [
280 | ""
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": 8,
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "class IterDataset(data.IterableDataset):\n",
290 | " def __init__(self, data_generator):\n",
291 | " super(IterDataset, self).__init__()\n",
292 | " self.data_generator = data_generator\n",
293 | "\n",
294 | " def __iter__(self):\n",
295 | " return iter(self.data_generator())\n",
296 | "generator = Datagen(shape = (H,W),txt_path = './trainlist.txt')\n",
297 | "iter_dataset = IterDataset(generator)\n",
298 | "data_loader = data.DataLoader(iter_dataset, batch_size=batch_size)\n"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": 9,
304 | "metadata": {},
305 | "outputs": [],
306 | "source": [
307 | "from torch.utils.tensorboard import SummaryWriter\n",
308 | "# default `log_dir` is \"runs\" - we'll be more specific here\n",
309 | "writer = SummaryWriter('runs/vgg_16x16/')"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": 16,
315 | "metadata": {},
316 | "outputs": [
317 | {
318 | "name": "stdout",
319 | "output_type": "stream",
320 | "text": [
321 | "Epoch: 1, Batch:21, loss:5.301943264224312 pixel_loss:0.03195870295166969, feature_loss:7.603750228881836 ,temp:0.19699248671531677, shape:0.25114700198173523593"
322 | ]
323 | },
324 | {
325 | "ename": "KeyboardInterrupt",
326 | "evalue": "",
327 | "output_type": "error",
328 | "traceback": [
329 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
330 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
331 | "\u001b[1;32mc:\\Users\\VINY\\Desktop\\Stabnet multigrid2-DeepStab Modded\\Stabnet_vgg19_16x16.ipynb Cell 12\u001b[0m line \u001b[0;36m8\n\u001b[0;32m 6\u001b[0m \u001b[39m# Generate the data for each iteration\u001b[39;00m\n\u001b[0;32m 7\u001b[0m running_loss \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m----> 8\u001b[0m \u001b[39mfor\u001b[39;00m idx,batch \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(data_loader):\n\u001b[0;32m 9\u001b[0m start \u001b[39m=\u001b[39m time()\n\u001b[0;32m 10\u001b[0m St, St_1, It, Igt, It_1,flow,features \u001b[39m=\u001b[39m batch\n",
332 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 631\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m 632\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 633\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[0;32m 634\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[0;32m 635\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 636\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 637\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
333 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:677\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 675\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m 676\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 677\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m 678\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[0;32m 679\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n",
334 | "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:32\u001b[0m, in \u001b[0;36m_IterableDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m possibly_batched_index:\n\u001b[0;32m 31\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 32\u001b[0m data\u001b[39m.\u001b[39mappend(\u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset_iter))\n\u001b[0;32m 33\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n\u001b[0;32m 34\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mended \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
335 | "File \u001b[1;32mc:\\Users\\VINY\\Desktop\\Stabnet multigrid2-DeepStab Modded\\datagen.py:34\u001b[0m, in \u001b[0;36mDatagen.__call__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 32\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m5\u001b[39m,\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m,\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m):\n\u001b[0;32m 33\u001b[0m pos \u001b[39m=\u001b[39m \u001b[39m2\u001b[39m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39m i \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m---> 34\u001b[0m s_cap\u001b[39m.\u001b[39;49mset(cv2\u001b[39m.\u001b[39;49mCAP_PROP_POS_FRAMES, idx \u001b[39m-\u001b[39;49m pos)\n\u001b[0;32m 35\u001b[0m _,temp1 \u001b[39m=\u001b[39m s_cap\u001b[39m.\u001b[39mread()\n\u001b[0;32m 36\u001b[0m temp1 \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpreprocess(temp1) \u001b[39m# -33\u001b[39;00m\n",
336 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
337 | ]
338 | }
339 | ],
340 | "source": [
341 | "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
342 | "# Training loop\n",
343 | "for epoch in range(starting_epoch,EPOCHS): \n",
344 | " for param_group in optimizer.param_groups:\n",
345 | " param_group['lr'] *= 0.5 \n",
346 | " # Generate the data for each iteration\n",
347 | " running_loss = 0\n",
348 | " for idx,batch in enumerate(data_loader):\n",
349 | " start = time()\n",
350 | " St, St_1, It, Igt, It_1,flow,features = batch\n",
351 | " # Move the data to GPU if available\n",
352 | " St = St.to(device)\n",
353 | " St_1 = St_1.to(device)\n",
354 | " It = It.to(device)\n",
355 | " Igt = Igt.to(device)\n",
356 | " It_1 = It_1.to(device)\n",
357 | " flow = flow.to(device)\n",
358 | " features = features.to(device)\n",
359 | " # Forward pass through the Siamese Network\n",
360 | " \n",
361 | " transform0 = stabnet(torch.cat([It, St], dim=1))\n",
362 | " transform1 = stabnet(torch.cat([It_1, St_1], dim=1))\n",
363 | "\n",
364 | " warped0,warp_field = get_warp(transform0,It)\n",
365 | " warped1,_ = get_warp(transform1,It_1)\n",
366 | " # Compute the losses\n",
367 | " #stability loss\n",
368 | " pixel_loss = F.mse_loss(warped0, Igt)\n",
369 | " feat_loss = feature_loss(features,warp_field)\n",
370 | " stability_loss = 50 * pixel_loss + feat_loss\n",
371 | " #shape_loss\n",
372 | " sh_loss = shape_loss(transform0)\n",
373 | " #temporal loss\n",
374 | " warped2 = dense_warp(warped1, flow)\n",
375 | " temp_loss = 10 * F.mse_loss(warped0,warped2)\n",
376 | " # Perform backpropagation and update the model parameters\n",
377 | " optimizer.zero_grad()\n",
378 | " total_loss = stability_loss + sh_loss + temp_loss\n",
379 | " total_loss.backward()\n",
380 | " optimizer.step()\n",
381 | "\n",
382 | " \n",
383 | " img = warped0[0,...].detach().cpu().permute(1,2,0).numpy()\n",
384 | " img = (img* 255).astype(np.uint8)\n",
385 | " img1 = Igt[0,...].cpu().permute(1,2,0).numpy()\n",
386 | " img1 = (img1* 255).astype(np.uint8)\n",
387 | " #draw features\n",
388 | " stable_features = ((features[:, :, 0, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
389 | " unstable_features = ((features[:, :, 1, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
390 | " # Clip the features to the range [0, 255]\n",
391 | " stable_features = torch.clamp(stable_features, min=0, max=255).cpu().numpy()\n",
392 | " unstable_features = torch.clamp(unstable_features, min=0, max=255).cpu().numpy()\n",
393 | " img = draw_features(img,unstable_features[0,...],color = (0,255,0))\n",
394 | " img1 = draw_features(img1,stable_features[0,...],color = (0,0,255))\n",
395 | " conc = cv2.hconcat([img,img1])\n",
396 | " cv2.imshow('window',conc)\n",
397 | " if cv2.waitKey(1) & 0xFF == ord('9'):\n",
398 | " break\n",
399 | " \n",
400 | " running_loss += total_loss.item()\n",
401 | " print(f\"\\rEpoch: {epoch}, Batch:{idx}, loss:{running_loss / (idx % 100 + 1)}\\\n",
402 | " pixel_loss:{pixel_loss.item()}, feature_loss:{feat_loss.item()} ,temp:{temp_loss.item()}, shape:{sh_loss.item()}\",end = \"\")\n",
403 | " if idx % 100 == 99:\n",
404 | " writer.add_scalar('training_loss',\n",
405 | " running_loss / 100, \n",
406 | " epoch * 41328 // batch_size + idx)\n",
407 | " running_loss = 0.0\n",
408 | " # Get current date and time as a formatted string\n",
409 | " current_datetime = datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n",
410 | "\n",
411 | " # Append the formatted date-time string to the model filename\n",
412 | " model_path = os.path.join(ckpt_dir, f'stabnet_{current_datetime}.pth')\n",
413 | " \n",
414 | " torch.save({'model': stabnet.state_dict(),\n",
415 | " 'optimizer' : optimizer.state_dict(),\n",
416 | " 'epoch' : epoch}\n",
417 | " ,model_path)\n",
418 | " del St, St_1, It, Igt, It_1,flow,features\n",
419 | "cv2.destroyAllWindows()"
420 | ]
421 | },
422 | {
423 | "cell_type": "code",
424 | "execution_count": null,
425 | "metadata": {},
426 | "outputs": [],
427 | "source": []
428 | }
429 | ],
430 | "metadata": {
431 | "kernelspec": {
432 | "display_name": "DUTCode",
433 | "language": "python",
434 | "name": "python3"
435 | },
436 | "language_info": {
437 | "codemirror_mode": {
438 | "name": "ipython",
439 | "version": 3
440 | },
441 | "file_extension": ".py",
442 | "mimetype": "text/x-python",
443 | "name": "python",
444 | "nbconvert_exporter": "python",
445 | "pygments_lexer": "ipython3",
446 | "version": "3.9.16"
447 | },
448 | "orig_nbformat": 4
449 | },
450 | "nbformat": 4,
451 | "nbformat_minor": 2
452 | }
453 |
--------------------------------------------------------------------------------