├── .gitignore
├── .gitmodules
├── Dockerfile
├── LICENSE
├── README.md
├── flow4D.py
├── lightning_logs
└── version_0
│ ├── hparams.yaml
│ └── metrics.csv
└── network_4D.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | *.json
3 | logs/*
4 | *.pyc
5 |
6 | tmp_sbatch.sh
7 | *.log
8 |
9 | *.hydra
10 |
11 | *.h5
12 |
13 | *.zip
14 |
15 | cancel.sh
16 |
17 | # slurm log files
18 | *.err
19 | *.out
20 |
21 | # figure
22 | *.png
23 | *.pdf
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "OpenSceneFlow"]
2 | path = OpenSceneFlow
3 | url = https://github.com/KTH-RPL/OpenSceneFlow
4 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # check more: https://hub.docker.com/r/nvidia/cuda
2 | FROM nvidia/cuda:11.7.1-devel-ubuntu20.04
3 | ENV DEBIAN_FRONTEND noninteractive
4 |
5 | RUN apt update && apt install -y --no-install-recommends \
6 | git curl vim rsync htop
7 |
8 | RUN curl -o ~/miniconda.sh -LO https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
9 | chmod +x ~/miniconda.sh && \
10 | ~/miniconda.sh -b -p /opt/conda && \
11 | rm ~/miniconda.sh && \
12 | /opt/conda/bin/conda clean -ya && /opt/conda/bin/conda init bash
13 |
14 | RUN curl -o ~/mamba.sh -LO https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh && \
15 | chmod +x ~/mamba.sh && \
16 | ~/mamba.sh -b -p /opt/mambaforge && \
17 | rm ~/mamba.sh && /opt/mambaforge/bin/mamba init bash
18 |
19 | # install zsh and oh-my-zsh
20 | RUN apt install -y wget git zsh tmux vim g++
21 | RUN sh -c "$(wget -O- https://github.com/deluan/zsh-in-docker/releases/download/v1.1.5/zsh-in-docker.sh)" -- \
22 | -t robbyrussell -p git \
23 | -p https://github.com/agkozak/zsh-z \
24 | -p https://github.com/zsh-users/zsh-autosuggestions \
25 | -p https://github.com/zsh-users/zsh-completions \
26 | -p https://github.com/zsh-users/zsh-syntax-highlighting
27 |
28 | RUN printf "y\ny\ny\n\n" | bash -c "$(curl -fsSL https://raw.githubusercontent.com/Kin-Zhang/Kin-Zhang/main/scripts/setup_ohmyzsh.sh)"
29 | RUN /opt/conda/bin/conda init zsh && /opt/mambaforge/bin/mamba init zsh
30 |
31 | # change to conda env
32 | ENV PATH /opt/conda/bin:$PATH
33 | ENV PATH /opt/mambaforge/bin:$PATH
34 |
35 | RUN mkdir -p /home/kin/workspace && cd /home/kin/workspace && git clone --recursive https://github.com/KTH-RPL/DeFlow.git
36 | WORKDIR /home/kin/workspace/DeFlow
37 | RUN apt-get update && apt-get install libgl1 -y
38 | # need read the gpu device info to compile the cuda extension
39 | RUN cd /home/kin/workspace/DeFlow && /opt/mambaforge/bin/mamba env create -f environment.yaml
40 | RUN cd /home/kin/workspace/DeFlow/mmcv && export MMCV_WITH_OPS=1 && export FORCE_CUDA=1 && /opt/mambaforge/envs/deflow/bin/pip install -e .
41 |
42 |
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Robotics, Perception and Learning @KTH
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Flow4D: Leveraging 4D Voxel Network for LiDAR Scene Flow Estimation (RA-L 2025)
2 |
3 | This repository contains the code for the [Flow4D paper (RA-L 2025)](https://ieeexplore.ieee.org/document/10887254)
4 |
5 |
6 | ## Notice
7 |
8 | **Flow4D has been integrated into [OpenSceneFlow](https://github.com/KTH-RPL/OpenSceneFlow).**
9 | Please visit the [OpenSceneFlow](https://github.com/KTH-RPL/OpenSceneFlow) repository for the latest updates and developments.
10 |
11 | This repo saved README, and quick core file in Flow4D for a quick reference.
12 | The old source code branch is also [available here](https://github.com/dgist-cvlab/Flow4D/tree/source).
13 |
14 | ## Requirements
15 |
16 | This code is based on DeFlow.
17 | Please follow the installation instructions from the [DeFlow repository](https://github.com/KTH-RPL/DeFlow).
18 |
19 | Additionally, you need to install `spconv 2.3.6`.
20 | You can find the installation instructions here: [spconv](https://github.com/traveller59/spconv).
21 |
22 |
23 | ## Training
24 |
25 | To train the model, use the following command:
26 |
27 | ```bash
28 | python train.py model=flow4d lr=1e-3 epochs=15 batch_size=8 num_frames=5 loss_fn=deflowLoss "voxel_size=[0.2, 0.2, 0.2]" "point_cloud_range=[-51.2, -51.2, -3.2, 51.2, 51.2, 3.2]"
29 | ```
30 |
31 |
32 | ## Inference
33 |
34 | To perform inference, use the following command:
35 |
36 | ```bash
37 | python eval.py checkpoint=path_to_checkpoint av2_mode=(val, test)
38 | ```
39 |
40 | Replace `path_to_checkpoint` with the actual path to your checkpoint file and choose either `val` or `test`.
41 |
42 |
43 | ## Gratitude
44 | This code is based on the [DeFlow code](https://github.com/KTH-RPL/DeFlow) by Qingwen Zhang.
45 | We extend our deepest gratitude to her.
46 | Additionally, we would like to express our sincere thanks to Kyle Vedder et al. for hosting and providing extensive support for [Argoverse2 2024 Scene Flow Challenge](https://www.argoverse.org/sceneflow.html)
47 |
48 |
--------------------------------------------------------------------------------
/flow4D.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | # Created: 2023-07-18 15:08
4 | # Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology
5 | # Author: Kin ZHANG (https://kin-zhang.github.io/)
6 | #
7 | # This work is licensed under the terms of the MIT license.
8 | # For a copy, see .
9 | """
10 |
11 | import torch.nn as nn
12 | import dztimer, torch
13 |
14 | from .basic.embedder_model_flow4D import DynamicEmbedder_4D
15 | from .basic import cal_pose0to1
16 |
17 | from .basic.network_4D import Network_4D, Seperate_to_3D, Point_head
18 |
19 |
20 | def replace_feature(out, new_features):
21 | if "replace_feature" in out.__dir__():
22 | # spconv 2.x behaviour
23 | return out.replace_feature(new_features)
24 | else:
25 | out.features = new_features
26 | return out
27 |
28 |
29 | class Flow4D(nn.Module):
30 | def __init__(self, voxel_size = [0.2, 0.2, 0.2],
31 | point_cloud_range = [-51.2, -51.2, -2.2, 51.2, 51.2, 4.2],
32 | grid_feature_size = [512, 512, 32],
33 | num_frames = 5):
34 | super().__init__()
35 |
36 | point_output_ch = 16
37 | voxel_output_ch = 16
38 |
39 | self.num_frames = num_frames
40 | print('voxel_size = {}, pseudo_dims = {}, input_num_frames = {}'.format(voxel_size, grid_feature_size, self.num_frames))
41 |
42 | self.embedder_4D = DynamicEmbedder_4D(voxel_size=voxel_size,
43 | pseudo_image_dims=[grid_feature_size[0], grid_feature_size[1], grid_feature_size[2], num_frames],
44 | point_cloud_range=point_cloud_range,
45 | feat_channels=point_output_ch)
46 |
47 | self.network_4D = Network_4D(in_channel=point_output_ch, out_channel=voxel_output_ch)
48 |
49 | self.seperate_feat = Seperate_to_3D(num_frames)
50 |
51 | self.pointhead_3D = Point_head(voxel_feat_dim=voxel_output_ch, point_feat_dim=point_output_ch)
52 |
53 | self.timer = dztimer.Timing()
54 | self.timer.start("Total")
55 |
56 | def load_from_checkpoint(self, ckpt_path):
57 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
58 | state_dict = {
59 | k[len("model.") :]: v for k, v in ckpt.items() if k.startswith("model.")
60 | }
61 | print("\nLoading... model weight from: ", ckpt_path, "\n")
62 | return self.load_state_dict(state_dict=state_dict, strict=False)
63 |
64 | def forward(self, batch):
65 | #t_deflow_start = time.time()
66 | """
67 | input: using the batch from dataloader, which is a dict
68 | Detail: [pc0, pc1, pose0, pose1]
69 | output: the predicted flow, pose_flow, and the valid point index of pc0
70 | """
71 |
72 | self.timer[0].start("Data Preprocess")
73 | batch_sizes = len(batch["pose0"])
74 |
75 | pose_flows = []
76 | transform_pc0s = []
77 | transform_pc_m_frames = [[] for _ in range(self.num_frames - 2)]
78 |
79 |
80 | for batch_id in range(batch_sizes):
81 | selected_pc0 = batch["pc0"][batch_id]
82 | self.timer[0][0].start("pose")
83 | with torch.no_grad():
84 | if 'ego_motion' in batch:
85 | pose_0to1 = batch['ego_motion'][batch_id]
86 | else:
87 | pose_0to1 = cal_pose0to1(batch["pose0"][batch_id], batch["pose1"][batch_id])
88 |
89 | if self.num_frames > 2:
90 | past_poses = []
91 | for i in range(1, self.num_frames - 1):
92 | past_pose = cal_pose0to1(batch[f"pose_m{i}"][batch_id], batch["pose1"][batch_id])
93 | past_poses.append(past_pose)
94 | self.timer[0][0].stop()
95 |
96 | self.timer[0][1].start("transform")
97 | transform_pc0 = selected_pc0 @ pose_0to1[:3, :3].T + pose_0to1[:3, 3] #t -> t+1 warping
98 | self.timer[0][1].stop()
99 | pose_flows.append(transform_pc0 - selected_pc0)
100 | transform_pc0s.append(transform_pc0)
101 |
102 | for i in range(1, self.num_frames - 1):
103 | selected_pc_m = batch[f"pc_m{i}"][batch_id]
104 | transform_pc_m = selected_pc_m @ past_poses[i-1][:3, :3].T + past_poses[i-1][:3, 3]
105 | transform_pc_m_frames[i-1].append(transform_pc_m)
106 |
107 | pc_m_frames = [torch.stack(transform_pc_m_frames[i], dim=0) for i in range(self.num_frames - 2)]
108 |
109 | pc0s = torch.stack(transform_pc0s, dim=0)
110 | pc1s = batch["pc1"]
111 | self.timer[0].stop()
112 |
113 |
114 | pcs_dict = {
115 | 'pc0s': pc0s,
116 | 'pc1s': pc1s,
117 | }
118 | for i in range(1, self.num_frames - 1):
119 | pcs_dict[f'pc_m{i}s'] = pc_m_frames[i-1]
120 |
121 |
122 | self.timer[1].start("4D_voxelization")
123 | dict_4d = self.embedder_4D(pcs_dict)
124 | pc01_tesnor_4d = dict_4d['4d_tensor']
125 | pc0_3dvoxel_infos_lst =dict_4d['pc0_3dvoxel_infos_lst']
126 | pc0_point_feats_lst =dict_4d['pc0_point_feats_lst']
127 | pc0_num_voxels = dict_4d['pc0_mum_voxels']
128 | self.timer[1].stop()
129 |
130 | self.timer[2].start("4D_backbone")
131 | pc_all_output_4d = self.network_4D(pc01_tesnor_4d) #all = past, current, next 다 합친것
132 | self.timer[2].stop()
133 |
134 | self.timer[3].start("4D pc01 to 3D pc0")
135 | pc0_last = self.seperate_feat(pc_all_output_4d)
136 | assert pc0_last.features.shape[0] == pc0_num_voxels, 'voxel number mismatch'
137 | self.timer[3].stop()
138 |
139 | self.timer[4].start("3D_sparsetensor_to_point and head")
140 | flows = self.pointhead_3D(pc0_last, pc0_3dvoxel_infos_lst, pc0_point_feats_lst)
141 | self.timer[4].stop()
142 |
143 | pc0_points_lst = [e["points"] for e in pc0_3dvoxel_infos_lst]
144 | pc0_valid_point_idxes = [e["point_idxes"] for e in pc0_3dvoxel_infos_lst]
145 |
146 | model_res = {
147 | "flow": flows,
148 | 'pose_flow': pose_flows,
149 |
150 | "pc0_valid_point_idxes": pc0_valid_point_idxes,
151 | "pc0_points_lst": pc0_points_lst,
152 |
153 | }
154 | return model_res
--------------------------------------------------------------------------------
/lightning_logs/version_0/hparams.yaml:
--------------------------------------------------------------------------------
1 | cfg:
2 | model:
3 | name: flow4D
4 | target:
5 | _target_: scripts.network.models.flow4D.Flow4D
6 | voxel_size:
7 | - 0.2
8 | - 0.2
9 | - 0.2
10 | point_cloud_range:
11 | - -51.2
12 | - -51.2
13 | - -2.2
14 | - 51.2
15 | - 51.2
16 | - 4.2
17 | num_frames: 5
18 | grid_feature_size:
19 | - 512
20 | - 512
21 | - 32
22 | val_monitor: val/Dynamic/Mean
23 | slurm_id: 0
24 | wandb_mode: offline
25 | dataset_path: /data/datasets/argoverse2/argoverse2/preprocess/sensor
26 | output: 0709_223159_04__0709_223427
27 | checkpoint: logs/wandb/0709_223159/checkpoints/04__0.194_0.011_0.025_0.055_0.006_0.014_0.77.ckpt
28 | av2_mode: val
29 | save_res: false
30 | submit_version: '2024'
31 | num_frames: 5
32 | loss_fn: deflowLoss
33 | gpus: 1
34 | seed: 42069
35 | eval: true
36 |
--------------------------------------------------------------------------------
/lightning_logs/version_0/metrics.csv:
--------------------------------------------------------------------------------
1 | epoch,step,val/Dynamic/BACKGROUND,val/Dynamic/CAR,val/Dynamic/Mean,val/Dynamic/OTHER_VEHICLES,val/Dynamic/PEDESTRIAN,val/Dynamic/WHEELED_VRU,val/EPE_BS,val/EPE_FD,val/EPE_FS,val/IoU,val/Static/BACKGROUND,val/Static/CAR,val/Static/Mean,val/Static/OTHER_VEHICLES,val/Static/PEDESTRIAN,val/Static/WHEELED_VRU,val/Three-way
2 | 0,0,nan,0.12572486698627472,0.18557094037532806,0.2090398073196411,0.2297278642654419,0.17779125273227692,0.00477928202599287,0.07636797428131104,0.01351962797343731,0.7429938912391663,0.004683776292949915,0.012337284162640572,0.011209582909941673,0.017054351046681404,0.011305715888738632,0.010666786693036556,0.031555626541376114
3 |
--------------------------------------------------------------------------------
/network_4D.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pdb, json
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import spconv as spconv_core
7 | #from easydict import EasyDict
8 | import yaml
9 | spconv_core.constants.SPCONV_ALLOW_TF32 = True
10 |
11 | import spconv.pytorch as spconv
12 | import time
13 | from spconv.utils import Point2VoxelCPU3d as VoxelGenerator
14 | tv = None
15 | try:
16 | import cumm.tensorview as tv
17 | except:
18 | pass
19 |
20 | from torch.autograd import Function
21 | from torch.autograd.function import once_differentiable
22 | import torch.cuda.amp as amp
23 | _TORCH_CUSTOM_FWD = amp.custom_fwd(cast_inputs=torch.float16)
24 | _TORCH_CUSTOM_BWD = amp.custom_bwd
25 |
26 | #from typing import List, Tuple, Dict
27 |
28 |
29 | def conv1x1x1x3(in_planes, out_planes, stride=1, indice_key=None):
30 | return spconv.SubMConv4d(in_planes, out_planes, kernel_size=(1,1,1,3), stride=stride,
31 | padding=(0,0,0,1), bias=False, indice_key=indice_key)
32 |
33 | def conv3x3x3x1(in_planes, out_planes, stride=1, indice_key=None):
34 | return spconv.SubMConv4d(in_planes, out_planes, kernel_size=(3,3,3,1), stride=stride,
35 | padding=(1,1,1,0), bias=False, indice_key=indice_key)
36 |
37 | def conv1x1x1x1(in_planes, out_planes, stride=1, indice_key=None):
38 | return spconv.SubMConv4d(in_planes, out_planes, kernel_size=(1,1,1,1), stride=stride,
39 | padding=0, bias=False, indice_key=indice_key)
40 |
41 |
42 | def conv3x3x3x3(in_planes, out_planes, stride=1, indice_key=None):
43 | return spconv.SubMConv4d(in_planes, out_planes, kernel_size=(3,3,3,3), stride=stride,
44 | padding=(1,1,1,1), bias=False, indice_key=indice_key)
45 |
46 | norm_cfg = {
47 | # format: layer_type: (abbreviation, module)
48 | "BN": ("bn", nn.BatchNorm2d),
49 | "BN1d": ("bn1d", nn.BatchNorm1d),
50 | "GN": ("gn", nn.GroupNorm),
51 | }
52 |
53 |
54 | class Seperate_to_3D(nn.Module):
55 | def __init__(self, num_frames):
56 | super(Seperate_to_3D, self).__init__()
57 | self.num_frames = num_frames
58 | #self.return_pc1 = return_pc1
59 |
60 | def forward(self, sparse_4D_tensor):
61 |
62 | indices_4d = sparse_4D_tensor.indices
63 | features_4d = sparse_4D_tensor.features
64 |
65 | pc0_time_value = self.num_frames-2
66 |
67 | mask_pc0 = (indices_4d[:, -1] == pc0_time_value)
68 |
69 | pc0_indices = indices_4d[mask_pc0][:, :-1]
70 | pc0_features = features_4d[mask_pc0]
71 |
72 | pc0_sparse_3D = sparse_4D_tensor.replace_feature(pc0_features)
73 | pc0_sparse_3D.spatial_shape = sparse_4D_tensor.spatial_shape[:-1]
74 | pc0_sparse_3D.indices = pc0_indices
75 |
76 | return pc0_sparse_3D
77 |
78 |
79 |
80 | class SpatioTemporal_Decomposition_Block(nn.Module):
81 | def __init__(self, in_filters, mid_filters, out_filters, indice_key=None, down_key = None, pooling=False, z_pooling=True, interact=False):
82 | super(SpatioTemporal_Decomposition_Block, self).__init__()
83 |
84 |
85 | self.pooling = pooling
86 |
87 | self.act = nn.LeakyReLU()
88 |
89 | self.spatial_conv_1 = conv3x3x3x1(in_filters, mid_filters, indice_key=indice_key + "bef")
90 | self.bn_s_1 = nn.BatchNorm1d(mid_filters)
91 |
92 | self.temporal_conv_1 = conv1x1x1x3(in_filters, mid_filters)
93 | self.bn_t_1 = nn.BatchNorm1d(mid_filters)
94 |
95 | self.fusion_conv_1 = conv1x1x1x1(mid_filters*2+in_filters, mid_filters, indice_key=indice_key + "1D")
96 | self.bn_fusion_1 = nn.BatchNorm1d(mid_filters)
97 |
98 |
99 | self.spatial_conv_2 = conv3x3x3x1(mid_filters, mid_filters, indice_key=indice_key + "bef")
100 | self.bn_s_2 = nn.BatchNorm1d(mid_filters)
101 |
102 | self.temporal_conv_2 = conv1x1x1x3(mid_filters, mid_filters)
103 | self.bn_t_2 = nn.BatchNorm1d(mid_filters)
104 |
105 | self.fusion_conv_2 = conv1x1x1x1(mid_filters*3, out_filters, indice_key=indice_key + "1D")
106 | self.bn_fusion_2 = nn.BatchNorm1d(out_filters)
107 |
108 |
109 | if self.pooling:
110 | if z_pooling == True:
111 | self.pool = spconv.SparseConv4d(out_filters, out_filters, kernel_size=(2,2,2,1), stride=(2,2,2,1), indice_key=down_key, bias=False)
112 | else:
113 | self.pool = spconv.SparseConv4d(out_filters, out_filters, kernel_size=(2,2,1,1), stride=(2,2,1,1), indice_key=down_key, bias=False)
114 |
115 | self.weight_initialization()
116 |
117 | def weight_initialization(self):
118 | for m in self.modules():
119 | if isinstance(m, nn.BatchNorm1d):
120 | nn.init.constant_(m.weight, 1)
121 | nn.init.constant_(m.bias, 0)
122 |
123 | def forward(self, x):
124 |
125 | #ST block
126 | S_feat_1 = self.spatial_conv_1(x)
127 | S_feat_1 = S_feat_1.replace_feature(self.bn_s_1(S_feat_1.features))
128 | S_feat_1 = S_feat_1.replace_feature(self.act(S_feat_1.features))
129 |
130 | T_feat_1 = self.temporal_conv_1(x)
131 | T_feat_1 = T_feat_1.replace_feature(self.bn_t_1(T_feat_1.features))
132 | T_feat_1 = T_feat_1.replace_feature(self.act(T_feat_1.features))
133 |
134 | ST_feat_1 = x.replace_feature(torch.cat([S_feat_1.features, T_feat_1.features, x.features], 1)) #residual까지 concate
135 |
136 | ST_feat_1 = self.fusion_conv_1(ST_feat_1)
137 | ST_feat_1 = ST_feat_1.replace_feature(self.bn_fusion_1(ST_feat_1.features))
138 | ST_feat_1 = ST_feat_1.replace_feature(self.act(ST_feat_1.features))
139 |
140 | #TS block
141 | S_feat_2 = self.spatial_conv_2(ST_feat_1)
142 | S_feat_2 = S_feat_2.replace_feature(self.bn_s_2(S_feat_2.features))
143 | S_feat_2 = S_feat_2.replace_feature(self.act(S_feat_2.features))
144 |
145 | T_feat_2 = self.temporal_conv_2(ST_feat_1)
146 | T_feat_2 = T_feat_2.replace_feature(self.bn_t_2(T_feat_2.features))
147 | T_feat_2 = T_feat_2.replace_feature(self.act(T_feat_2.features))
148 |
149 | ST_feat_2 = x.replace_feature(torch.cat([S_feat_2.features, T_feat_2.features, ST_feat_1.features], 1)) #residual까지 concate
150 |
151 | ST_feat_2 = self.fusion_conv_2(ST_feat_2)
152 | ST_feat_2 = ST_feat_2.replace_feature(self.bn_fusion_2(ST_feat_2.features))
153 | ST_feat_2 = ST_feat_2.replace_feature(self.act(ST_feat_2.features))
154 |
155 | if self.pooling:
156 | pooled = self.pool(ST_feat_2)
157 | return pooled, ST_feat_2
158 | else:
159 | return ST_feat_2
160 |
161 |
162 |
163 | class Network_4D(nn.Module):
164 | def __init__(self, in_channel=16, out_channel=16, model_size = 16):
165 | super().__init__()
166 |
167 | SpatioTemporal_Block = SpatioTemporal_Decomposition_Block
168 |
169 | self.model_size = model_size
170 |
171 |
172 | self.STDB_1_1_1 = SpatioTemporal_Block(in_channel, model_size, model_size, indice_key="st1_1", down_key='floor1')
173 | self.STDB_1_1_2 = SpatioTemporal_Block(model_size, model_size, model_size*2, indice_key="st1_1", down_key='floor1', pooling=True) #512 512 32 -> 256 256 16
174 |
175 | self.STDB_2_1_1 = SpatioTemporal_Block(model_size*2, model_size*2, model_size*2, indice_key="st2_1", down_key='floor2')
176 | self.STDB_2_1_2 = SpatioTemporal_Block(model_size*2, model_size*2, model_size*4, indice_key="st2_1", down_key='floor2', pooling=True) #256 256 16 -> 128 128 8
177 |
178 | self.STDB_3_1_1 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st3_1", down_key='floor3')
179 | self.STDB_3_1_2 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st3_1", down_key='floor3', pooling=True) #128 128 8 -> 64 64 4
180 |
181 | self.STDB_4_1_1 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st4_1", down_key='floor4')
182 | self.STDB_4_1_2 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st4_1", down_key='floor4', pooling=True, z_pooling=False) #64 64 4 -> 64 64 4
183 |
184 | self.STDB_5_1_1 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st5_1")
185 | self.STDB_5_1_2 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st5_1")
186 | self.up_subm_5 = spconv.SparseInverseConv4d(model_size*4, model_size*4, kernel_size=(2,2,1,1), indice_key='floor4', bias=False) #zpooling false
187 |
188 | self.STDB_4_2_1 = SpatioTemporal_Block(model_size*8, model_size*8, model_size*4, indice_key="st4_2")
189 | self.up_subm_4 = spconv.SparseInverseConv4d(model_size*4, model_size*4, kernel_size=(2,2,2,1), indice_key='floor3', bias=False)
190 |
191 | self.STDB_3_2_1 = SpatioTemporal_Block(model_size*8, model_size*8, model_size*4, indice_key="st3_2")
192 | self.up_subm_3 = spconv.SparseInverseConv4d(model_size*4, model_size*4, kernel_size=(2,2,2,1), indice_key='floor2', bias=False)
193 |
194 | self.STDB_2_2_1 = SpatioTemporal_Block(model_size*8, model_size*4, model_size*4, indice_key="st_2_2")
195 | self.up_subm_2 = spconv.SparseInverseConv4d(model_size*4, model_size*2, kernel_size=(2,2,2,1), indice_key='floor1', bias=False)
196 |
197 | self.STDB_1_2_1 = SpatioTemporal_Block(model_size*4, model_size*2, out_channel, indice_key="st_1_2")
198 |
199 |
200 | def forward(self, sp_tensor):
201 |
202 | sp_tensor = self.STDB_1_1_1(sp_tensor)
203 | down_2, skip_1 = self.STDB_1_1_2(sp_tensor)
204 |
205 | down_2 = self.STDB_2_1_1(down_2)
206 | down_3, skip_2 = self.STDB_2_1_2(down_2)
207 |
208 | down_3 = self.STDB_3_1_1(down_3)
209 | down_4, skip_3 = self.STDB_3_1_2(down_3)
210 |
211 | down_4 = self.STDB_4_1_1(down_4)
212 | down_5, skip_4 = self.STDB_4_1_2(down_4)
213 |
214 | down_5 = self.STDB_5_1_1(down_5)
215 | down_5 = self.STDB_5_1_2(down_5)
216 |
217 | up_4 = self.up_subm_5(down_5)
218 | up_4 = up_4.replace_feature(torch.cat((up_4.features, skip_4.features), 1))
219 | up_4 = self.STDB_4_2_1(up_4)
220 |
221 | up_3 = self.up_subm_4(up_4)
222 | up_3 = up_3.replace_feature(torch.cat((up_3.features, skip_3.features), 1))
223 | up_3 = self.STDB_3_2_1(up_3)
224 |
225 | up_2 = self.up_subm_3(up_3)
226 | up_2 = up_2.replace_feature(torch.cat((up_2.features, skip_2.features), 1))
227 | up_2 = self.STDB_2_2_1(up_2)
228 |
229 | up_1 = self.up_subm_2(up_2)
230 | up_1 = up_1.replace_feature(torch.cat((up_1.features, skip_1.features), 1))
231 | up_1 = self.STDB_1_2_1(up_1)
232 |
233 | return up_1
234 |
235 |
236 |
237 | class Point_head(nn.Module):
238 | def __init__(self, voxel_feat_dim: int = 96, point_feat_dim: int = 32):
239 | super().__init__()
240 |
241 | self.input_dim = voxel_feat_dim + point_feat_dim
242 |
243 | self.PPmodel_flow = nn.Sequential(
244 | nn.Linear(self.input_dim, 32),
245 | nn.BatchNorm1d(32),
246 | nn.ReLU(),
247 |
248 | nn.Linear(32, 3)
249 | )
250 |
251 | def forward_single(self, voxel_feat, voxel_coords, point_feat):
252 |
253 | voxel_to_point_feat = voxel_feat[:, voxel_coords[:,2], voxel_coords[:,1], voxel_coords[:,0]].T
254 | concated_point_feat = torch.cat([voxel_to_point_feat, point_feat],dim=-1)
255 |
256 | flow = self.PPmodel_flow(concated_point_feat)
257 |
258 | return flow
259 |
260 | def forward(self, sparse_tensor, voxelizer_infos, pc0_point_feats_lst):
261 |
262 | voxel_feats = sparse_tensor.dense()
263 |
264 | flow_outputs = []
265 | batch_idx = 0
266 | for voxelizer_info in voxelizer_infos:
267 | voxel_coords = voxelizer_info["voxel_coords"]
268 | point_feat = pc0_point_feats_lst[batch_idx]
269 | voxel_feat = voxel_feats[batch_idx, :]
270 | flow = self.forward_single(voxel_feat, voxel_coords, point_feat)
271 | batch_idx += 1
272 | flow_outputs.append(flow)
273 |
274 | return flow_outputs
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
--------------------------------------------------------------------------------