├── README.md ├── config.py ├── dataset.py ├── infer ├── inference.py └── iterator.py ├── logs └── beef │ ├── config.txt │ └── model.pth ├── model ├── _cdht │ ├── .gitignore │ ├── deep_hough_cuda.cpp │ ├── deep_hough_cuda_kernel.cu │ ├── dht_func.py │ └── setup.py ├── dht.py ├── encoder.py ├── fic.py ├── magnet_epoch12_loss7.28e-02.pth └── vibnet.py ├── predict.py ├── requirements.txt ├── test.py ├── train.py ├── utils.py └── video.jpg /README.md: -------------------------------------------------------------------------------- 1 | # VibNet: Vibration-Boosted Needle Detection in Ultrasound Images 2 | 3 | This repository contains the official code of the IEEE TMI paper "[VibNet: Vibration-Boosted Needle Detection in Ultrasound Images](https://arxiv.org/abs/2403.14523)". 4 | 5 | ## Abstract 6 | 7 | Precise percutaneous needle detection is crucial for ultrasound (US)-guided interventions. However, inherent limitations such as speckles, needle-like artifacts, and low resolution make it challenging to robustly detect needles, especially when their visibility is reduced or imperceptible. To address this challenge, we propose VibNet, a learning-based framework designed to enhance the robustness and accuracy of needle detection in US images by leveraging periodic vibration applied externally to the needle shafts. VibNet integrates neural Short-Time Fourier Transform and Hough Transform modules to achieve successive sub-goals, including motion feature extraction in the spatiotemporal space, frequency feature aggregation, and needle detection in the Hough space. Due to the periodic subtle vibration, the features are more robust in the frequency domain than in the image intensity domain, making VibNet more effective than traditional intensity-based methods. To demonstrate the effectiveness of VibNet, we conducted experiments on distinct *ex vivo* porcine and bovine tissue samples. The results obtained on porcine samples demonstrate that VibNet effectively detects needles even when their visibility is severely reduced, with a tip error of $1.61\pm1.56mm$ compared to $8.15\pm9.98mm$ for UNet and $6.63\pm7.58mm$ for WNet, and a needle direction error of $1.64\pm1.86^{\circ}$ compared to $9.29\pm15.30^{\circ}$ for UNet and $8.54\pm17.92^{\circ}$ for WNet. 8 | 9 | [![Watch the video](./video.jpg)](https://www.youtube.com/watch?v=lXzHw0crPaM) 10 | 11 | ## Installation 12 | 13 | ```bash 14 | # Create conda environment 15 | conda create -n vibnet python=3.11 16 | 17 | # Install torch (be careful of the cuda version) 18 | pip install torch --index-url https://download.pytorch.org/whl/cu118 19 | 20 | # install packages 21 | pip install -r requirements.txt 22 | 23 | # Install deep-hough-transform 24 | cd model/_cdht 25 | python setup.py install --user 26 | # if errors occurred, try this: 27 | # pip install . --user 28 | ``` 29 | 30 | ## Usage 31 | 32 | - Prepare dataset 33 | 1. Download `dataset.zip` from the [release](https://github.com/marslicy/VibNet/releases). 34 | 2. Unzip the dataset using `unzip dataset.zip`. 35 | - Training 36 | 1. Set model parameters and training parameters in `config.py`. 37 | 2. Generate the data split in `{data_path}/train.txt` and `{data_path}/val.txt` 38 | 3. Run `python train.py`. 39 | - Evaluation 40 | 1. Set parameters at the end of `test.py`. 41 | 2. Run `python test.py`. 42 | 3. After the test, the results will be saved in `{save_path}`, and it will print errors in the console. (theta in degree, rho and tip in mm) 43 | - Visualization 44 | 1. Set parameters at the end of `predict.py`. 45 | 2. Run `python predict.py`. 46 | 3. Output videos will be saved in the `{output_dir}`. 47 | 48 | ## Code Structure 49 | 50 | - `dataset/` contains data samples from our test dataset. The existing data splits are the same for training, validation, and testing; you may need to modify them. 51 | - `config.py` will be loaded automatically when training the model. 52 | - `train.py` can be executed directly to train the model after you set `config.py`. 53 | - `predict.py` is used for prediction and visualizing results using a well-trained model. 54 | - `utils.py` contains a lot of functions that are used by multiple scripts. 55 | - `/logs` contains the training logs and can be read using `tensorboardx`. Each training log contains the trained model and the config file used for training this model. 56 | - `/model` contains code about the model. 57 | - `/infer` contains code for inference. 58 | - Most Python scripts can be run using `PYTHONPATH='.' python '{file path}'`. The last part of each script shows an example about how to use the classes/functions written in the file. 59 | 60 | ## Acknowledgement 61 | 62 | Big Thanks for the open source codes from: 63 | 64 | - [Learning-based Video Motion Magnification](https://github.com/ZhengPeng7/motion_magnification_learning-based), where we used the code for the motion encoder and pretrained encoder model. 65 | - [UniTS: Short-Time Fourier Inspired Neural Networks for Sensory Time Series Classification](https://github.com/Shuheng-Li/UniTS-Sensory-Time-Series-Classification), where we used the code for FIC module. 66 | - [Deep Hough Transform for Semantic Line Detection](https://github.com/Hanqer/deep-hough-transform), where we used the code for the deep Hough transform. 67 | 68 | ## Citation 69 | 70 | If you find our paper and code useful, please cite us: 71 | 72 | ``` 73 | @article{huang2025vibnet, 74 | author = {Huang, Dianye and Li, Chenyang and Karlas, Angelos and Chu, Xiangyu and Au, K. W. Samuel and Navab, Nassir and Jiang, Zhongliang}, 75 | title = {VibNet: Vibration-Boosted Needle Detection in Ultrasound Images}, 76 | journal = {IEEE Transactions on Medical Imaging}, 77 | year = {2025}, 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | config_list = [ 2 | { 3 | "expriment_name": "beef", 4 | "data": { 5 | "data_path": "dataset/Beef", 6 | "size": (657 // 2, 671 // 2), 7 | }, 8 | "train": { 9 | "device": "cuda:0", 10 | "val_every_n": 300, 11 | "print_every_n": 25, 12 | "lr": 1e-4, 13 | "w_shaft": 0.95, 14 | "w_tip": 0.05, 15 | "batch_size_train": 6, 16 | "batch_size_val": 8, 17 | "epoch": 10, 18 | "early_stop": 10, 19 | }, 20 | "model": { 21 | "seq_length": 30, 22 | "num_angle": 180, 23 | "num_rho": 100, 24 | # Items after this line could be removed if you just want to use the default value 25 | "win": 10, 26 | "stride": 3, 27 | "FocalLoss": True, 28 | "enc_init": True, 29 | "fic_init": True, 30 | }, 31 | }, 32 | ] 33 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import albumentations as A 5 | import cv2 6 | import numpy as np 7 | from natsort import natsorted 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def gaussian(num_theta, num_rho, center, sig): 12 | """ 13 | Gaussian blurring for Hough space 14 | """ 15 | # create nxn zeros 16 | y = np.linspace(0, num_theta - 1, num_theta) 17 | x = np.linspace(0, num_rho - 1, num_rho) 18 | x, y = np.meshgrid(x, y) 19 | x0 = center[1] 20 | y0 = center[0] 21 | res = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sig**2)) 22 | 23 | return res 24 | 25 | 26 | class SeqDataset(Dataset): 27 | """ 28 | Image sequence + label of the last image in the sequence. 29 | Each video shared the same label. 30 | The dataset directory should be organized as follows: 31 | -- data_path 32 | |-- imgs 33 | |-- seq_1: images of a video 34 | |-- seq_2: images of another video 35 | |-- ... 36 | |-- annos: labels of all videos (seq_1.png, seq_2.png, ...) 37 | The file names of the sequnces and of their label images should be the same. 38 | 39 | Args: 40 | data_path (string or Path): Path to the dataset 41 | split (string): Split of the dataset, there should be a file named "{split}.txt" in the data_path 42 | size (tuple of int, optional): The size of the output images. If the size is different from the original image size, the images will be resized. 43 | num_angle (int, optional): Number of angles in the prediction. Defaults to 180. 44 | num_rho (int, optional): Number of rhos in the prediction. Defaults to 100. 45 | augment (bool, optional): Augmentation is required or not. Defaults to True. 46 | Augmentation includes: 47 | - horizontally flip of the images 48 | - Contrast and Brightness adjustment 49 | - Gaussian blur 50 | """ 51 | 52 | def __init__( 53 | self, 54 | data_path, 55 | split, 56 | size, # H, W 57 | seq_length=30, 58 | num_angle=180, 59 | num_rho=100, 60 | augment=True, 61 | ): 62 | super().__init__() 63 | 64 | self.num_angle = num_angle 65 | self.num_rho = num_rho 66 | self.size = size # H, W 67 | self.resize = (size[1], size[0]) # W, H 68 | self.augment = augment 69 | 70 | self.data_path = Path(data_path) 71 | self.seq_length = seq_length 72 | 73 | self.img_path = self.data_path / "imgs" 74 | self.anno_path = self.data_path / "annos" 75 | 76 | self.seq_names = natsorted( 77 | open(Path(data_path) / f"{split}.txt").read().splitlines() 78 | ) 79 | self.all_file_names = [ 80 | natsorted(os.listdir(self.img_path / name)) for name in self.seq_names 81 | ] 82 | 83 | self.length_list = [ 84 | len(os.listdir(self.img_path / name)) - self.seq_length + 1 85 | for name in self.seq_names 86 | ] 87 | 88 | def calc_coords(self, label): 89 | """ 90 | calulate the coordinates of the beginning and the end of the needle line region. 91 | 92 | Args: 93 | label: Segmentation mask of the needle 94 | 95 | Returns: 96 | x0, y0, x1, y1: location of points in the image space. 97 | The origin is the midpoint of the image. The x axis is from left to the right, the y axis is from the top to the bottom. 98 | """ 99 | H, W = self.size 100 | coords = np.argwhere(label) 101 | try: 102 | x0 = coords[:, 1].min() 103 | x1 = coords[:, 1].max() 104 | y0 = coords[coords[:, 1] == x0][:, 0].min() 105 | y1 = coords[coords[:, 1] == x1][:, 0].max() 106 | 107 | x0 -= W / 2 108 | x1 -= W / 2 109 | y0 -= H / 2 110 | y1 -= H / 2 111 | except ValueError: 112 | x0, y0, x1, y1 = 0, 0, 0, 0 113 | 114 | return x0, y0, x1, y1 115 | 116 | def calc_rho_theta(self, x0, y0, x1, y1): 117 | """ 118 | calculate the rho and theta of the line. 119 | """ 120 | # hough transform 121 | theta = np.arctan2(y1 - y0, x1 - x0) + np.pi / 2 122 | rho = x0 * np.cos(theta) + y0 * np.sin(theta) 123 | return theta, rho 124 | 125 | def line_shaft(self, theta, rho): 126 | """ 127 | create the hough space label for the shaft 128 | 129 | Returns: 130 | hough_space_shaft, theta, rho: 131 | - hough_space_shaft: the hough space label for the shaft, which is a gaussian distribution 132 | - theta: the index of gt theta in the hough space 133 | - rho: the index of gt rho in the hough space 134 | """ 135 | # rho is the distance from the line to the middle point of the image 136 | H, W = self.size 137 | # calculate resolution of rho and theta 138 | irho = np.sqrt(H * H + W * W) / self.num_rho 139 | itheta = np.pi / self.num_angle 140 | 141 | # rho can be a negative value, so we need to shift the index 142 | rho_idx = int(np.round(rho / irho)) + int((self.num_rho) / 2) 143 | theta_idx = int(np.round(theta / itheta)) 144 | if theta_idx >= self.num_angle: 145 | theta_idx = self.num_angle - 1 146 | hough_space_shaft = gaussian( 147 | self.num_angle, self.num_rho, (theta_idx, rho_idx), sig=2 148 | ) 149 | 150 | return hough_space_shaft, theta_idx, rho_idx 151 | 152 | def all_line_cross_tip(self, y, x): 153 | """ 154 | create the hough space label for the tip. The tip is the intersection of all the lines. 155 | """ 156 | H, W = self.size 157 | irho = np.sqrt(H * H + W * W) / self.num_rho 158 | 159 | hough_space_tip = np.zeros((self.num_angle, self.num_rho)) 160 | for i in range(self.num_angle): 161 | theta = i * np.pi / self.num_angle 162 | rho = x * np.cos(theta) + y * np.sin(theta) 163 | rho = int(np.round(rho / irho)) + int((self.num_rho) / 2) 164 | hough_space_tip[i] = gaussian(1, self.num_rho, (0, rho), sig=3) 165 | return hough_space_tip 166 | 167 | def process_label(self, label): 168 | """ 169 | It will process the label (segmentation mask) to get the hough space label and the theta and rho of the line and the tip location. 170 | """ 171 | # find the coordinates of the line 172 | x0, y0, x1, y1 = self.calc_coords(label) 173 | # H, W = self.size 174 | # cv2.line(img, (int(x0 + W / 2), int(y0 + H / 2)), (int(x1 + W / 2), int(y1 + H / 2)), 255, 2) 175 | # cv2.imwrite('coordscheck.jpg',img) 176 | 177 | # no line in the image 178 | if y0 == y1 and x0 == x1: 179 | return np.zeros((2, self.num_angle, self.num_rho)), 0, 0 180 | 181 | # calculate the rho and theta 182 | # rho is the distance from the line to the middle of the image 183 | theta, rho = self.calc_rho_theta(x0, y0, x1, y1) 184 | # cos = np.cos(theta) 185 | # sin = np.sin(theta) 186 | # x0 = cos * rho 187 | # y0 = sin * rho 188 | # x1 = int(x0 + 1000 * (-sin)) 189 | # y1 = int(y0 + 1000 * cos) 190 | # x2 = int(x0 - 1000 * (-sin)) 191 | # y2 = int(y0 - 1000 * cos) 192 | # cv2.line(img, (int(x1 + W / 2), int(y1 + H / 2)), (int(x2 + W / 2), int(y2 + H / 2)), 255, 2) 193 | # cv2.imwrite("houghlinescheck.jpg", img) 194 | 195 | # create the hough space label 196 | hough_space_label = np.zeros((2, self.num_angle, self.num_rho)) 197 | hough_space_label[0], theta, rho = self.line_shaft(theta, rho) 198 | 199 | # sort (y0, x0) and (y1, x1) based on y 200 | if y0 > y1: 201 | hough_space_label[1] = self.all_line_cross_tip(y0, x0) 202 | tip = np.array([y0, x0]) 203 | else: 204 | hough_space_label[1] = self.all_line_cross_tip(y1, x1) 205 | tip = np.array([y1, x1]) 206 | 207 | # y0, x0, y1, x1 was calculated by seen the middle point of the image as the origin 208 | # tip location in the tensor space 209 | tip[0] += self.size[0] / 2 210 | tip[1] += self.size[1] / 2 211 | 212 | return hough_space_label, theta, rho, tip 213 | 214 | def aug(self, img_seq, label): 215 | """ 216 | data augmentation 217 | """ 218 | img_seq = np.array(img_seq).astype(np.int32) 219 | 220 | augseq = A.ReplayCompose( 221 | [ 222 | A.HorizontalFlip(p=0.5), 223 | A.RandomBrightnessContrast( 224 | p=1, brightness_limit=(0, 0.2), contrast_limit=(0, 0.2) 225 | ), 226 | A.GaussianBlur(sigma_limit=(0, 1)), 227 | ] 228 | ) 229 | 230 | img_seq = img_seq.astype(np.uint8) 231 | img_seq_aug = [] 232 | data = augseq(image=img_seq[0]) 233 | img_seq_aug.append(data["image"]) 234 | 235 | for i in range(1, len(img_seq) - 1): 236 | img = A.ReplayCompose.replay(data["replay"], image=img_seq[i])["image"] 237 | img_seq_aug.append(img) 238 | 239 | transformed = A.ReplayCompose.replay( 240 | data["replay"], image=img_seq[-1], mask=label 241 | ) 242 | img_seq_aug.append(transformed["image"]) 243 | label = transformed["mask"] 244 | 245 | # visualize the augmented label 246 | # cv2.imwrite("augcheck_img.jpg", img_seq_aug[-1]) 247 | # cv2.imwrite("augcheck_label.jpg", label * 255) 248 | 249 | return img_seq_aug, label 250 | 251 | def __len__(self): 252 | return sum(self.length_list) 253 | 254 | def __getitem__(self, index): 255 | i = 0 256 | while self.length_list[i] <= index: 257 | index -= self.length_list[i] 258 | i += 1 259 | 260 | seq_file_names = self.all_file_names[i][index : index + self.seq_length] 261 | 262 | assert len(seq_file_names) == self.seq_length, "sequence length not match" 263 | 264 | img_seq = [] 265 | for file_name in seq_file_names: 266 | img = cv2.imread( 267 | str(self.img_path / self.seq_names[i] / file_name), cv2.IMREAD_GRAYSCALE 268 | ) 269 | img = cv2.resize(img, self.resize) 270 | img_seq.append(img) 271 | 272 | label = cv2.imread( 273 | str(self.anno_path / (self.seq_names[i] + ".png")), cv2.IMREAD_GRAYSCALE 274 | ) 275 | label = cv2.resize(label, self.resize) 276 | 277 | if self.augment: 278 | img_seq, label = self.aug(img_seq, label) 279 | 280 | hough_space_label, theta, rho, tip = self.process_label(label) 281 | 282 | return ( 283 | np.expand_dims(np.array(img_seq), 1).astype(np.float32) / 127.5 - 1.0, 284 | hough_space_label, 285 | label, 286 | theta, 287 | rho, 288 | tip, 289 | ) 290 | 291 | 292 | if __name__ == "__main__": 293 | # can used for test dataset and utils 294 | import torch 295 | from torch.utils.data import DataLoader 296 | 297 | from utils import (reverse_all_hough_space, reverse_max_hough_space, 298 | vis_result) 299 | 300 | seq_dataset = SeqDataset( 301 | data_path="dataset/Beef", 302 | split="test", 303 | size=(657 // 2, 671 // 2), 304 | seq_length=30, 305 | num_angle=180, 306 | num_rho=100, 307 | augment=True, 308 | ) 309 | # print(len(seq_dataset)) 310 | seq_dataloader = DataLoader(seq_dataset, batch_size=2, shuffle=True) 311 | for batch in seq_dataloader: 312 | img, hough_space_label, label, theta, rho, tip = batch 313 | img = img[0][-1] 314 | # print(img.shape) 315 | hough_space_label_shaft = hough_space_label[0][0] 316 | hough_space_label_tip = hough_space_label[0][1] 317 | label = label[0] 318 | theta = theta[0] 319 | rho = rho[0] 320 | tip = tip[0] 321 | 322 | lines = reverse_all_hough_space( 323 | torch.zeros(img.shape[-2:]), hough_space_label_tip, 180, 100 324 | ) 325 | seq_tip = vis_result(img, lines, label) 326 | W, H = img.shape[-2:] 327 | # find the max value's location in the lines, using pytorch 328 | tip_loc = torch.argmax(lines) 329 | x_pos = tip_loc / H 330 | y_pos = tip_loc % H 331 | # print(x_pos, y_pos) 332 | # print(tip) 333 | cv2.imwrite("seq_tip.png", seq_tip) 334 | line = reverse_max_hough_space( 335 | torch.zeros(img.shape[-2:]), hough_space_label_shaft, 180, 100 336 | ) 337 | seq_shaft = vis_result(img, line, label) 338 | cv2.imwrite("seq_shaft.png", seq_shaft) 339 | break 340 | -------------------------------------------------------------------------------- /infer/inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | from model.vibnet import VibNet 5 | from utils import reverse_all_hough_space 6 | 7 | 8 | class Inference: 9 | def __init__( 10 | self, pth_path, num_angle, num_rho, sequence_length=30, win=10, stride=5 11 | ): 12 | self.num_angle = num_angle 13 | self.num_rho = num_rho 14 | model = VibNet( 15 | num_angle, num_rho, seq_len=sequence_length, win=win, stride=stride 16 | ) 17 | model.load_state_dict(torch.load(pth_path)) 18 | 19 | self.model = model 20 | self.model.eval() 21 | self.model.to("cuda") 22 | 23 | def __call__(self, imgs): 24 | imgs = imgs.to("cuda") 25 | with torch.no_grad(): 26 | out = torch.sigmoid(self.model(imgs)) 27 | hough_shaft = out[:, 0, :, :] 28 | hough_tip = out[:, 1, :, :] 29 | H, W = imgs.shape[-2:] 30 | theta, rho, heatmaps_shaft = self.cal_theta_rho_idx(H, W, hough_shaft) 31 | tip_loc, heatmaps_tip = self.cal_tip_location(H, W, hough_tip) 32 | return theta, rho, tip_loc, heatmaps_shaft, heatmaps_tip, out 33 | 34 | def cal_heatmaps(self, H, W, hough_space, threshold=1e-3): 35 | heatmaps = [] 36 | for i in range(hough_space.size(0)): 37 | lines = reverse_all_hough_space( 38 | torch.zeros((H, W), device="cuda"), 39 | hough_space[i], 40 | self.num_angle, 41 | self.num_rho, 42 | threshold, 43 | ) 44 | heatmaps.append(lines) 45 | return torch.stack(heatmaps) 46 | 47 | def cal_theta_rho_idx(self, H, W, hough_space, percent=0.999): 48 | 49 | theta_rho_pred = [] 50 | heatmaps = [] 51 | for i in range(hough_space.size(0)): 52 | max_indices = torch.nonzero(hough_space[i] == torch.max(hough_space[i])) 53 | if max_indices.size(0) == 1: 54 | max_index = max_indices 55 | else: 56 | max_index = max_indices[0].reshape(1, -1) 57 | theta_rho_pred.append(max_index) 58 | threshold = torch.quantile(hough_space[i], percent, interpolation="lower") 59 | heatmaps.append(self.cal_heatmaps(H, W, hough_space[i].unsqueeze(0), threshold)) 60 | 61 | heatmaps = torch.stack(heatmaps, dim=0).squeeze(1) 62 | theta_rho_pred = torch.stack(theta_rho_pred, dim=0).squeeze(1) 63 | theta_pred = theta_rho_pred[:, 0] 64 | rho_pred = theta_rho_pred[:, 1] 65 | 66 | return theta_pred, rho_pred, heatmaps 67 | 68 | def cal_tip_location(self, H, W, hough_space, threshold=1e-3): 69 | res = [] 70 | # threshold = torch.quantile(hough_space, 1 - percent, interpolation="lower") 71 | heatmaps = self.cal_heatmaps(H, W, hough_space, threshold) 72 | for i in range(hough_space.size(0)): 73 | lines = heatmaps[i] 74 | # find the max value's location in the lines, using pytorch 75 | tip_loc = torch.argmax(lines) 76 | # x, y in the tensor space 77 | x_pos = tip_loc // W 78 | y_pos = tip_loc % W 79 | res.append(torch.tensor([x_pos, y_pos], device="cuda")) 80 | return torch.stack(res), heatmaps 81 | 82 | 83 | if __name__ == "__main__": 84 | from infer.iterator import ImageIterator 85 | 86 | image_dir = "./dataset/Beef/imgs/0" 87 | anno_path = "./dataset/Beef/annos/0.png" 88 | iterator = ImageIterator( 89 | image_dir, anno_path, (657 // 2, 671 // 2), 30, batch_size=1 90 | ) 91 | 92 | infer = Inference("./logs/beef/model.pth", 180, 100) 93 | for seqs in iterator: 94 | print(seqs.shape) 95 | theta, rho, tip_loc, heatmaps_shaft, heatmaps_tip, out = infer(seqs) 96 | cv2.imwrite( 97 | "test.png", 98 | heatmaps_shaft[0].cpu().numpy() 99 | * 255 100 | / heatmaps_shaft[0].cpu().numpy().max(), 101 | ) 102 | exit(0) 103 | -------------------------------------------------------------------------------- /infer/iterator.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from natsort import natsorted 8 | 9 | 10 | class ImageIterator: 11 | def __init__(self, image_dir, anno_path, size, seq_len, batch_size=8): 12 | self.image_dir = Path(image_dir) 13 | self._index = 0 14 | self.file_list = self.image_dir.glob("*.png") 15 | self.file_list = [str(self.image_dir / f.name) for f in self.file_list] 16 | self.file_list = natsorted(self.file_list) 17 | self.length = len(self.file_list) - seq_len + 1 18 | self.size = size 19 | self.resize = (size[1], size[0]) 20 | self.seq_len = seq_len 21 | if anno_path is not None: 22 | self.anno = cv2.imread(anno_path, cv2.IMREAD_GRAYSCALE) 23 | self.anno = cv2.resize(self.anno, self.resize) 24 | else: 25 | self.anno = None 26 | self.batch_size = batch_size 27 | 28 | def __len__(self): 29 | return ceil(self.length / self.batch_size) 30 | 31 | def __iter__(self): 32 | self._index = 0 33 | return self 34 | 35 | def __next__(self): 36 | if self._index < self.length: 37 | imgs = self.get(self._index) 38 | self._index += self.batch_size 39 | return imgs 40 | else: 41 | raise StopIteration 42 | 43 | def get(self, idx): 44 | if idx >= self.length: 45 | raise IndexError("index out of range") 46 | seqs = [] 47 | imgs = [] 48 | for i in range(self.seq_len): 49 | img = cv2.imread(self.file_list[idx + i], cv2.IMREAD_GRAYSCALE) 50 | img = cv2.resize(img, self.resize) 51 | img = img.reshape(1, *img.shape) 52 | imgs.append(img / 127.5 - 1) 53 | for i in range(self.batch_size): 54 | seqs.append(np.stack(imgs, axis=0)) 55 | try: 56 | img = cv2.imread( 57 | self.file_list[idx + i + self.seq_len], cv2.IMREAD_GRAYSCALE 58 | ) 59 | img = cv2.resize(img, self.resize) 60 | img = img.reshape(1, *img.shape) 61 | imgs.append(img / 127.5 - 1) 62 | imgs.pop(0) 63 | except IndexError: 64 | break 65 | 66 | seqs = np.stack(seqs, axis=0) 67 | seqs = torch.tensor(seqs, dtype=torch.float32) 68 | return seqs 69 | 70 | 71 | if __name__ == "__main__": 72 | image_dir = "./dataset/Beef/imgs/0" 73 | anno_path = "./dataset/Beef/annos/0.png" 74 | iterator = ImageIterator(image_dir, anno_path, (657 // 2, 671 // 2), 30) 75 | print(iterator.length) 76 | print(iterator.size) 77 | print(iterator.seq_len) 78 | print(iterator.anno.shape) 79 | for i, seqs in enumerate(iterator): 80 | print(i, seqs.shape) 81 | -------------------------------------------------------------------------------- /logs/beef/config.txt: -------------------------------------------------------------------------------- 1 | {'expriment_name': 'beef', 'data': {'data_path': 'dataset/Beef', 'size': (328, 335)}, 'train': {'device': 'cuda:0', 'val_every_n': 300, 'print_every_n': 25, 'lr': 0.0001, 'w_shaft': 0.95, 'w_tip': 0.05, 'batch_size_train': 6, 'batch_size_val': 8, 'epoch': 10, 'early_stop': 10}, 'model': {'seq_length': 30, 'num_angle': 180, 'num_rho': 100, 'FocalLoss': True, 'win': 10, 'stride': 5, 'enc_init': True, 'fic_init': True}} -------------------------------------------------------------------------------- /logs/beef/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marslicy/VibNet/8830a2f5fde163776e4abb92ba06c723d26795fa/logs/beef/model.pth -------------------------------------------------------------------------------- /model/_cdht/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ -------------------------------------------------------------------------------- /model/_cdht/deep_hough_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | // CUDA forward declarations 9 | std::vector line_accum_cuda_forward( 10 | const torch::Tensor feat, 11 | const float* tabCos, 12 | const float* tabSin, 13 | torch::Tensor output, 14 | const int numangle, 15 | const int numrho); 16 | 17 | std::vector line_accum_cuda_backward( 18 | torch::Tensor grad_outputs, 19 | torch::Tensor grad_in, 20 | torch::Tensor feat, 21 | const float* tabCos, 22 | const float* tabSin, 23 | const int numangle, 24 | const int numrho); 25 | 26 | // C++ interface 27 | 28 | #define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda()) //, #x " must be a CUDA tensor") 29 | #define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous()) //, #x " must be contiguous") 30 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 31 | #define PI 3.14159265358979323846 32 | 33 | void initTab(float* tabSin, float* tabCos, const int numangle, const int numrho, const int H, const int W) 34 | { 35 | float irho = int(std::sqrt(H*H + W*W) + 1) / float((numrho - 1)); 36 | float itheta = PI / numangle; 37 | float angle = 0; 38 | for(int i = 0; i < numangle; ++i) 39 | { 40 | tabCos[i] = std::cos(angle) / irho; 41 | tabSin[i] = std::sin(angle) / irho; 42 | angle += itheta; 43 | } 44 | } 45 | 46 | std::vector line_accum_forward( 47 | const at::Tensor feat, 48 | at::Tensor output, 49 | const int numangle, 50 | const int numrho) { 51 | 52 | CHECK_INPUT(feat); 53 | CHECK_INPUT(output); 54 | float tabSin[numangle], tabCos[numangle]; 55 | const int H = feat.size(2); 56 | const int W = feat.size(3); 57 | initTab(tabSin, tabCos, numangle, numrho, H, W); 58 | const int batch_size = feat.size(0); 59 | const int channels_size = feat.size(1); 60 | 61 | // torch::set_requires_grad(output, true); 62 | auto out = line_accum_cuda_forward(feat, tabCos, tabSin, output, numangle, numrho); 63 | // std::cout << out[0].sum() << std::endl; 64 | CHECK_CONTIGUOUS(out[0]); 65 | return out; 66 | } 67 | 68 | std::vector line_accum_backward( 69 | torch::Tensor grad_outputs, 70 | torch::Tensor grad_inputs, 71 | torch::Tensor feat, 72 | const int numangle, 73 | const int numrho) { 74 | 75 | CHECK_INPUT(grad_outputs); 76 | CHECK_INPUT(grad_inputs); 77 | CHECK_INPUT(feat); 78 | 79 | float tabSin[numangle], tabCos[numangle]; 80 | const int H = feat.size(2); 81 | const int W = feat.size(3); 82 | initTab(tabSin, tabCos, numangle, numrho, H, W); 83 | 84 | const int batch_size = feat.size(0); 85 | const int channels_size = feat.size(1); 86 | const int imH = feat.size(2); 87 | const int imW = feat.size(3); 88 | 89 | return line_accum_cuda_backward( 90 | grad_outputs, 91 | grad_inputs, 92 | feat, 93 | tabCos, 94 | tabSin, 95 | numangle, 96 | numrho); 97 | } 98 | 99 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 100 | m.def("forward", &line_accum_forward, "line features accumulating forward (CUDA)"); 101 | m.def("backward", &line_accum_backward, "line features accumulating backward (CUDA)"); 102 | } -------------------------------------------------------------------------------- /model/_cdht/deep_hough_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | 12 | // ------- 13 | // KERNELS 14 | // ------- 15 | __global__ void helloCUDA(const float *f) 16 | { 17 | for(int i = 0; i < 10; ++i) 18 | { 19 | printf("%d ", f[i]); 20 | } 21 | printf("\n"); 22 | // printf("Hello thread %d, %d, %d, f=%f\n", threadIdx.x, threadIdx.y, threadIdx.z, f); 23 | } 24 | 25 | 26 | __global__ 27 | void line_accum_forward_kernel( 28 | const float* __restrict__ feat, 29 | const float* tabCos, 30 | const float* tabSin, 31 | float* output, 32 | const int imWidth, 33 | const int imHeight, 34 | const int threadW, 35 | const int threadH, 36 | const int threadK, 37 | const int channelSize, 38 | const int batchSize, 39 | const int numangle, 40 | const int numrho) 41 | { 42 | int batch = blockIdx.y; 43 | int channel = blockIdx.x; 44 | int x = threadIdx.x*threadW; 45 | int y = threadIdx.y*threadH; 46 | int k = threadIdx.z*threadK; 47 | 48 | int imgStartIdx = batch*channelSize*imWidth*imHeight+ 49 | channel*imWidth*imHeight+ 50 | y*imWidth+ 51 | x; 52 | 53 | int angleStartIdx = k; 54 | 55 | if (x < imWidth && y < imHeight && channel < channelSize && batch < batchSize && k < numangle) 56 | { 57 | int imgIndex = imgStartIdx; 58 | int angleIndex; 59 | int outIndex; 60 | int r; 61 | for (int idY=0; idY < threadH; idY++) 62 | { 63 | imgIndex = imgStartIdx + idY*imWidth; 64 | // labelIndex = labelStartIdx + idY*imWidth; 65 | if (y+idY < imHeight) 66 | { 67 | for (int idX=0; idX line_accum_cuda_forward( 180 | const torch::Tensor feat, 181 | const float* tabCos, 182 | const float* tabSin, 183 | torch::Tensor output, 184 | const int numangle, 185 | const int numrho){ 186 | // -feat: [N, C, H, W] 187 | // -tabCos: [numangle] 188 | // -tabSin: [numangle] 189 | const int batch_size = feat.size(0); 190 | const int channels_size = feat.size(1); 191 | const int imH = feat.size(2); 192 | const int imW = feat.size(3); 193 | 194 | int blockSizeX = std::min(8, imW); 195 | const int threadW = ceil(imW/(float)blockSizeX); 196 | 197 | int blockSizeY = std::min(8, imH); 198 | const int threadH = ceil(imH/(float)blockSizeY); 199 | 200 | int blockSizeZ = std::min(8, numangle); 201 | const int threadK = ceil(numangle/(float)blockSizeZ); 202 | 203 | const dim3 blocks(channels_size, batch_size); 204 | const dim3 threads(blockSizeX, blockSizeY, blockSizeZ); 205 | 206 | float *d_tabCos, *d_tabSin; 207 | 208 | cudaMalloc((void **)&d_tabCos, sizeof(float)*numangle); 209 | cudaMalloc((void **)&d_tabSin, sizeof(float)*numangle); 210 | 211 | cudaMemcpy(d_tabCos, tabCos, sizeof(float)*numangle, cudaMemcpyHostToDevice); 212 | cudaMemcpy(d_tabSin, tabSin, sizeof(float)*numangle, cudaMemcpyHostToDevice); 213 | 214 | // std::cout << imW << " " << imH << " " << channels_size << " " << batch_size << " " << numangle << " " << numrho << std::endl; 215 | line_accum_forward_kernel<<>>( 216 | feat.data(), 217 | d_tabCos, 218 | d_tabSin, 219 | output.data(), 220 | imW, 221 | imH, 222 | threadW, 223 | threadH, 224 | threadK, 225 | channels_size, 226 | batch_size, 227 | numangle, 228 | numrho 229 | ); 230 | // helloCUDA<<>>(tabCos); 231 | // cudaDeviceSynchronize(); 232 | // std::cout << output << std::endl; 233 | // std::cout << output.sum() << std::endl; 234 | cudaFree(d_tabCos); 235 | cudaFree(d_tabSin); 236 | return {output}; 237 | } 238 | 239 | std::vector line_accum_cuda_backward( 240 | torch::Tensor grad_outputs, 241 | torch::Tensor grad_in, 242 | torch::Tensor feat, 243 | const float* tabCos, 244 | const float* tabSin, 245 | const int numangle, 246 | const int numrho) 247 | { 248 | const int batch_size = feat.size(0); 249 | const int channels_size = feat.size(1); 250 | const int imH = feat.size(2); 251 | const int imW = feat.size(3); 252 | 253 | int blockSizeX = std::min(8, imW); 254 | const int threadW = ceil(imW/(float)blockSizeX); 255 | 256 | int blockSizeY = std::min(8, imH); 257 | const int threadH = ceil(imH/(float)blockSizeY); 258 | 259 | int blockSizeZ = std::min(8, numangle); 260 | const int threadK = ceil(numangle/(float)blockSizeZ); 261 | 262 | const dim3 blocks(channels_size, batch_size); 263 | const dim3 threads(blockSizeX, blockSizeY, blockSizeZ); 264 | 265 | float *d_tabCos, *d_tabSin; 266 | 267 | cudaMalloc((void **)&d_tabCos, sizeof(float)*numangle); 268 | cudaMalloc((void **)&d_tabSin, sizeof(float)*numangle); 269 | 270 | cudaMemcpy(d_tabCos, tabCos, sizeof(float)*numangle, cudaMemcpyHostToDevice); 271 | cudaMemcpy(d_tabSin, tabSin, sizeof(float)*numangle, cudaMemcpyHostToDevice); 272 | // std::cout << imW << " " << imH << " " << channels_size << " " << batch_size << " " << numangle << " " << numrho << std::endl; 273 | 274 | 275 | // printf("p = %p\n", grad_outputs.data()); 276 | // printf("p = %p\n", grad_in.data()); 277 | 278 | line_accum_backward_kernel<<>>( 279 | grad_in.data(), 280 | grad_outputs.data(), 281 | d_tabCos, 282 | d_tabSin, 283 | imW, 284 | imH, 285 | threadW, 286 | threadH, 287 | threadK, 288 | channels_size, 289 | batch_size, 290 | numangle, 291 | numrho 292 | ); 293 | // printf("p = %p\n", grad_outputs.data()); 294 | // printf("p = %p\n", grad_in.data()); 295 | // std::cout << grad_outputs << std::endl; 296 | // cudaDeviceSynchronize(); 297 | cudaFree(d_tabCos); 298 | cudaFree(d_tabSin); 299 | return {grad_in}; 300 | } -------------------------------------------------------------------------------- /model/_cdht/dht_func.py: -------------------------------------------------------------------------------- 1 | import deep_hough as dh 2 | import torch 3 | 4 | 5 | class C_dht_Function(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, feat, numangle, numrho): 8 | N, C, _, _ = feat.size() 9 | out = torch.zeros(N, C, numangle, numrho).type_as(feat).cuda() 10 | out = dh.forward(feat, out, numangle, numrho) 11 | outputs = out[0] 12 | ctx.save_for_backward(feat) 13 | ctx.numangle = numangle 14 | ctx.numrho = numrho 15 | return outputs 16 | 17 | @staticmethod 18 | def backward(ctx, grad_output): 19 | feat = ctx.saved_tensors[0] 20 | numangle = ctx.numangle 21 | numrho = ctx.numrho 22 | out = torch.zeros_like(feat).type_as(feat).cuda() 23 | out = dh.backward(grad_output.contiguous(), out, feat, numangle, numrho) 24 | grad_in = out[0] 25 | return grad_in, None, None 26 | 27 | 28 | class C_dht(torch.nn.Module): 29 | def __init__(self, numAngle, numRho): 30 | super(C_dht, self).__init__() 31 | self.numAngle = numAngle 32 | self.numRho = numRho 33 | 34 | def forward(self, feat): 35 | return C_dht_Function.apply(feat, self.numAngle, self.numRho) 36 | -------------------------------------------------------------------------------- /model/_cdht/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 3 | CUDAExtension) 4 | 5 | setup( 6 | name="deep_hough", 7 | ext_modules=[ 8 | CUDAExtension( 9 | "deep_hough", 10 | [ 11 | "deep_hough_cuda.cpp", 12 | "deep_hough_cuda_kernel.cu", 13 | ], 14 | extra_compile_args={"cxx": ["-g"], "nvcc": ["-arch=sm_60"]}, 15 | ) 16 | ], 17 | cmdclass={"build_ext": BuildExtension}, 18 | ) 19 | -------------------------------------------------------------------------------- /model/dht.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from model._cdht.dht_func import C_dht 4 | 5 | 6 | class ConvAct(nn.Module): 7 | def __init__(self, dim_in, dim_out, dim_intermediate, ks=3, s=1): 8 | super(ConvAct, self).__init__() 9 | p = (ks - 1) // 2 10 | self.conv1 = nn.Sequential( 11 | nn.Conv2d(dim_in, dim_intermediate, ks, s, p, bias=None), 12 | nn.BatchNorm2d(dim_intermediate), 13 | nn.ReLU(), 14 | ) 15 | self.conv2 = nn.Sequential( 16 | nn.Conv2d(dim_intermediate, dim_out, ks, s, p, bias=None), 17 | nn.BatchNorm2d(dim_out), 18 | nn.ReLU(), 19 | ) 20 | 21 | def forward(self, x): 22 | x = self.conv1(x) 23 | x = self.conv2(x) 24 | return x 25 | 26 | 27 | class DHT_Layer(nn.Module): 28 | def __init__(self, input_dim, dim, numAngle, numRho): 29 | super(DHT_Layer, self).__init__() 30 | self.fist_conv = nn.Sequential( 31 | nn.Conv2d(input_dim, dim, 1), nn.BatchNorm2d(dim), nn.ReLU() 32 | ) 33 | self.dht = DHT(numAngle=numAngle, numRho=numRho) 34 | self.convs = nn.Sequential( 35 | ConvAct(dim, dim, dim), 36 | ConvAct(dim, dim, dim), 37 | ) 38 | 39 | def forward(self, x): 40 | x = self.fist_conv(x) 41 | x = self.dht(x) 42 | x = self.convs(x) 43 | return x 44 | 45 | 46 | # import time 47 | class DHT(nn.Module): 48 | def __init__(self, numAngle, numRho): 49 | super(DHT, self).__init__() 50 | self.line_agg = C_dht(numAngle, numRho) 51 | 52 | def forward(self, x): 53 | # start_time = time.perf_counter() 54 | accum = self.line_agg(x) # Most time consuming part 55 | # end_time = time.perf_counter() 56 | # elapsed_time = end_time - start_time 57 | # print(f"DHT Elapsed time: {elapsed_time:.4f} seconds") 58 | return accum 59 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Conv2D_activa(nn.Module): 6 | def __init__( 7 | self, 8 | in_channels, 9 | out_channels, 10 | kernel_size, 11 | stride, 12 | padding=0, 13 | dilation=1, 14 | activation="relu", 15 | ): 16 | super(Conv2D_activa, self).__init__() 17 | self.padding = padding 18 | if self.padding: 19 | self.pad = nn.ReflectionPad2d(padding) 20 | self.conv2d = nn.Conv2d( 21 | in_channels, out_channels, kernel_size, stride, dilation=dilation, bias=None 22 | ) 23 | self.activation = activation 24 | if activation == "relu": 25 | self.activation = nn.ReLU() 26 | 27 | def forward(self, x): 28 | if self.padding: 29 | x = self.pad(x) 30 | x = self.conv2d(x) 31 | if self.activation: 32 | x = self.activation(x) 33 | return x 34 | 35 | 36 | class ResBlk(nn.Module): 37 | def __init__(self, dim_in, dim_out, dim_intermediate=32, ks=3, s=1): 38 | super(ResBlk, self).__init__() 39 | p = (ks - 1) // 2 40 | self.cba_1 = Conv2D_activa( 41 | dim_in, dim_intermediate, ks, s, p, activation="relu" 42 | ) 43 | self.cba_2 = Conv2D_activa(dim_intermediate, dim_out, ks, s, p, activation=None) 44 | 45 | def forward(self, x): 46 | y = self.cba_1(x) 47 | y = self.cba_2(y) 48 | return y + x 49 | 50 | 51 | def _repeat_blocks(block, dim_in, dim_out, num_blocks, dim_intermediate=32, ks=3, s=1): 52 | blocks = [] 53 | for idx_block in range(num_blocks): 54 | if idx_block == 0: 55 | blocks.append( 56 | block(dim_in, dim_out, dim_intermediate=dim_intermediate, ks=ks, s=s) 57 | ) 58 | else: 59 | blocks.append( 60 | block(dim_out, dim_out, dim_intermediate=dim_intermediate, ks=ks, s=s) 61 | ) 62 | return nn.Sequential(*blocks) 63 | 64 | 65 | class Encoder(nn.Module): 66 | def __init__( 67 | self, 68 | dim_in=3, 69 | dim_out=32, 70 | num_resblk=3, 71 | use_texture_conv=True, 72 | use_motion_conv=True, 73 | texture_downsample=True, 74 | num_resblk_texture=2, 75 | num_resblk_motion=2, 76 | pretained_dict_path=None, 77 | ): 78 | super(Encoder, self).__init__() 79 | self.use_texture_conv, self.use_motion_conv = use_texture_conv, use_motion_conv 80 | 81 | self.cba_1 = Conv2D_activa(dim_in, 16, 7, 1, 3, activation="relu") 82 | self.cba_2 = Conv2D_activa(16, 32, 3, 2, 1, activation="relu") 83 | 84 | self.resblks = _repeat_blocks(ResBlk, 32, 32, num_resblk) 85 | 86 | # texture representation 87 | # if self.use_texture_conv: 88 | # self.texture_cba = Conv2D_activa( 89 | # 32, 32, 3, (2 if texture_downsample else 1), 1, activation="relu" 90 | # ) 91 | # self.texture_resblks = _repeat_blocks(ResBlk, 32, dim_out, num_resblk_texture) 92 | 93 | # motion representation 94 | if self.use_motion_conv: 95 | self.motion_cba = Conv2D_activa(32, 32, 3, 1, 1, activation="relu") 96 | self.motion_resblks = _repeat_blocks(ResBlk, 32, dim_out, num_resblk_motion) 97 | 98 | if pretained_dict_path: 99 | self.load_pretrained_state_dict(pretained_dict_path) 100 | 101 | def forward(self, x): 102 | x = self.cba_1(x) 103 | x = self.cba_2(x) 104 | x = self.resblks(x) 105 | 106 | # if self.use_texture_conv: 107 | # texture = self.texture_cba(x) 108 | # texture = self.texture_resblks(texture) 109 | # else: 110 | # texture = self.texture_resblks(x) 111 | 112 | texture = None 113 | 114 | if self.use_motion_conv: 115 | motion = self.motion_cba(x) 116 | motion = self.motion_resblks(motion) 117 | else: 118 | motion = self.motion_resblks(x) 119 | 120 | return texture, motion 121 | 122 | def load_pretrained_state_dict(self, pretained_dict_path): 123 | pretrained_dict = torch.load(pretained_dict_path) 124 | model_dict = self.state_dict() 125 | # 1. filter out unnecessary keys 126 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 127 | # 2. overwrite entries in the existing state dict 128 | model_dict.update(pretrained_dict) 129 | # 3. load the new state dict 130 | self.load_state_dict(model_dict) 131 | -------------------------------------------------------------------------------- /model/fic.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class FIC(nn.Module): 8 | """ 9 | Fourier Initialized Convolution 10 | For every 1D convolution, do fourier transform or short-time fourier transform 11 | Set window_size = sequence_length to do fourier transform 12 | (From UniTS: Short-Time Fourier Inspired Neural Networks for Sensory Time Series Classification, https://doi.org/10.1145/3485730.3485942) 13 | """ 14 | 15 | def __init__(self, window_size, stride, init=True): 16 | super(FIC, self).__init__() 17 | self.window_size = window_size 18 | self.k = int(window_size / 2) 19 | 20 | self.conv = nn.Conv1d( 21 | in_channels=1, 22 | out_channels=2 * int(window_size / 2), 23 | kernel_size=window_size, 24 | stride=stride, 25 | padding=0, 26 | bias=False, 27 | ) 28 | if init: 29 | self.init() 30 | 31 | def forward(self, x): 32 | # x: (batch, channel, in_length) 33 | B, C = x.shape[:2] 34 | 35 | # x: (batch, channel, in_length) 36 | x = x.reshape(B * C, 1, -1) 37 | # x: (batch*channel, 1, in_length) 38 | x = self.conv(x) 39 | # x: (batch*channel, fc, out_length) 40 | x = x.reshape(B, C, -1, x.shape[-1]) 41 | # x: (batch, channel, fc, out_length) 42 | 43 | return x 44 | 45 | def init(self): 46 | """ 47 | Fourier weights initialization 48 | """ 49 | basis = torch.tensor( 50 | [math.pi * 2 * j / self.window_size for j in range(self.window_size)] 51 | ) 52 | 53 | # print('basis size: ', basis.size()) 54 | # print('basis: ', basis) 55 | 56 | weight = torch.zeros((self.k * 2, self.window_size)) 57 | 58 | # print('weight size: ', weight.size()) 59 | 60 | for i in range(self.k * 2): 61 | f = int(i / 2) + 1 62 | if i % 2 == 0: 63 | weight[i] = torch.cos(f * basis) 64 | else: 65 | weight[i] = torch.sin(-f * basis) 66 | 67 | self.conv.weight = torch.nn.Parameter(weight.unsqueeze(1), requires_grad=True) 68 | 69 | 70 | if __name__ == "__main__": 71 | import matplotlib.pyplot as plt 72 | import numpy as np 73 | 74 | ft = FIC(60, 1) 75 | # cos in frequency 5hz 76 | # x = ( 77 | # torch.cos(torch.tensor([math.pi * 2 * 3.5 * j / 30 for j in range(30)])) 78 | # .unsqueeze(0) 79 | # .unsqueeze(0) 80 | # ) 81 | 82 | x = ( 83 | torch.cos(torch.tensor([math.pi * 2 * 7 * j / 59 for j in range(128)])) 84 | .unsqueeze(0) 85 | .unsqueeze(0) 86 | ) 87 | 88 | print("x shape: ", x.shape) 89 | 90 | # # visualize the fourier transform 91 | # # plt.plot(x.squeeze()) 92 | # # plt.savefig("cos.png") 93 | ft_res = ft(x).detach().numpy().squeeze() 94 | 95 | print(ft_res) 96 | 97 | print("result shape: ", ft_res.shape) 98 | 99 | # plt.plot(np.arange(0, 15, 0.5), ft_res) 100 | plt.imshow(ft_res[::2, :]) 101 | # plt.imshow(ft_res) 102 | # print(ft_res.shape) 103 | plt.savefig("ft.png") 104 | -------------------------------------------------------------------------------- /model/magnet_epoch12_loss7.28e-02.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marslicy/VibNet/8830a2f5fde163776e4abb92ba06c723d26795fa/model/magnet_epoch12_loss7.28e-02.pth -------------------------------------------------------------------------------- /model/vibnet.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | 7 | from model.dht import DHT_Layer 8 | from model.encoder import Encoder 9 | from model.fic import FIC 10 | 11 | 12 | class ResBlk(nn.Module): 13 | def __init__(self, dim_in, dim_out, dim_intermediate, ks=3, s=1): 14 | super(ResBlk, self).__init__() 15 | p = (ks - 1) // 2 16 | self.conv1 = nn.Sequential( 17 | nn.Conv2d(dim_in, dim_intermediate, ks, s, p, bias=None), 18 | nn.BatchNorm2d(dim_intermediate), 19 | nn.ReLU(), 20 | ) 21 | self.conv2 = nn.Sequential( 22 | nn.Conv2d(dim_intermediate, dim_out, ks, s, p, bias=None), 23 | nn.BatchNorm2d(dim_out), 24 | nn.ReLU(), 25 | ) 26 | 27 | def forward(self, x): 28 | y = self.conv1(x) 29 | y = self.conv2(y) 30 | return y + x 31 | 32 | 33 | class VibNet(nn.Module): 34 | def __init__( 35 | self, 36 | num_angle, 37 | num_rho, 38 | enc_init=True, 39 | fic_init=True, 40 | seq_len=30, 41 | win=10, 42 | stride=5, 43 | ): 44 | super(VibNet, self).__init__() 45 | if enc_init: 46 | pretained_dict_path = "model/magnet_epoch12_loss7.28e-02.pth" 47 | else: 48 | pretained_dict_path = None 49 | self.encoder = Encoder(pretained_dict_path=pretained_dict_path) 50 | self.batch_norm = nn.BatchNorm2d(32) 51 | self.fusion1d_1 = nn.Conv1d(32, 24, 1) 52 | 53 | self.fic = FIC(win, stride, init=fic_init) 54 | 55 | t = (seq_len - win) // stride + 1 56 | 57 | self.fusion2d_1 = nn.Sequential( 58 | ResBlk(24, 24, 24), 59 | ResBlk(24, 24, 24), 60 | ) 61 | 62 | self.fusion_stft = nn.Sequential( 63 | nn.Conv1d(t, 1, 1, 1), 64 | nn.BatchNorm1d(1), 65 | nn.ReLU(), 66 | ) 67 | 68 | self.fusion1d_2 = nn.Sequential( 69 | nn.Conv1d(24, 16, 7, 1, 3), 70 | nn.BatchNorm1d(16), 71 | nn.ReLU(), 72 | nn.Conv1d(16, 12, 7, 1, 3), 73 | nn.BatchNorm1d(12), 74 | nn.ReLU(), 75 | ) 76 | 77 | self.fusion2d_2 = nn.Sequential( 78 | ResBlk(12, 12, 12), 79 | ResBlk(12, 12, 12), 80 | ) 81 | 82 | out_length = 2 * int(win / 2) * 12 83 | 84 | self.fm_conv = nn.Sequential( 85 | ResBlk(out_length, out_length, out_length), 86 | ResBlk(out_length, out_length, out_length), 87 | ) 88 | 89 | self.dht_detector = DHT_Layer( 90 | out_length, out_length, numAngle=num_angle, numRho=num_rho 91 | ) 92 | self.last_conv = nn.Sequential(nn.Conv2d(out_length, 2, 1)) 93 | self.num_angle = num_angle 94 | self.num_rho = num_rho 95 | 96 | if enc_init: 97 | for param in self.encoder.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, x): 101 | # ================== Encoder ================== 102 | (N, T, C, H, W) = x.shape 103 | x = x.reshape(N * T, C, H, W) 104 | # reapeat each image 3 times to form 3 channels 105 | x = torch.repeat_interleave(x, 3, dim=1) 106 | _, x = self.encoder(x) # 32 channels (N*T, C, H, W) 107 | x = self.batch_norm(x) 108 | x = x.reshape(N, T, -1, x.shape[-2], x.shape[-1]) # (N, T, C, H, W) 109 | (N, T, C, H, W) = x.shape 110 | 111 | # =============== Channel Fusion =============== 112 | x = x.permute(0, 3, 4, 2, 1) # (N, H, W, C, T) 113 | x = x.reshape(N * H * W, -1, T) 114 | x = self.fusion1d_1(x) 115 | 116 | # =============== STFT Module =============== 117 | x = self.fic(x) # (N*H*W, 24, 10, 5) 118 | (_, _, F, t) = x.shape 119 | x = x.reshape(-1, F, t) 120 | x = x.permute(0, 2, 1) # (N*H*W*24, 5, 10) 121 | 122 | # =============== STFT Fusion =============== 123 | x = self.fusion_stft(x) # (N*H*W*24, 1, 10) 124 | x = x.reshape(N, H, W, -1, F) # (N, H, W, C, F) 125 | 126 | # ========= Spatial & Channel Conv ========= 127 | x = self.permute_conv2d(self.fusion2d_1, x) 128 | x = self.permute_conv1d(self.fusion1d_2, x) 129 | x = self.permute_conv2d(self.fusion2d_2, x) 130 | 131 | # =============== Concatenate & 2D Conv =============== 132 | (N, _, _, H, W) = x.shape 133 | x = x.permute(0, 2, 1, 3, 4) # (N, C, F, H, W) 134 | x = x.reshape(N, -1, H, W) # (N, 120, H, W) 135 | x = self.fm_conv(x) 136 | 137 | # ========= Deep Hough & Classification ========= 138 | x = x.contiguous() 139 | x = self.dht_detector(x) 140 | x = self.last_conv(x) 141 | return x 142 | 143 | def permute_conv2d(self, conv2d_layer, x): 144 | # input in shape (N, H, W, C, F) 145 | (N, H, W, _, F) = x.shape 146 | x = x.permute(0, 4, 3, 1, 2) # (N, F, C, H, W) 147 | x = x.reshape(N * F, -1, H, W) 148 | x = conv2d_layer(x) 149 | x = x.reshape(N, F, -1, x.shape[-2], x.shape[-1]) # (N, F, C, H, W) 150 | # output in shape (N, F, C, H, W) 151 | return x 152 | 153 | def permute_conv1d(self, conv1d_layer, x): 154 | # input in shape (N, F, C, H, W) 155 | (N, F, _, H, W) = x.shape 156 | x = x.permute(0, 3, 4, 2, 1) # (N, H, W, C, F) 157 | x = x.reshape(N * H * W, -1, F) 158 | x = conv1d_layer(x) 159 | x = x.reshape(N, H, W, -1, F) # (N, H, W, C, F) 160 | # output in shape (N, H, W, C, F) 161 | return x 162 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from infer.inference import Inference 9 | from infer.iterator import ImageIterator 10 | from utils import theta_rho_to_xy 11 | 12 | 13 | def vis_res_video( 14 | image_dir, 15 | anno_path, 16 | model_path, 17 | output_dir, 18 | batch_size, 19 | config, 20 | ): 21 | os.makedirs(output_dir, exist_ok=True) 22 | 23 | size = config["data"]["size"] 24 | seq_length = config["model"]["seq_length"] 25 | num_angle = config["model"]["num_angle"] 26 | num_rho = config["model"]["num_rho"] 27 | win = config["model"]["win"] 28 | stride = config["model"]["stride"] 29 | 30 | iterator = ImageIterator(image_dir, anno_path, size, seq_length, batch_size) 31 | inference = Inference(model_path, num_angle, num_rho, seq_length, win, stride) 32 | label = iterator.anno 33 | 34 | H, W = size 35 | # each frame contrains: img, heatmaps_shaft + label, heatmaps_tip + leabel + tip_loc 36 | frame = np.zeros((H, W * 3), dtype=np.uint8) 37 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 38 | vid_name = str(Path(output_dir) / Path(image_dir).name) + ".mp4" 39 | out = cv2.VideoWriter(vid_name, fourcc, 30.0, (W * 3, H), False) 40 | 41 | cnt = 0 42 | for seqs in tqdm(iterator, desc=f"Video {vid_name}"): 43 | theta, rho, tip_loc, heatmaps_shaft, heatmaps_tip, _ = inference(seqs) 44 | 45 | H, W = seqs.shape[-2:] 46 | p1, p2 = theta_rho_to_xy((H, W), theta, rho, num_angle, num_rho) 47 | p1 = p1.int().cpu().numpy() 48 | p2 = p2.int().cpu().numpy() 49 | tip_loc = tip_loc.int().cpu().numpy() 50 | heatmaps_shaft = heatmaps_shaft.cpu().numpy() 51 | heatmaps_tip = heatmaps_tip.cpu().numpy() 52 | 53 | for i in range(len(seqs)): 54 | img = (seqs[i, -1, :, :].cpu().numpy() + 1) * 127.5 55 | 56 | zeros = np.zeros((H, W), dtype=np.float32) 57 | shaft = cv2.addWeighted(zeros, 1, label.astype(np.float32) * 255, 0.5, 0) 58 | shaft = cv2.addWeighted(shaft, 1, heatmaps_shaft[i], 0.8, 0) 59 | shaft = cv2.line(shaft, (p1[i, 1], p1[i, 0]), (p2[i, 1], p2[i, 0]), 255, 1) 60 | 61 | zeros = np.zeros((H, W), dtype=np.float32) 62 | tip = cv2.addWeighted(zeros, 1, label.astype(np.float32) * 255, 0.5, 0) 63 | tip = cv2.addWeighted(tip, 1, heatmaps_tip[i], 0.8, 0) 64 | tip = cv2.circle(tip, (tip_loc[i, 1], tip_loc[i, 0]), 3, 255, -1) 65 | 66 | frame[:, :W] = img 67 | frame[:, W : W * 2] = shaft 68 | frame[:, W * 2 : W * 3] = tip 69 | 70 | out.write(frame) 71 | cnt += 1 72 | print(f"Video {vid_name} Done! Total {cnt} frames") 73 | out.release() 74 | cv2.destroyAllWindows() 75 | 76 | 77 | if __name__ == "__main__": 78 | batch_size = 4 79 | dataset_dir = Path("./dataset/Beef") 80 | output_dir = "./output_videos_beef" 81 | model_path = "./logs/beef/model.pth" 82 | config_path = "./logs/beef/config.txt" 83 | 84 | seq_name = open("dataset/Beef/test.txt").read().split("\n") 85 | # seq_name = ["37", "41"] 86 | 87 | imgs_dir = dataset_dir / "imgs" 88 | annos_dir = dataset_dir / "annos" 89 | 90 | with open(config_path, "r") as f: 91 | config = eval(f.read()) 92 | 93 | total = len(seq_name) 94 | for i, d in enumerate(seq_name): 95 | image_dir = str(imgs_dir / d) 96 | anno_path = str(annos_dir / f"{d}.png") 97 | 98 | vis_res_video( 99 | image_dir, 100 | anno_path, 101 | model_path, 102 | output_dir, 103 | batch_size, 104 | config, 105 | ) 106 | 107 | print(f"Video {i+1}/{total} Done!") 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | natsort==8.4.0 2 | albumentations==2.0.4 3 | opencv-python==4.11.0.86 4 | tensorboardX==2.6.2.2 5 | tqdm==4.67.1 6 | matplotlib==3.10.0 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from math import sqrt 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from dataset import SeqDataset 12 | from infer.inference import Inference 13 | 14 | 15 | def test( 16 | model_path, 17 | data_path, 18 | split, 19 | config, 20 | batch_size, 21 | save_path, 22 | ): 23 | num_angle = config["model"]["num_angle"] 24 | num_rho = config["model"]["num_rho"] 25 | seq_length = config["model"]["seq_length"] 26 | size = config["data"]["size"] 27 | win = config["model"]["win"] 28 | stride = config["model"]["stride"] 29 | 30 | inference = Inference(model_path, num_angle, num_rho, seq_length, win, stride) 31 | dataset = SeqDataset( 32 | data_path=data_path, 33 | split=split, 34 | size=size, 35 | seq_length=seq_length, 36 | num_angle=num_angle, 37 | num_rho=num_rho, 38 | augment=False, 39 | ) 40 | 41 | print(f"total: {len(dataset)}") 42 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 43 | 44 | theta_diffs = torch.tensor([], device="cuda") 45 | rho_diffs = torch.tensor([], device="cuda") 46 | tip_diffs = torch.tensor([], device="cuda") 47 | time_list = [] 48 | 49 | for batch in tqdm(dataloader): 50 | img, _, _, theta, rho, tip = batch 51 | theta = theta.to("cuda") 52 | rho = rho.to("cuda") 53 | tip = tip.to("cuda") 54 | with torch.no_grad(): 55 | time_start = time.time() 56 | theta_pred, rho_pred, tip_pred, _, _, _ = inference(img) 57 | time_end = time.time() 58 | time_list.append(time_end - time_start) 59 | theta_diffs_curr = torch.abs(theta_pred - theta) 60 | rho_diffs_curr = ( 61 | torch.abs(rho_pred - rho) 62 | * sqrt(size[0] ** 2 + size[1] ** 2) 63 | / num_rho 64 | * 50 65 | / size[0] 66 | ) 67 | tip_diffs_curr = torch.norm(tip_pred - tip, dim=1) * 50 / size[0] 68 | theta_diffs = torch.cat([theta_diffs, theta_diffs_curr], dim=0) 69 | rho_diffs = torch.cat([rho_diffs, rho_diffs_curr], dim=0) 70 | tip_diffs = torch.cat([tip_diffs, tip_diffs_curr], dim=0) 71 | 72 | print(f"theta_diffs: mean {theta_diffs.mean()}, std {theta_diffs.std()}") 73 | print(f"rho_diffs: mean {rho_diffs.mean()}, std {rho_diffs.std()}") 74 | print(f"tip_diffs: mean {tip_diffs.mean()}, std {tip_diffs.std()}") 75 | print( 76 | f"time: mean {torch.tensor(time_list).mean()}, std {torch.tensor(time_list).std()}" 77 | ) 78 | 79 | print("test done!!!") 80 | 81 | time_list = torch.tensor(time_list) 82 | os.makedirs(save_path, exist_ok=True) 83 | theta_diffs = theta_diffs.cpu().numpy() 84 | rho_diffs = rho_diffs.cpu().numpy() 85 | tip_diffs = tip_diffs.cpu().numpy() 86 | time_list = time_list.cpu().numpy() 87 | 88 | np.save(f"{save_path}/theta_diffs.npy", theta_diffs) 89 | np.save(f"{save_path}/rho_diffs.npy", rho_diffs) 90 | np.save(f"{save_path}/tip_diffs.npy", tip_diffs) 91 | 92 | 93 | if __name__ == "__main__": 94 | model_path = "./logs/beef/model.pth" 95 | config_path = "./logs/beef/config.txt" 96 | 97 | batch_size = 4 98 | 99 | save_path = Path("./results") 100 | 101 | data_path_prefix = Path("./dataset") 102 | tissues = ["Beef", "Pork"] 103 | splits = ["challenging", "normal"] 104 | 105 | with open(config_path, "r") as f: 106 | config = eval(f.read()) 107 | 108 | save_path_prefix = save_path / config["expriment_name"] 109 | 110 | exp_name = config["expriment_name"] 111 | for t in tissues: 112 | for split in splits: 113 | data_path = data_path_prefix / t 114 | save_path = str(save_path_prefix / (t.lower() + "_" + split)) 115 | os.makedirs(save_path, exist_ok=True) 116 | 117 | print(f"Testing {t} {split}...") 118 | test( 119 | model_path, 120 | data_path, 121 | split, 122 | config, 123 | batch_size, 124 | save_path, 125 | ) 126 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from tensorboardX import SummaryWriter 9 | from torch.utils.data import DataLoader 10 | 11 | from dataset import SeqDataset 12 | from model.vibnet import VibNet 13 | from utils import reverse_all_hough_space, reverse_max_hough_space, vis_result 14 | 15 | 16 | def setup_seed(seed): 17 | # random package 18 | random.seed(seed) 19 | 20 | # torch package 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | 25 | # numpy package 26 | np.random.seed(seed) 27 | 28 | # os 29 | os.environ["PYTHONHASHSEED"] = str(seed) 30 | 31 | 32 | def modified_focal_loss(pred, gt): 33 | pos_inds = gt.eq(1).float() 34 | neg_inds = gt.lt(1).float() 35 | 36 | neg_weights = torch.pow(1 - gt, 4) 37 | 38 | pred = torch.clamp(torch.sigmoid(pred), min=1e-4, max=1 - 1e-4) 39 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds 40 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds 41 | 42 | loss = -(pos_loss + neg_loss).mean() 43 | return loss 44 | 45 | 46 | def get_model_dataset_expname(config): 47 | win = config.get("model").get("win") 48 | stride = config.get("model").get("stride") 49 | enc_init = config.get("model").get("enc_init") 50 | fic_init = config.get("model").get("fic_init") 51 | 52 | model = VibNet( 53 | num_angle=config["model"]["num_angle"], 54 | num_rho=config["model"]["num_rho"], 55 | seq_len=config["model"]["seq_length"], 56 | win=win if win is not None else 10, 57 | stride=stride if stride is not None else 5, 58 | enc_init=enc_init if enc_init is not None else True, 59 | fic_init=fic_init if fic_init is not None else True, 60 | ) 61 | 62 | if config.get("expriment_name") is None: 63 | expriment_name = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 64 | else: 65 | expriment_name = config["expriment_name"] + time.strftime( 66 | "_(%Y-%m-%d-%H-%M-%S)", time.localtime() 67 | ) 68 | 69 | dataset_train = SeqDataset( 70 | data_path=config["data"]["data_path"], 71 | split="train", 72 | size=config["data"]["size"], 73 | seq_length=config["model"]["seq_length"], 74 | num_angle=config["model"]["num_angle"], 75 | num_rho=config["model"]["num_rho"], 76 | augment=True, 77 | ) 78 | 79 | dataset_val = SeqDataset( 80 | data_path=config["data"]["data_path"], 81 | split="val", 82 | size=config["data"]["size"], 83 | seq_length=config["model"]["seq_length"], 84 | num_angle=config["model"]["num_angle"], 85 | num_rho=config["model"]["num_rho"], 86 | augment=False, 87 | ) 88 | 89 | return model, dataset_train, dataset_val, expriment_name 90 | 91 | 92 | def train(config): 93 | model, dataset_train, dataset_val, expriment_name = get_model_dataset_expname( 94 | config 95 | ) 96 | 97 | device = config["train"]["device"] 98 | 99 | log_path = f"logs/{expriment_name}" 100 | writer = SummaryWriter(log_path) 101 | figure_path = f"{log_path}/figures_val" 102 | os.makedirs(figure_path, exist_ok=True) 103 | 104 | val_every_n = config["train"]["val_every_n"] 105 | print_every_n = config["train"]["print_every_n"] 106 | early_stop_thres = config["train"]["early_stop"] 107 | 108 | # save config 109 | with open(f"{log_path}/config.txt", "w") as f: 110 | f.write(str(config)) 111 | 112 | best_val_loss = 100 113 | early_stop_cnt = 0 114 | optimizer = torch.optim.Adam(model.parameters(), lr=config["train"]["lr"]) 115 | if config["model"]["FocalLoss"]: 116 | loss_fn = modified_focal_loss 117 | else: 118 | loss_fn = torch.nn.BCEWithLogitsLoss() 119 | loss_fn.to(device) 120 | 121 | train_loader = DataLoader( 122 | dataset_train, 123 | batch_size=config["train"]["batch_size_train"], 124 | shuffle=True, 125 | ) 126 | val_loader = DataLoader( 127 | dataset_val, 128 | batch_size=config["train"]["batch_size_val"], 129 | shuffle=True, 130 | ) 131 | 132 | model.train() 133 | model.to(device) 134 | 135 | loss_shaft_curr, loss_tip_curr, loss_curr = 0, 0, 0 136 | loss_shaft_print, loss_tip_print, loss_print = 0, 0, 0 137 | for epoch in range(config["train"]["epoch"]): 138 | for i, (img, hough_space_label, _, _, _, _) in enumerate(train_loader): 139 | img = img.to(device) 140 | hough_space_label = hough_space_label.to(device) 141 | 142 | optimizer.zero_grad() 143 | 144 | pred = model(img) 145 | loss_shaft_curr = loss_fn(pred[:, 0, :, :], hough_space_label[:, 0, :, :]) 146 | loss_shaft_print += loss_shaft_curr 147 | 148 | loss_tip_curr = loss_fn(pred[:, 1, :, :], hough_space_label[:, 1, :, :]) 149 | loss_tip_print += loss_tip_curr 150 | 151 | loss_curr = ( 152 | config["train"]["w_shaft"] * loss_shaft_curr 153 | + config["train"]["w_tip"] * loss_tip_curr 154 | ) 155 | loss_print += loss_curr 156 | 157 | loss_curr.backward() 158 | optimizer.step() 159 | 160 | if (i + 1) % print_every_n == 0 or i == len(train_loader) - 1: 161 | loss_shaft_print /= print_every_n 162 | loss_tip_print /= print_every_n 163 | loss_print /= print_every_n 164 | print( 165 | f"Epoch {epoch} | Iter {i} | Loss {loss_print} (Shaft {loss_shaft_print}, Tip {loss_tip_print})" 166 | ) 167 | writer.add_scalar( 168 | "loss/train", loss_print, epoch * len(train_loader) + i 169 | ) 170 | writer.add_scalar( 171 | "loss_shaft/train", 172 | loss_shaft_print, 173 | epoch * len(train_loader) + i, 174 | ) 175 | writer.add_scalar( 176 | "loss_tip/train", 177 | loss_tip_print, 178 | epoch * len(train_loader) + i, 179 | ) 180 | 181 | loss_shaft_print, loss_tip_print, loss_print = 0, 0, 0 182 | 183 | # validation 184 | if (epoch > 0 or (i + 1) >= 3000) and ( 185 | ((i + 1) % val_every_n == 0) or i == len(train_loader) - 1 186 | ): 187 | val_loss, val_loss_shaft, val_loss_tip = validate( 188 | model, config, val_loader, loss_fn, epoch, i, figure_path 189 | ) 190 | 191 | print("======================================================") 192 | print( 193 | f"Epoch {epoch} | Iter {i} | Val Loss {val_loss} (Shaft {val_loss_shaft}, Tip {val_loss_tip})" 194 | ) 195 | print("======================================================") 196 | writer.add_scalar("loss/val", val_loss, epoch * len(train_loader) + i) 197 | writer.add_scalar( 198 | "loss_shaft/val", 199 | val_loss_shaft, 200 | epoch * len(train_loader) + i, 201 | ) 202 | writer.add_scalar( 203 | "loss_tip/val", 204 | val_loss_tip, 205 | epoch * len(train_loader) + i, 206 | ) 207 | 208 | if val_loss < best_val_loss: 209 | best_val_loss = val_loss 210 | torch.save(model.state_dict(), f"{log_path}/model.pth") 211 | early_stop_cnt = 0 212 | else: 213 | print("No improvement!!") 214 | early_stop_cnt += 1 215 | if early_stop_cnt >= early_stop_thres: 216 | print("Early stop!!!") 217 | return 218 | 219 | model.train() 220 | 221 | 222 | def validate(model, config, val_loader, loss_fn, epoch, i, figure_path): 223 | device = config["train"]["device"] 224 | with torch.no_grad(): 225 | model.eval() 226 | val_loss_shaft = 0 227 | val_loss_tip = 0 228 | k = 0 229 | for j, (img, hough_space_label, label, _, _, _) in enumerate(val_loader): 230 | img = img.to(device) 231 | hough_space_label = hough_space_label.to(device) 232 | pred = model(img) 233 | val_loss_shaft += loss_fn(pred[:, 0, :, :], hough_space_label[:, 0, :, :]) 234 | val_loss_tip += loss_fn(pred[:, 1, :, :], hough_space_label[:, 1, :, :]) 235 | 236 | # save 237 | if k < 10: 238 | for j in range(5): 239 | try: 240 | # visualize the shaft prediction 241 | line = reverse_max_hough_space( 242 | torch.zeros(img.shape[-2:], device=device), 243 | pred[j][0], 244 | num_angle=config["model"]["num_angle"], 245 | num_rho=config["model"]["num_rho"], 246 | ) 247 | img_shaft = vis_result(img[j][-1], line, label[j]) 248 | cv2.imwrite( 249 | f"{figure_path}/{epoch}_{i}_shaft_{k}.jpg", 250 | img_shaft, 251 | ) 252 | 253 | # visualize the tip prediction 254 | line = reverse_all_hough_space( 255 | torch.zeros(img.shape[-2:], device=device), 256 | pred[j][1].sigmoid(), 257 | num_angle=config["model"]["num_angle"], 258 | num_rho=config["model"]["num_rho"], 259 | ) 260 | img_tip = vis_result(img[j][-1], line, label[j]) 261 | cv2.imwrite( 262 | f"{figure_path}/{epoch}_{i}_tip_{k}.jpg", 263 | img_tip, 264 | ) 265 | 266 | k += 1 267 | except IndexError: 268 | pass 269 | val_loss = ( 270 | config["train"]["w_shaft"] * val_loss_shaft 271 | + config["train"]["w_tip"] * val_loss_tip 272 | ) 273 | val_loss /= len(val_loader) 274 | val_loss_tip /= len(val_loader) 275 | val_loss_shaft /= len(val_loader) 276 | 277 | return val_loss, val_loss_shaft, val_loss_tip 278 | 279 | 280 | if __name__ == "__main__": 281 | from config import config_list 282 | 283 | setup_seed(42) 284 | for config in config_list: 285 | train(config) 286 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def theta_rho_to_xy(img_shape, theta, rho, num_angle, num_rho): 7 | """ 8 | Convert the theta and rho to the coordinates of the line (start point and end point of each line). 9 | 10 | Args: 11 | img_shape (tuple): The shape of the image (H, W), where the lines will be drawn. 12 | theta (torch.tensor): Thetas of the lines in the shape of (num_line,). 13 | rho (torch.tensor): Rhos of the lines in the shape of (num_line,). 14 | num_angle (int): the number of angles in the hough space. 15 | num_rho (int): the number of rhos in the hough space. 16 | 17 | Returns: 18 | torch.tensor: 2 tensors, the coordinates of the point in the shape of (num_line, 2). 19 | """ 20 | # calculate resolution of rho and theta 21 | H, W = img_shape 22 | if theta.dim() == 0: 23 | theta = theta.unsqueeze(0) 24 | if rho.dim() == 0: 25 | rho = rho.unsqueeze(0) 26 | # theta_idx, rho_idx = theta, rho 27 | l = torch.sqrt(torch.tensor(H * H + W * W)) 28 | irho = l / num_rho 29 | itheta = torch.pi / num_angle 30 | theta = theta * itheta 31 | # shift the index back 32 | rho = rho - int((num_rho) / 2) 33 | rho = rho * irho 34 | # calculate the coordinates of the line 35 | cos = torch.cos(theta) 36 | sin = torch.sin(theta) 37 | 38 | # x, y in image coordinate 39 | x0 = cos * rho 40 | y0 = sin * rho 41 | 42 | steps = l / 2 43 | x1 = x0 + steps * (-sin) + W / 2 44 | y1 = y0 + steps * cos + H / 2 45 | x2 = x0 - steps * (-sin) + W / 2 46 | y2 = y0 - steps * cos + H / 2 47 | 48 | # shift to tensor coordinate 49 | p1 = torch.stack([y1, x1], dim=1) 50 | p2 = torch.stack([y2, x2], dim=1) 51 | 52 | return p1, p2 53 | 54 | 55 | def draw_lines(img, p1, p2, weight, width=5): 56 | """ 57 | Draw lines into an image. The number of weights should be the same as the number of lines. 58 | 59 | Args: 60 | img(torch.tensor): the input image, where to draw the lines in the shape of (H, W). 61 | p1(torch.tensor): the start point [x y] of the line in the shape of(B, 2). 62 | p2(torch.tensor): the end point [x y] of the line in the shape of of (B, 2). 63 | weight(torch.tensor): the weight of the line in the shape of of (B, ). 64 | width(int): the width of the line. The default value is 5, which is best values for num_rho = 100. 65 | the best value is calculated by: width = sqrt(H^2 + W^2) / num_rho 66 | Return: 67 | (torch.tensor): the image with containing the line. 68 | """ 69 | 70 | H, W = img.shape 71 | step = int(torch.sqrt(torch.tensor(H * H + W * W)) + 1) 72 | # filter out all nan, line is not in the image 73 | mask = torch.isnan(p1).any(dim=1) | torch.isnan(p2).any(dim=1) 74 | p1 = p1[~mask] 75 | p2 = p2[~mask] 76 | weight = weight[~mask] 77 | 78 | dx = ((p2[:, 0] - p1[:, 0]) / step).unsqueeze(1) 79 | dy = ((p2[:, 1] - p1[:, 1]) / step).unsqueeze(1) 80 | new_x = torch.repeat_interleave(p1[:, 0], step).reshape(p1.shape[0], step) 81 | new_y = torch.repeat_interleave(p1[:, 1], step).reshape(p1.shape[0], step) 82 | weight = torch.repeat_interleave(weight, step).reshape(p1.shape[0], step) 83 | i = torch.arange(step, device=dx.device).unsqueeze(0) 84 | 85 | new_x += dx * i 86 | new_y += dy * i 87 | new_x = new_x.long() 88 | new_y = new_y.long() 89 | 90 | idx = torch.arange(p1.shape[0], device=img.device) 91 | idx = torch.repeat_interleave(idx, step).reshape(p1.shape[0], step) 92 | mask = (new_x >= 0) & (new_x < H) & (new_y >= 0) & (new_y < W) 93 | new_x = new_x[mask] 94 | new_y = new_y[mask] 95 | idx = idx[mask] 96 | weight = weight[mask] 97 | 98 | idx = torch.stack([idx, new_x, new_y], dim=0).long() 99 | weight = weight.to(torch.float16) # for saving space 100 | 101 | img_temp = torch.zeros((p1.shape[0], H, W), dtype=weight.dtype, device=img.device) 102 | # print(img_temp.element_size() * img_temp.nelement() / 1024 / 1024) 103 | bound = torch.tensor(H - 1) 104 | for i in range(width): 105 | # not the best way to set the width, but it's ok 106 | img_temp[idx[0], torch.min(bound, idx[1] + i), idx[2]] = weight 107 | img += img_temp.sum(dim=0) 108 | return img 109 | 110 | 111 | def reverse_max_hough_space(img, hough_space, num_angle, num_rho, width=5): 112 | """ 113 | Reverse the line with highest value in the hough space to the image. 114 | 115 | Args: 116 | img (torch.tensor): the tensor image in the shape of (H, W), where the lines will be drawn. 117 | hough_space (torch.tensor): the hough space in the shape of (1, num_angle, num_rho) or (num_angle, num_rho). 118 | num_angle (int): the number of angles in the hough space. 119 | num_rho (int): the number of rhos in the hough space. 120 | width (int): the width of the line to be drawn. 121 | 122 | Returns: 123 | (torch.tensor): an image with a line drawn. 124 | """ 125 | hough_space = torch.squeeze(hough_space) 126 | 127 | # find the index of the max value of the hough space 128 | max_loc = torch.argmax(hough_space) 129 | theta = max_loc // num_rho 130 | rho = max_loc % num_rho 131 | 132 | p1, p2 = theta_rho_to_xy(img.shape, theta, rho, num_angle, num_rho) 133 | 134 | img = draw_lines(img, p1, p2, torch.tensor([255], device=img.device), width=width) 135 | 136 | return img 137 | 138 | 139 | def reverse_all_hough_space( 140 | img, hough_space, num_angle, num_rho, threshold=1e-3, width=5 141 | ): 142 | """ 143 | Reverse the hough space (contains a lot of lines) to the image. 144 | 145 | Args: 146 | img (torch.tensor): the tensor image in the shape of (H, W), where the lines will be drawn. 147 | hough_space (torch.tensor): the hough space in the shape of (1, num_angle, num_rho) or (num_angle, num_rho). 148 | num_angle (int): the number of angles in the hough space. 149 | num_rho (int): the number of rhos in the hough space. 150 | thereshold (float): the threshold to filter the hough space. 151 | width (int): the width of the line to be drawn. 152 | 153 | Returns: 154 | (torch.tensor): an image with lines drawn. 155 | """ 156 | hough_space = torch.squeeze(hough_space) 157 | hough_space[hough_space < threshold] = 0 158 | 159 | theta, rho = torch.nonzero(hough_space, as_tuple=True) 160 | 161 | if not theta.size(0): 162 | return img 163 | 164 | value = hough_space[theta, rho] 165 | 166 | p1, p2 = theta_rho_to_xy(img.shape, theta, rho, num_angle, num_rho) 167 | 168 | img = draw_lines(img, p1, p2, value, width=width) 169 | 170 | img = (img - img.min()) / (img.max() - img.min()) 171 | 172 | img = img * 255 173 | 174 | return img 175 | 176 | 177 | def vis_result(input_img, line, label=None): 178 | input_img = input_img.squeeze().cpu().numpy() 179 | input_img = (input_img + 1) * 127.5 180 | input_img = input_img.astype(np.uint8) 181 | line = line.cpu().numpy().astype(input_img.dtype) 182 | 183 | if label is not None: 184 | label = label.cpu().numpy().astype(input_img.dtype) 185 | res = cv2.addWeighted(input_img, 1, label * 255, 0.5, 0) 186 | res = cv2.addWeighted(res, 1, line, 0.8, 0) 187 | else: 188 | res = cv2.addWeighted(input_img, 1, line, 0.8, 0) 189 | 190 | return res 191 | -------------------------------------------------------------------------------- /video.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marslicy/VibNet/8830a2f5fde163776e4abb92ba06c723d26795fa/video.jpg --------------------------------------------------------------------------------