├── .gitignore ├── README-zh_CN.md ├── README.md ├── data.py ├── environment.yml ├── example ├── cls.png ├── disp.png ├── model.png ├── table_cls.png └── table_disp.png ├── main.py ├── models └── model.py ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | dist/ 8 | build/ 9 | *.egg-info/ 10 | 11 | # Virtual environments 12 | venv/ 13 | env/ 14 | .env/ 15 | 16 | # IDE 17 | .idea/ 18 | .vscode/ 19 | 20 | # Jupyter Notebook 21 | .ipynb_checkpoints 22 | 23 | # Local development settings 24 | .env 25 | .env.local 26 | 27 | # Model checkpoints and logs 28 | *.pth 29 | *.ckpt 30 | logs/ 31 | checkpoints/ 32 | 33 | # Dataset 34 | dataset/ 35 | data/ 36 | 37 | # Example files 38 | example/ 39 | 40 | # System files 41 | .DS_Store 42 | Thumbs.db -------------------------------------------------------------------------------- /README-zh_CN.md: -------------------------------------------------------------------------------- 1 | # S3Net 2 | 3 | [English](./README.md) | 简体中文 4 | 5 | CVEO小组在IGARSS 2024学术研讨会上提交的论文"在卫星极线图像中使用创新的单分支语义立体网络(S3Net)进行立体匹配和语义分割"的开源代码 6 | 7 | ## 模型概述 8 | ### 框架 9 | ![model](./example/model.png) 10 | 11 | ### 实验结果 12 | #### US3D测试集上的立体匹配结果 13 | ![cls](./example/table_disp.png) 14 | ![disp](./example/disp.png) 15 | 16 | #### US3D测试集上的语义分割结果 17 | ![cls](./example/table_cls.png) 18 | ![cls](./example/cls.png) 19 | 20 | ## 使用说明 21 | ### 安装 22 | ```bash 23 | git clone https://github.com/CVEO/S3Net.git 24 | cd S3Net 25 | conda env create -f environment.yml 26 | conda activate s3net 27 | ``` 28 | ### 数据集 29 | 本实验使用的数据集是[2019数据融合竞赛](https://ieee-dataport.org/open-access/data-fusion-contest-2019-dfc2019)中的US3D赛道2数据集。 30 | ### 预训练权重 31 | [百度网盘](https://pan.baidu.com/s/1EHYTq4eBKVJXgeFTq8SYFQ?pwd=1111) : 1111 32 | 33 | [谷歌云盘](https://drive.google.com/file/d/1QrbsIir5FmKkZ2xlNL57AQKeQ7-vMubh/view?usp=drive_link) 34 | 35 | ## 训练启动方法 36 | 37 | ### 1. 单节点单GPU训练 38 | ```bash 39 | python main.py 40 | ``` 41 | 42 | ### 2. 单节点多GPU训练 43 | ```bash 44 | torchrun --nproc_per_node=N main.py 45 | ``` 46 | 47 | ### 3. 多节点多GPU训练 48 | 49 | #### 启动命令 50 | 在主节点上: 51 | ```bash 52 | torchrun --nproc_per_node=4 --nnodes=N --node_rank=0 --master_addr=MASTER_IP --master_port=PORT main.py 53 | ``` 54 | 55 | 在其他节点上: 56 | ```bash 57 | torchrun --nproc_per_node=4 --nnodes=N --node_rank=R --master_addr=MASTER_IP --master_port=PORT main.py 58 | ``` 59 | 60 | ## 推理启动方法 61 | 62 | 使用test.py进行模型推理: 63 | ```bash 64 | python test.py 65 | ``` 66 | ## 文件目录说明 67 | ``` 68 | S3Net 69 | ├── example 70 | │ ├── cls.png 71 | │ ├── disp.png 72 | │ ├── model.png 73 | │ ├── table_cls.png 74 | │ └── table_disp.png 75 | ├── models 76 | │ └── model.py 77 | ├── README-zh_CN.md 78 | ├── README.md 79 | ├── environment.yml 80 | ├── utils.py 81 | ├── train.py 82 | ├── test.py 83 | ├── main.py 84 | └── data.py 85 | ``` 86 | 87 | ## 最新工作 88 | 如果您对我们的最新工作感兴趣,欢迎查看我们的新项目 [TriGeoNet](https://github.com/CVEO/TriGeoNet)! 89 | 90 | ## 许可证 91 | 代码仅供非商业和研究目的使用。如需商业用途,请联系作者。 92 | 93 | ## 引用本工作 94 | 如果您觉得S3Net对您的研究有帮助,请考虑给个star ⭐ 并引用: 95 | ``` 96 | @inproceedings{yang2024s, 97 | title={S3Net: Innovating Stereo Matching and Semantic Segmentation with a Single-Branch Semantic Stereo Network in Satellite Epipolar Imagery}, 98 | author={Yang, Qingyuan and Chen, Guanzhou and Tan, Xiaoliang and Wang, Tong and Wang, Jiaqi and Zhang, Xiaodong}, 99 | booktitle={IGARSS 2024-2024 IEEE International Geoscience and Remote Sensing Symposium}, 100 | pages={8737--8740}, 101 | year={2024}, 102 | organization={IEEE} 103 | } 104 | ``` 105 | 106 | 或引用旧版本S2Net: 107 | 108 | ``` 109 | @article{liao2023s, 110 | title={S2Net: A Multitask Learning Network for Semantic Stereo of Satellite Image Pairs}, 111 | author={Liao, Puyun and Zhang, Xiaodong and Chen, Guanzhou and Wang, Tong and Li, Xianwei and Yang, Haobo and Zhou, Wenlin and He, Chanjuan and Wang, Qing}, 112 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 113 | volume={62}, 114 | pages={1--13}, 115 | year={2023}, 116 | publisher={IEEE} 117 | } 118 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # S3Net 2 | 3 | English | [简体中文](./README-zh_CN.md) 4 | 5 | Open-source codes of CVEO recent work "S3Net: Innovating Stereo Matching and Semantic Segmentation with a Single-Branch Semantic Stereo Network in Satellite Binocular Imagery" on IGARSS 2024 Symposium. 6 | 7 | ## Model Overview 8 | ### Framework 9 | ![model](./example/model.png) 10 | 11 | ### Results 12 | #### Results of Stereo Matching on the US3D Test Set 13 | ![cls](./example/table_disp.png) 14 | ![disp](./example/disp.png) 15 | 16 | #### Results of Semantic Segmentation on the US3D Test Set 17 | ![cls](./example/table_cls.png) 18 | ![cls](./example/cls.png) 19 | 20 | ## Usage 21 | ### Installation 22 | ```bash 23 | git clone https://github.com/CVEO/S3Net.git 24 | cd S3Net 25 | conda env create -f environment.yml 26 | conda activate s3net 27 | ``` 28 | ### Datasets 29 | The dataset used in our experiment is the track-2 dataset of US3D in [2019 Data Fusion Contest](https://ieee-dataport.org/open-access/data-fusion-contest-2019-dfc2019) 30 | ### Pretrained Weights 31 | [Baidu Disk](https://pan.baidu.com/s/1EHYTq4eBKVJXgeFTq8SYFQ?pwd=1111) : 1111 32 | 33 | [Google Drive](https://drive.google.com/file/d/1QrbsIir5FmKkZ2xlNL57AQKeQ7-vMubh/view?usp=drive_link) 34 | 35 | ## Training Launch Methods 36 | 37 | ### 1. Single-Node Single-GPU Training 38 | ```bash 39 | python main.py 40 | ``` 41 | 42 | ### 2. Single-Node Multi-GPU Training 43 | ```bash 44 | torchrun --nproc_per_node=N main.py 45 | ``` 46 | 47 | 48 | ### 3. Multi-Node Multi-GPU Training 49 | 50 | 51 | #### Launch Commands 52 | On the master node: 53 | ```bash 54 | torchrun --nproc_per_node=4 --nnodes=N --node_rank=0 --master_addr=MASTER_IP --master_port=PORT main.py 55 | ``` 56 | 57 | On other nodes: 58 | ```bash 59 | torchrun --nproc_per_node=4 --nnodes=N --node_rank=R --master_addr=MASTER_IP --master_port=PORT main.py 60 | ``` 61 | 62 | ## Inference Launch Methods 63 | 64 | Use evaluation.py for model inference: 65 | ```bash 66 | python test.py 67 | ``` 68 | ## File Directory Description 69 | ``` 70 | S3Net 71 | ├── example 72 | │ ├── cls.png 73 | │ ├── disp.png 74 | │ ├── model.png 75 | │ ├── table_cls.png 76 | │ └── table_disp.png 77 | ├── models 78 | │ └── model.py 79 | ├── README-zh_CN.md 80 | ├── README.md 81 | ├── environment.yml 82 | ├── utils.py 83 | ├── train.py 84 | ├── test.py 85 | ├── main.py 86 | └── data.py 87 | ``` 88 | 89 | ## Latest Work 90 | If you are interested in our latest work, please check out our new project [TriGeoNet](https://github.com/CVEO/TriGeoNet)! 91 | 92 | ## License 93 | Code is released for non-commercial and research purposes only. For commercial purposes, please contact the authors. 94 | 95 | 96 | ## Cite this work 97 | If you find S3Net useful in your research, please consider giving a star ⭐ and citing: 98 | ``` 99 | @inproceedings{yang2024s, 100 | title={S3Net: Innovating Stereo Matching and Semantic Segmentation with a Single-Branch Semantic Stereo Network in Satellite Epipolar Imagery}, 101 | author={Yang, Qingyuan and Chen, Guanzhou and Tan, Xiaoliang and Wang, Tong and Wang, Jiaqi and Zhang, Xiaodong}, 102 | booktitle={IGARSS 2024-2024 IEEE International Geoscience and Remote Sensing Symposium}, 103 | pages={8737--8740}, 104 | year={2024}, 105 | organization={IEEE} 106 | } 107 | ``` 108 | 109 | or cite the old version S2Net: 110 | 111 | ``` 112 | @article{liao2023s, 113 | title={S2Net: A Multitask Learning Network for Semantic Stereo of Satellite Image Pairs}, 114 | author={Liao, Puyun and Zhang, Xiaodong and Chen, Guanzhou and Wang, Tong and Li, Xianwei and Yang, Haobo and Zhou, Wenlin and He, Chanjuan and Wang, Qing}, 115 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 116 | volume={62}, 117 | pages={1--13}, 118 | year={2023}, 119 | publisher={IEEE} 120 | } 121 | ``` 122 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import random 5 | import torchvision.transforms as transforms 6 | import torch.distributed as dist 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | 10 | class DataLoader: 11 | @staticmethod 12 | def load(datapath: str, dataset: str) -> tuple: 13 | def _get_image_paths(subdir): 14 | base_path = os.path.join(datapath, subdir) 15 | return [os.path.join(base_path, img) for img in os.listdir(base_path)] 16 | 17 | left_train = _get_image_paths("left") 18 | right_train = [path.replace("LEFT_RGB", "RIGHT_RGB").replace("left", "right") for path in left_train] 19 | disp_train = [path.replace("LEFT_RGB", "LEFT_DSP").replace("left", "disp") for path in left_train] 20 | cls_train = [path.replace("LEFT_RGB", "LEFT_CLS").replace("left", "cls") for path in left_train] 21 | 22 | left_valid = _get_image_paths("valid_left") 23 | right_valid = [path.replace("LEFT_RGB", "RIGHT_RGB").replace("left", "right") for path in left_valid] 24 | disp_valid = [path.replace("LEFT_RGB", "LEFT_DSP").replace("left", "disp") for path in left_valid] 25 | cls_valid = [path.replace("LEFT_RGB", "LEFT_CLS").replace("left", "cls") for path in left_valid] 26 | 27 | 28 | train_data = (left_train, right_train, disp_train, cls_train) 29 | valid_data = (left_valid, right_valid, disp_valid, cls_valid) 30 | 31 | return train_data, valid_data 32 | 33 | class StereoDataset(Dataset): 34 | def __init__(self, left_images, right_images, disp_images, cls_images, training=True): 35 | self.left = left_images 36 | self.right = right_images 37 | self.disp = disp_images 38 | self.cls = cls_images 39 | self.training = training 40 | 41 | def __len__(self): 42 | return len(self.left) 43 | 44 | def get_transform(self, data): 45 | normal_mean_var = {'mean': [0.485, 0.456, 0.406], 46 | 'std': [0.229, 0.224, 0.225]} 47 | data = torch.from_numpy(data).float() 48 | transform = transforms.Compose([transforms.Normalize(**normal_mean_var)]) 49 | return transform(data).float() 50 | 51 | def _augment(self, left, right, disp, cls): 52 | 53 | if random.random() > 0.5: 54 | left = np.flip(left, axis=1).copy() 55 | right = np.flip(right, axis=1).copy() 56 | disp = np.flip(disp, axis=0).copy() 57 | cls = np.flip(cls, axis=0).copy() 58 | 59 | if random.random() > 0.5: 60 | left = np.flip(left, axis=2).copy() 61 | right = np.flip(right, axis=2).copy() 62 | disp = -np.flip(disp, axis=1).copy() 63 | cls = np.flip(cls, axis=1).copy() 64 | 65 | _, h, w = left.shape 66 | x = random.randint(0, w - 512) 67 | y = random.randint(0, h - 512) 68 | left = left[:, y:y+512, x:x+512].copy() 69 | right = right[:, y:y+512, x:x+512].copy() 70 | cls = cls[y:y+512, x:x+512].copy() 71 | disp = disp[y:y+512, x:x+512].copy() 72 | 73 | return left, right, disp, cls 74 | 75 | def __getitem__(self, index): 76 | left = self._read_image(self.left[index]) 77 | right = self._read_image(self.right[index]) 78 | disp = self._read_image(self.disp[index], is_disp_cls=True) 79 | cls = self._read_image(self.cls[index], is_disp_cls=True) 80 | 81 | if self.training: 82 | left, right, disp, cls = self._augment(left, right, disp, cls) 83 | 84 | left = self.get_transform(left) 85 | right = self.get_transform(right) 86 | 87 | return left, right, disp, cls 88 | 89 | def _read_image(self, path, is_disp_cls=False): 90 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype('float32') 91 | 92 | if len(img.shape) == 3: 93 | img = np.moveaxis(img, -1, 0) / 255.0 94 | return img 95 | 96 | if is_disp_cls: 97 | return img 98 | 99 | 100 | def generate(dataset, datapath): 101 | 102 | train_data, valid_data = DataLoader.load(datapath, dataset) 103 | 104 | train_dataset = StereoDataset( 105 | left_images=train_data[0], 106 | right_images=train_data[1], 107 | disp_images=train_data[2], 108 | cls_images=train_data[3], 109 | training=True 110 | ) 111 | 112 | valid_dataset = StereoDataset( 113 | left_images=valid_data[0], 114 | right_images=valid_data[1], 115 | disp_images=valid_data[2], 116 | cls_images=valid_data[3], 117 | training=False 118 | ) 119 | 120 | return train_dataset, valid_dataset 121 | 122 | 123 | def initialize_dataloaders(args, train_dataset, valid_dataset): 124 | if args.is_distributed: 125 | train_sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), 126 | rank=dist.get_rank()) 127 | valid_sampler = torch.utils.data.DistributedSampler(valid_dataset, num_replicas=dist.get_world_size(), 128 | rank=dist.get_rank()) 129 | train_loader = torch.utils.data.DataLoader( 130 | train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, 131 | sampler=train_sampler, pin_memory=True) 132 | 133 | valid_loader = torch.utils.data.DataLoader( 134 | valid_dataset, batch_size=args.batch_size, num_workers=args.num_workers, 135 | sampler=valid_sampler, pin_memory=True) 136 | else: 137 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, 138 | shuffle=True, num_workers=args.num_workers, drop_last=False) 139 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, 140 | shuffle=False, num_workers=args.num_workers, drop_last=False) 141 | return train_loader, valid_loader -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: s3net 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.9.18 6 | - opencv=4.8.1 7 | - pip 8 | - pip: 9 | - torch==1.8.1+cu102 10 | - torchvision==0.9.1+cu102 11 | - torchaudio==0.8.1 12 | - tqdm 13 | - numpy 14 | - torchmetrics -------------------------------------------------------------------------------- /example/cls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVEO/S3Net/aaab6979f66ef132a1117afa3a7a8921e21ab1cf/example/cls.png -------------------------------------------------------------------------------- /example/disp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVEO/S3Net/aaab6979f66ef132a1117afa3a7a8921e21ab1cf/example/disp.png -------------------------------------------------------------------------------- /example/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVEO/S3Net/aaab6979f66ef132a1117afa3a7a8921e21ab1cf/example/model.png -------------------------------------------------------------------------------- /example/table_cls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVEO/S3Net/aaab6979f66ef132a1117afa3a7a8921e21ab1cf/example/table_cls.png -------------------------------------------------------------------------------- /example/table_disp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVEO/S3Net/aaab6979f66ef132a1117afa3a7a8921e21ab1cf/example/table_disp.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import logging 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.distributed as dist 9 | import torch.optim as optim 10 | from tqdm import tqdm 11 | import torch.multiprocessing as mp 12 | 13 | from utils import ( 14 | parser, 15 | initialize_model, 16 | setup, 17 | ) 18 | from data import generate, initialize_dataloaders 19 | from train import train_one_epoch, validate_one_epoch, save_checkpoint, log_results 20 | 21 | def train(): 22 | args = parser() 23 | setup(args) 24 | 25 | args.save_ckpt_path = args.save_ckpt_path or f'{args.model}_{args.dataset}' 26 | args.save_csv_file_path = args.save_csv_file_path or f'{args.model}_{args.dataset}.csv' 27 | os.makedirs(args.save_ckpt_path, exist_ok=True) 28 | 29 | # Distributed training initialization 30 | model = initialize_model(args) 31 | if args.rank == 0: 32 | print(args) 33 | 34 | # Dataset loading 35 | train_dataset, valid_dataset = generate(args.dataset, args.datapath) 36 | train_loader, valid_loader = initialize_dataloaders(args, train_dataset, valid_dataset) 37 | 38 | # Log file initialization 39 | fieldnames = ['epoch', 'train_loss', 'valid_loss'] 40 | if not torch.cuda.is_available() and not os.path.exists(args.save_csv_file_path): 41 | with open(args.save_csv_file_path, 'w', newline='') as csvfile: 42 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 43 | writer.writeheader() 44 | 45 | # Loss function and optimizer configuration 46 | trainable_params = [p for p in model.parameters() if p.requires_grad] 47 | if args.rank == 0: 48 | print(f"Number of trainable parameters: {sum(p.numel() for p in trainable_params) / 1e6:.2f}M") 49 | 50 | optimizer = optim.Adam([ 51 | {'params': trainable_params, 'name': 'model', 'lr': args.lr}, 52 | ], betas=(0.9, 0.999), weight_decay=1e-7) 53 | 54 | # Training loop 55 | for epoch in range(args.epochs): 56 | # Training and validation 57 | train_loss = train_one_epoch(epoch, model, optimizer, train_loader, args.local_rank, args) 58 | save_checkpoint(model, epoch, args, args.local_rank) 59 | valid_loss = validate_one_epoch(epoch, model, valid_loader, args.local_rank, args) 60 | 61 | if args.rank == 0: 62 | log_results(epoch, train_loss, valid_loss, args, fieldnames) 63 | 64 | if __name__ == '__main__': 65 | train() -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def convbn_3d(in_planes, out_planes, kernel_size, stride, pad): 8 | return nn.Sequential( 9 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride, bias=False), 10 | nn.BatchNorm3d(out_planes) 11 | ) 12 | 13 | 14 | class disparityregression(nn.Module): 15 | def __init__(self, maxdisp): 16 | super(disparityregression, self).__init__() 17 | self.disp = torch.Tensor(np.reshape(np.array(range(maxdisp)), [1, maxdisp, 1, 1])).cuda() - maxdisp/2 18 | 19 | def forward(self, x): 20 | out = torch.sum(x*self.disp.data, 1, keepdim=False) 21 | return out 22 | 23 | 24 | class CAModule(nn.Module): 25 | def __init__(self, channel, ratio=16): 26 | super(CAModule, self).__init__() 27 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 28 | self.max_pool = nn.AdaptiveMaxPool3d(1) 29 | 30 | self.shared = nn.Sequential( 31 | nn.Conv3d(channel, channel // ratio, 1, bias=False), 32 | nn.ReLU(), 33 | nn.Conv3d(channel // ratio, channel, 1, bias=False) 34 | ) 35 | self.sigmoid = nn.Sigmoid() 36 | 37 | def forward(self, x): 38 | avgout = self.shared(self.avg_pool(x)) 39 | maxout = self.shared(self.max_pool(x)) 40 | return self.sigmoid(avgout + maxout) 41 | 42 | 43 | class SAModule(nn.Module): 44 | def __init__(self): 45 | super(SAModule, self).__init__() 46 | self.conv3d = nn.Conv3d(in_channels=2, out_channels=1, kernel_size=3, stride=1, padding=1) 47 | self.sigmoid = nn.Sigmoid() 48 | 49 | def forward(self, x): 50 | avgout = torch.mean(x, dim=1, keepdim=True) 51 | maxout, _ = torch.max(x, dim=1, keepdim=True) 52 | out = torch.cat([avgout, maxout], dim=1) 53 | out = self.sigmoid(self.conv3d(out)) 54 | return out 55 | 56 | 57 | class mmcs(nn.Module): 58 | def __init__(self, channel): 59 | super(mmcs, self).__init__() 60 | self.ca = CAModule(channel) 61 | self.sa = SAModule() 62 | 63 | def forward(self, x): 64 | out = self.ca(x) * x 65 | out = self.sa(out) * out 66 | return out 67 | 68 | 69 | class Block(nn.Module): 70 | def __init__(self, in_places, places, stride=1, downsampling=False, expansion=4): 71 | super(Block, self).__init__() 72 | self.expansion = expansion 73 | self.downsampling = downsampling 74 | 75 | self.bottleneck = nn.Sequential( 76 | nn.Conv3d(in_channels=in_places, out_channels=places, kernel_size=1, stride=1, bias=False), 77 | nn.BatchNorm3d(places), 78 | nn.ReLU(), 79 | nn.Conv3d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False), 80 | nn.BatchNorm3d(places), 81 | nn.ReLU(), 82 | nn.Conv3d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False), 83 | nn.BatchNorm3d(places*self.expansion), 84 | ) 85 | self.mmcs = mmcs(channel=places*self.expansion) 86 | 87 | if self.downsampling: 88 | self.downsample = nn.Sequential( 89 | nn.Conv3d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False), 90 | nn.BatchNorm3d(places*self.expansion) 91 | ) 92 | self.relu = nn.ReLU() 93 | 94 | def forward(self, x): 95 | residual = x 96 | out = self.bottleneck(x) 97 | out = self.mmcs(out) 98 | 99 | if self.downsampling: 100 | residual = self.downsample(x) 101 | 102 | out = out + residual 103 | out = self.relu(out) 104 | return out 105 | 106 | 107 | class Mul_fuse_3D(nn.Module): 108 | def __init__(self, input_size=128, hidden_size=256): 109 | super(Mul_fuse_3D, self).__init__() 110 | self.conv1 = nn.Conv3d(input_size, hidden_size, kernel_size=3, padding=1, stride=2) 111 | self.conv2 = nn.Conv3d(hidden_size, hidden_size, kernel_size=3, padding=1, stride=1) 112 | self.conv3 = nn.Conv3d(hidden_size, hidden_size, kernel_size=3, padding=1, stride=2) 113 | self.conv4 = nn.Conv3d(hidden_size, hidden_size, kernel_size=3, padding=1, stride=1) 114 | self.conv5 = nn.Sequential(nn.ConvTranspose3d(hidden_size, hidden_size, kernel_size=3, padding=1, output_padding=1, stride=2), 115 | nn.BatchNorm3d(hidden_size)) 116 | self.conv6 = nn.Sequential(nn.ConvTranspose3d(hidden_size, input_size, kernel_size=3, padding=1, output_padding=1, stride=2), 117 | nn.BatchNorm3d(input_size)) 118 | self.fuse_3D = fuse_3D(64, 64, 128) 119 | self.fuse_3d1 = fuse_3D(128, 128, 64) 120 | 121 | self.conv7 = nn.Conv3d(128, 128, 3, 1, 1) 122 | self.conv8 = nn.Conv3d(64, 64, 3, 1, 1) 123 | self.pool = nn.MaxPool3d(3 ,1, 1) 124 | self.bn = nn.BatchNorm3d(128) 125 | 126 | self.bn1= nn.BatchNorm3d(64) 127 | self.pool1 = nn.MaxPool3d(3 ,1, 1) 128 | 129 | def forward(self, x, presqu, postsqu): 130 | x = self.fuse_3D(x) #128 -> 256 131 | x = self.bn(x) 132 | x = torch.relu(self.conv7(x)) 133 | x = self.pool(x) 134 | 135 | cls = x[:,:,:1] # B C 1 H W -> B C H W -> conv2d -> B C 1 H W 136 | disp = x[:,:,1:] # 5D B C D H W 137 | 138 | pre = cls * disp 139 | disp = torch.relu(self.conv1(pre)) # 256 -> 512 140 | pre = self.conv2(disp) 141 | 142 | if postsqu is not None: 143 | pre = F.relu(pre+postsqu) 144 | else: 145 | pre = F.relu(pre) 146 | 147 | out = torch.relu(self.conv3(pre)) 148 | out = torch.relu(self.conv4(out)) 149 | 150 | if presqu is not None: 151 | post = F.relu(self.conv5(out)+presqu) # 512 -> 512 152 | else: 153 | post = F.relu(self.conv5(out)+pre) 154 | 155 | out = self.conv6(post) # 512 -> 256 156 | 157 | 158 | out = torch.cat((cls, out), dim=2) 159 | 160 | out = self.fuse_3d1(out) 161 | out = self.bn1(out) 162 | out = torch.relu(self.conv8(out)) 163 | out = self.pool1(out) 164 | return out, pre, post 165 | 166 | 167 | 168 | class fuse_3D(nn.Module): 169 | def __init__(self, input_size, hidden_size, output_size): 170 | super(fuse_3D, self).__init__() 171 | self.conv1 = nn.Conv3d(input_size, hidden_size, kernel_size=3, padding=1, stride=1) 172 | self.fuse1 = nn.Conv3d(input_size, hidden_size, kernel_size=3, padding=1, stride=1) 173 | self.conv2 = nn.Conv3d(hidden_size, output_size, kernel_size=3, padding=1, stride=1) 174 | self.fuse2 = nn.Conv3d(hidden_size, output_size, kernel_size=3, padding=1, stride=1) 175 | 176 | def forward(self, x): 177 | hidden1 = torch.relu(self.conv1(x)) 178 | fuse1 = torch.sigmoid(self.fuse1(x)) 179 | fsue_hidden1 = hidden1 * fuse1 180 | 181 | output = torch.relu(self.conv2(fsue_hidden1)) 182 | fuse2 = torch.sigmoid(self.fuse2(fsue_hidden1)) 183 | final_output = output * fuse2 184 | 185 | return final_output 186 | 187 | 188 | 189 | class fuse_2D(nn.Module): 190 | def __init__(self, input_size, hidden_size, output_size): 191 | super(fuse_2D, self).__init__() 192 | self.conv1 = nn.Conv2d(input_size, hidden_size, kernel_size=3, padding=1, stride=1) 193 | self.fuse1 = nn.Conv2d(input_size, hidden_size, kernel_size=3, padding=1, stride=1) 194 | self.conv2 = nn.Conv2d(hidden_size, output_size, kernel_size=3, padding=1, stride=1) 195 | self.fuse2 = nn.Conv2d(hidden_size, output_size, kernel_size=3, padding=1, stride=1) 196 | 197 | def forward(self, x): 198 | hidden1 = torch.relu(self.conv1(x)) 199 | fuse1 = torch.sigmoid(self.fuse1(x)) 200 | fuse_hidden1 = hidden1 * fuse1 201 | 202 | output = torch.relu(self.conv2(fuse_hidden1)) 203 | fuse2 = torch.sigmoid(self.fuse2(fuse_hidden1)) 204 | final_output = output * fuse2 205 | 206 | return final_output 207 | 208 | 209 | 210 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation): 211 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False), 212 | nn.BatchNorm2d(out_planes)) 213 | 214 | 215 | class BasicBlock(nn.Module): 216 | expansion = 1 217 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 218 | super(BasicBlock, self).__init__() 219 | 220 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation), 221 | nn.ReLU()) 222 | 223 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation) 224 | 225 | self.downsample = downsample 226 | self.stride = stride 227 | 228 | self.weight = nn.Parameter(torch.Tensor([1])) 229 | 230 | def forward(self, x): 231 | out = self.conv1(x) 232 | out = self.conv2(out) 233 | 234 | if self.downsample is not None: 235 | x = self.downsample(x) 236 | 237 | out = out+self.weight * out + x 238 | 239 | return out 240 | 241 | 242 | class feature_extraction(nn.Module): 243 | def __init__(self): 244 | super(feature_extraction, self).__init__() 245 | self.inplanes = 32 246 | self.firstconv = nn.Sequential(convbn(3, 32, 3, 2, 1, 1), 247 | nn.ReLU(), 248 | convbn(32, 32, 3, 1, 1, 1), 249 | nn.ReLU(), 250 | convbn(32, 32, 3, 1, 1, 1), 251 | nn.BatchNorm2d(32), 252 | nn.ReLU(), 253 | nn.MaxPool2d(3, 1, 1)) 254 | self.conv2 = nn.Sequential(nn.ConvTranspose2d(32,64,3,2,1,1,bias=False), 255 | nn.ReLU(), 256 | nn.ConvTranspose2d(64,64,1,1,0,bias=False)) 257 | 258 | self.conv3 = nn.Sequential(nn.ConvTranspose2d(64,128,3,2,1,1,bias=False), 259 | nn.ReLU(), 260 | nn.ConvTranspose2d(128,128,1,1,0,bias=False)) 261 | 262 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1, 1, 1) 263 | self.layer2 = self._make_layer(BasicBlock, 64, 4, 2, 1, 1) 264 | self.layer3 = self._make_layer(BasicBlock, 128, 6, 1, 1, 1) 265 | self.layer4 = self._make_layer(BasicBlock, 256, 3, 1, 1, 2) 266 | 267 | 268 | self.fuse1 = fuse_2D(64, 512, 512) 269 | self.fuse2 = fuse_2D(128, 512, 512) 270 | self.fuse3 = fuse_2D(256, 512, 512) 271 | 272 | self.lastconv = self.lastconv = nn.Sequential(nn.Conv2d(512, 128, 3, 1, 1, 1), 273 | nn.BatchNorm2d(128), 274 | nn.ReLU(), 275 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride = 1)) 276 | 277 | def forward(self, x): 278 | conv_x = self.firstconv(x) 279 | layer1 = self.layer1(conv_x) 280 | layer2 = self.layer2(layer1) 281 | layer3 = self.layer3(layer2) 282 | layer4 = self.layer4(layer3) 283 | 284 | output = self.fuse1(layer2) + self.fuse2(layer3) + self.fuse3(layer4) 285 | 286 | output = self.lastconv(output) 287 | 288 | return output 289 | 290 | def _make_layer(self, block, planes, blocks, stride, pad, dilation): 291 | downsample = None 292 | if stride != 1 or self.inplanes != planes * block.expansion: 293 | downsample = nn.Sequential( 294 | nn.Conv2d(self.inplanes, planes * block.expansion, 295 | kernel_size=1, stride=stride, bias=False), 296 | nn.BatchNorm2d(planes * block.expansion)) 297 | 298 | layers = [] 299 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation)) 300 | self.inplanes = planes * block.expansion 301 | for i in range(1, blocks): 302 | layers.append(block(self.inplanes, planes, 1, None, pad, dilation)) 303 | 304 | return nn.Sequential(*layers) 305 | 306 | 307 | class cls_extraction(nn.Module): 308 | def __init__(self): 309 | super(cls_extraction, self).__init__() 310 | 311 | self.conv = nn.Sequential(nn.Conv2d(3,32,3,2,1), 312 | nn.ReLU(), 313 | nn.Conv2d(32,32,3,1,1), 314 | nn.ReLU(), 315 | nn.Conv2d(32,32,3,2,1), 316 | nn.BatchNorm2d(32), 317 | nn.ReLU(), 318 | nn.MaxPool2d(3, 1, 1),) 319 | 320 | self.fuse = fuse_2D(32, 64, 32) 321 | 322 | 323 | def forward(self, x): 324 | 325 | out = self.conv(x) 326 | out = self.fuse(out) 327 | 328 | return out 329 | 330 | 331 | class SSNet(nn.Module): 332 | def __init__(self, maxdisp=48, mindisp=-48, num_classes=6): 333 | super(SSNet, self).__init__() 334 | self.maxdisp = maxdisp 335 | self.num_classes = num_classes 336 | 337 | self.feature_extraction = feature_extraction() 338 | self.cls_extraction = cls_extraction() 339 | 340 | 341 | self.fuse = fuse_3D(64, 64, 64) 342 | self.bn = nn.BatchNorm3d(64) 343 | self.Mul_fuse_3D1 = Mul_fuse_3D() 344 | self.Mul_fuse_3D2 = Mul_fuse_3D() 345 | self.Mul_fuse_3D3 = Mul_fuse_3D() 346 | 347 | self.block1 = Block(in_places=64, places=16) 348 | self.block2 = Block(in_places=64, places=16) 349 | self.block3 = Block(in_places=64, places=16) 350 | 351 | 352 | self.classif1 = nn.Sequential(convbn_3d(64, 64, 3, 1, 1), 353 | nn.ReLU()) 354 | 355 | self.classif1_1 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1), 356 | nn.ReLU()) 357 | 358 | self.classif1_last = nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False) 359 | 360 | self.classif2 = nn.Sequential(convbn_3d(64, 64, 3, 1, 1), 361 | nn.ReLU()) 362 | self.classif2_1 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1), 363 | nn.ReLU()) 364 | 365 | self.classif2_last = nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False) 366 | 367 | self.classif3 = nn.Sequential(convbn_3d(64, 64, 3, 1, 1), 368 | nn.ReLU()) 369 | self.classif3_1 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1), 370 | nn.ReLU()) 371 | 372 | self.classif3_last = nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False) 373 | 374 | self.fuse1 = fuse_2D(256,128,128) 375 | self.fuse2 = fuse_2D(128,64,64) 376 | self.fuse3 = fuse_2D(64,32,32) 377 | self.weight1 = nn.Parameter(torch.Tensor([1])) 378 | self.weight2 = nn.Parameter(torch.Tensor([1])) 379 | self.weight3 = nn.Parameter(torch.Tensor([1])) 380 | 381 | self.conv1 = nn.Conv2d(32,self.num_classes,3,1,1) 382 | 383 | self.conv2 = nn.Sequential(nn.ConvTranspose2d(256,128,3,2,1,1,bias=False), 384 | nn.ReLU(), 385 | nn.ConvTranspose2d(128,128,1,1,0,bias=False)) 386 | 387 | self.conv3 = nn.Sequential(nn.ConvTranspose2d(128,64,3,2,1,1,bias=False), 388 | nn.ReLU(), 389 | nn.ConvTranspose2d(64,64,1,1,0,bias=False)) 390 | 391 | self.firstconv = nn.Sequential(convbn(32, 64, 3, 1, 1, 1), 392 | nn.ReLU(), 393 | convbn(64, 128, 3, 1, 1, 1), 394 | nn.ReLU(), 395 | nn.MaxPool2d(3, 1, 1), 396 | convbn(128, 256, 3, 1, 1, 1), 397 | nn.ReLU(), 398 | convbn(256, 256, 3, 1, 1, 1), 399 | nn.BatchNorm2d(256), 400 | nn.ReLU(), 401 | nn.MaxPool2d(3, 1, 1)) 402 | 403 | def forward(self, left, right): 404 | left_feature = self.feature_extraction(left) 405 | right_feature = self.feature_extraction(right) 406 | 407 | lef_cls_first = self.cls_extraction(left).unsqueeze(2) 408 | rig_cls_first = self.cls_extraction(right).unsqueeze(2) 409 | 410 | cls_first = torch.cat((lef_cls_first, rig_cls_first), dim=1) 411 | 412 | 413 | cost = torch.zeros(left_feature.size()[0], left_feature.size()[1]*2, self.maxdisp//4, left_feature.size()[2], left_feature.size()[3]).cuda() 414 | 415 | for i in range(-self.maxdisp//8, self.maxdisp//8): 416 | if i > 0: 417 | cost[:, :left_feature.size()[1], i + self.maxdisp//8, :, i:] = left_feature[:, :, :, i:] 418 | cost[:, left_feature.size()[1]:, i + self.maxdisp//8, :, i:] = right_feature[:, :, :, :-i] 419 | elif i == 0: 420 | cost[:, :left_feature.size()[1], self.maxdisp//8, :, :] = left_feature 421 | cost[:, left_feature.size()[1]:, self.maxdisp//8, :, :] = right_feature 422 | else: 423 | cost[:, :left_feature.size()[1], i + self.maxdisp//8, :, :i] = left_feature[:, :, :, :i] 424 | cost[:, left_feature.size()[1]:, i + self.maxdisp//8, :, :i] = right_feature[:, :, :, -i:] 425 | 426 | cost = torch.cat((cls_first, cost), dim=2) 427 | cost = cost.contiguous() 428 | 429 | 430 | cost0 = self.fuse(cost) 431 | cost0 = self.bn(cost0) 432 | 433 | out1, pre1, post1 = self.Mul_fuse_3D1(cost0, None, None) 434 | out1 = out1+cost0 435 | out1 = self.block1(out1) 436 | 437 | out2, pre2, post2 = self.Mul_fuse_3D2(out1, pre1, post1) 438 | out2 = out2+cost0 439 | out2 = self.block2(out2) 440 | 441 | out3, pre3, post3 = self.Mul_fuse_3D3(out2, pre1, post2) 442 | out3 = out3+cost0 443 | out3 = self.block3(out3) 444 | 445 | 446 | cost1 = self.classif1(out1) 447 | cost2 = self.classif2(out2) + cost1 448 | cost3 = self.classif3(out3) + cost2 449 | 450 | cost1 = self.classif1_1(cost1) 451 | cost2 = self.classif2_1(cost2) + cost1 452 | cost3 = self.classif3_1(cost3) + cost2 453 | 454 | cls1 = cost1[:,:,:1] 455 | cls2 = cost2[:,:,:1] 456 | cls3 = cost3[:,:,:1] 457 | cost1 = cost1[:,:,1:] 458 | cost2 = cost2[:,:,1:] 459 | cost3 = cost3[:,:,1:] 460 | 461 | 462 | cost1 = self.classif1_last(cost1) 463 | cost2 = self.classif2_last(cost2) + cost1 464 | cost3 = self.classif3_last(cost3) + cost2 465 | 466 | cost1 = F.interpolate(cost1, [self.maxdisp,left.size()[2],left.size()[3]], align_corners=True, mode='trilinear') 467 | cost2 = F.interpolate(cost2, [self.maxdisp,left.size()[2],left.size()[3]], align_corners=True, mode='trilinear') 468 | 469 | cost1 = torch.squeeze(cost1,1) 470 | pred1 = F.softmax(cost1,dim=1) 471 | pred1 = disparityregression(self.maxdisp)(pred1) 472 | 473 | cost2 = torch.squeeze(cost2,1) 474 | pred2 = F.softmax(cost2,dim=1) 475 | pred2 = disparityregression(self.maxdisp)(pred2) 476 | 477 | cost3 = F.interpolate(cost3, [self.maxdisp,left.size()[2],left.size()[3]], align_corners=True, mode='trilinear') 478 | cost3 = torch.squeeze(cost3,1) 479 | pred3 = F.softmax(cost3,dim=1) 480 | pred3 = disparityregression(self.maxdisp)(pred3) 481 | 482 | cls1 = cls1.squeeze(2) 483 | cls2 = cls2.squeeze(2) 484 | cls3 = cls3.squeeze(2) 485 | 486 | cls = cls1+cls2+cls3 487 | 488 | cls1 = self.firstconv(cls) 489 | cls2 = self.conv2(cls1) 490 | cls3 = self.conv3(cls2) 491 | 492 | cls1 = F.interpolate(cls1, [cls2.size()[2], cls2.size()[3]], align_corners=False, mode='bilinear') 493 | cls2 = self.weight1*self.fuse1(cls1)+cls2 494 | cls2 = F.interpolate(cls2, [cls3.size()[2], cls3.size()[3]], align_corners=False, mode='bilinear') 495 | cls3 = self.weight2*self.fuse2(cls2)+cls3 496 | cls3 = F.interpolate(cls3, [left.size()[2], left.size()[3]], align_corners=False, mode='bilinear') 497 | cls3 = self.fuse3(cls3) 498 | cls3 = self.conv1(cls3) 499 | 500 | return pred1, pred2, pred3, cls3 501 | 502 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import torchvision.transforms as transforms 5 | 6 | from utils import parser 7 | from models.model import SSNet 8 | 9 | 10 | def get_transform(data): 11 | normal_mean_var = {'mean': [0.485, 0.456, 0.406], 12 | 'std': [0.229, 0.224, 0.225]} 13 | data = torch.from_numpy(data).float() 14 | transform = transforms.Compose([transforms.Normalize(**normal_mean_var)]) 15 | return transform(data).float() 16 | 17 | 18 | def eval(): 19 | args = parser() 20 | 21 | model = SSNet(args.maxdisp, args.mindisp, args.classfication) 22 | model.load_state_dict(torch.load(args.ckpt)['state_dict']) 23 | model.eval().cuda() 24 | 25 | # Dataset loading 26 | left_path = 'xxxx' 27 | right_path = 'xxxx' 28 | 29 | left = cv2.imread(left_path, cv2.IMREAD_UNCHANGED).astype('float32') 30 | right = cv2.imread(right_path, cv2.IMREAD_UNCHANGED).astype('float32') 31 | 32 | left = np.moveaxis(left, -1, 0) / 255.0 33 | right = np.moveaxis(right, -1, 0) / 255.0 34 | 35 | left = get_transform(left).unsqueeze(0).float().cuda() 36 | right = get_transform(right).unsqueeze(0).float().cuda() 37 | 38 | # Inference 39 | with torch.no_grad(): 40 | _, _, pred_disp, pred_cls = model(left, right) 41 | cv2.imwrite('pred_disp.tif', pred_disp.squeeze().cpu().numpy().astype(np.float32)) 42 | pred_cls_np = torch.argmax(pred_cls, dim=1).squeeze().cpu().numpy().astype(np.uint8) 43 | cv2.imwrite('pred_cls.tif', pred_cls_np) 44 | 45 | if __name__ == "__main__": 46 | eval() 47 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.distributed as dist 7 | from tqdm import tqdm 8 | 9 | from utils import create_mask, adjust_learning_rate, masked_cross_entropy_loss 10 | 11 | def save_checkpoint(model, epoch, args, local_rank): 12 | 13 | savefilename = os.path.join(args.save_ckpt_path, f'train_ckpt_{epoch}.tar') 14 | if args.is_distributed: 15 | if args.rank == 0: 16 | torch.save({'state_dict': model.module.state_dict()}, savefilename) 17 | else: 18 | torch.save({'state_dict': model.state_dict()}, savefilename) 19 | 20 | 21 | def log_results(epoch, train_loss, valid_loss, args, fieldnames): 22 | with open(args.save_csv_file_path, 'a', newline='') as csvfile: 23 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 24 | writer.writerow({ 25 | 'epoch': epoch, 26 | 'train_loss': train_loss, 27 | 'valid_loss': valid_loss, 28 | }) 29 | 30 | 31 | def train_one_epoch(epoch, model, optimizer, train_loader, local_rank, args): 32 | 33 | model.train() 34 | total_train_loss = 0.0 35 | 36 | # Learning rate adjustment and distributed training setup 37 | adjust_learning_rate(optimizer, epoch) 38 | if args.is_distributed: 39 | train_loader.sampler.set_epoch(epoch) 40 | 41 | # Training loop 42 | for batch_idx, (left, right, disp, cls) in tqdm(enumerate(train_loader), total=len(train_loader)): 43 | 44 | left, right, disp = [ 45 | tensor.to(local_rank).float() for tensor in 46 | [left, right, disp] 47 | ] 48 | 49 | cls = cls.to(local_rank).long() 50 | 51 | optimizer.zero_grad() 52 | 53 | pred_disp1, pred_disp2, pred_disp3, pred_cls = model(left, right) 54 | 55 | # Create masks for different disparity scales 56 | disp, mask = create_mask(disp, args.maxdisp, args.mindisp) 57 | 58 | # Compute losses for different scales 59 | loss1 = 0.5*F.smooth_l1_loss(pred_disp1[mask], disp[mask], size_average=True) + \ 60 | 0.7*F.smooth_l1_loss(pred_disp2[mask], disp[mask], size_average=True) + \ 61 | F.smooth_l1_loss(pred_disp3[mask], disp[mask], size_average=True) 62 | loss2 = masked_cross_entropy_loss(pred_cls, cls) 63 | loss = 0.15*loss1 + loss2 64 | 65 | if args.is_distributed: 66 | dist.all_reduce(loss, op=dist.ReduceOp.SUM) 67 | loss = loss / args.world_size 68 | 69 | loss.backward() 70 | optimizer.step() 71 | 72 | if args.rank == 0: 73 | total_train_loss += loss.detach().cpu().numpy() 74 | 75 | return total_train_loss / len(train_loader) 76 | 77 | 78 | def validate_one_epoch(epoch, model, valid_loader, local_rank, args): 79 | 80 | model.eval() 81 | total_valid_loss = 0.0 82 | 83 | if args.is_distributed: 84 | valid_loader.sampler.set_epoch(epoch) 85 | 86 | # Validation loop 87 | with torch.no_grad(): 88 | for batch_idx, (left, right, disp, cls) in tqdm(enumerate(valid_loader), total=len(valid_loader)): 89 | 90 | left, right, disp = [ 91 | tensor.to(local_rank).float() for tensor in 92 | [left, right, disp] 93 | ] 94 | 95 | cls = cls.to(local_rank).long() 96 | 97 | pred_disp1, pred_disp2, pred_disp3, pred_cls = model(left, right) 98 | 99 | # Create masks for different disparity scales 100 | disp, mask = create_mask(disp, args.maxdisp, args.mindisp) 101 | 102 | # Compute losses for different scales 103 | loss1 = 0.5*F.smooth_l1_loss(pred_disp1[mask], disp[mask], size_average=True) + \ 104 | 0.7*F.smooth_l1_loss(pred_disp2[mask], disp[mask], size_average=True) + \ 105 | F.smooth_l1_loss(pred_disp3[mask], disp[mask], size_average=True) 106 | loss2 = masked_cross_entropy_loss(pred_cls, cls) 107 | loss = 0.15*loss1 + loss2 108 | 109 | if args.is_distributed: 110 | dist.all_reduce(loss, op=dist.ReduceOp.SUM) 111 | loss = loss / args.world_size 112 | 113 | if args.rank == 0: 114 | total_valid_loss += loss.detach().cpu().numpy() 115 | 116 | return total_valid_loss / len(valid_loader) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.distributed as dist 9 | from models.model import SSNet 10 | 11 | def parser() -> argparse.Namespace: 12 | 13 | parser = argparse.ArgumentParser(description='S3Net semantic Stereo Matching Model Configuration') 14 | 15 | # Model & Dataset Configuration 16 | parser.add_argument('--model', type=str, default='S3Net', help='Model name'), 17 | parser.add_argument('--dataset', type=str, default='DFC2019', help='Dataset name'), 18 | parser.add_argument('--datapath', type=str, default="./dataset/US3D", help='Dataset path'), 19 | 20 | # Disparity Configuration 21 | parser.add_argument('--maxdisp', type=int, default=48, help='Maximum disparity range'), 22 | parser.add_argument('--mindisp', type=int, default=-48, help='Minimum disparity range'), 23 | parser.add_argument('--classfication', default=6, help='class number') 24 | 25 | # Training Hyperparameters 26 | parser.add_argument('--lr', type=float, default=1e-4, help='Model learning rate'), 27 | parser.add_argument('--batch_size', type=int, default=1, help='Batch size'), 28 | parser.add_argument('--num_workers', type=int, default=2, help='Number of workers'), 29 | parser.add_argument('--epochs', type=int, default=120, help='Number of training epochs'), 30 | 31 | # Checkpoint and Saving 32 | parser.add_argument('--ckpt', type=str, default='ckpt.tar', help='Pretrained model checkpoint path'), 33 | parser.add_argument('--save_ckpt_path', type=str, default=None, help='Model checkpoint saving path'), 34 | parser.add_argument('--save_csv_file_path', type=str, default=None, help='Model training log saving path'), 35 | 36 | # Distributed Training 37 | parser.add_argument('--world_size', type=int, default=1, help='Total number of distributed training processes'), 38 | parser.add_argument('--is_distributed', type=bool, default=False, help='Enable distributed training mode'), 39 | parser.add_argument('--local_rank', type=int, default=None, help='Local process rank for distributed training'), 40 | 41 | # Parse arguments 42 | args, _ = parser.parse_known_args() 43 | return args 44 | 45 | 46 | def synchronize(): 47 | """ 48 | Helper function to synchronize (barrier) among all processes when 49 | using distributed training 50 | """ 51 | if not dist.is_available(): 52 | return 53 | if not dist.is_initialized(): 54 | return 55 | world_size = dist.get_world_size() 56 | if world_size == 1: 57 | return 58 | dist.barrier() 59 | 60 | def adjust_learning_rate(optimizer, epoch): 61 | lr_model = 1e-3 * (0.5) ** (epoch // 40) 62 | 63 | for param_group in optimizer.param_groups: 64 | if param_group['name'] == 'model': 65 | param_group['lr'] = lr_model 66 | 67 | def masked_cross_entropy_loss(y_pred, y_true): 68 | loss = F.cross_entropy(y_pred, y_true, reduction='none') 69 | return loss.mean() 70 | 71 | def create_mask(disp, maxdisp, mindisp): 72 | disp = disp.unsqueeze(1) 73 | return disp, (disp != -999) & (~torch.isnan(disp)) & (disp >= mindisp) & (disp <= maxdisp) 74 | 75 | 76 | def weights_init(m): 77 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose3d)): 78 | nn.init.kaiming_normal_(m.weight) 79 | if m.bias is not None: 80 | nn.init.zeros_(m.bias) 81 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 82 | nn.init.ones_(m.weight) 83 | nn.init.zeros_(m.bias) 84 | 85 | 86 | def initialize_distributed(args): 87 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 88 | args.world_size = num_gpus 89 | args.is_distributed = num_gpus > 1 90 | 91 | if args.is_distributed: 92 | dist.init_process_group(backend='nccl', init_method='env://') 93 | local_rank = dist.get_rank() 94 | torch.cuda.set_device(local_rank) 95 | 96 | else: 97 | if torch.cuda.is_available(): 98 | local_rank = torch.cuda.current_device() 99 | 100 | args.local_rank = local_rank 101 | 102 | 103 | def setup(args): 104 | rank = int(os.environ.get("RANK", 0)) 105 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 106 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 107 | 108 | args.world_size = world_size 109 | args.is_distributed = world_size > 1 110 | num_gpus = torch.cuda.device_count() 111 | 112 | local_rank = min(local_rank, num_gpus - 1) 113 | 114 | args.local_rank = local_rank 115 | args.rank = rank 116 | 117 | torch.cuda.set_device(local_rank) 118 | 119 | if args.rank == 0: 120 | print(f"Total available GPUs: {world_size}") 121 | 122 | dist.init_process_group( 123 | backend='nccl', 124 | init_method='env://', 125 | world_size=world_size, 126 | rank=rank 127 | ) 128 | 129 | torch.manual_seed(rank) 130 | np.random.seed(rank) 131 | random.seed(rank) 132 | 133 | def create_mask(disp, maxdisp, mindisp): 134 | return disp, (disp != -999) & (~torch.isnan(disp)) & (disp >= mindisp) & (disp <= maxdisp) 135 | 136 | def initialize_model(args): 137 | 138 | model = SSNet(args.maxdisp, args.mindisp, args.classfication) 139 | model.apply(weights_init) 140 | 141 | model = model.to(args.local_rank) 142 | 143 | if args.rank == 0: 144 | print(f'Number of model parameters: {sum([p.data.nelement() for p in model.parameters()]) / 1e6:.2f}M') 145 | 146 | if args.is_distributed: 147 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 148 | model = torch.nn.parallel.DistributedDataParallel( 149 | model, device_ids=[args.local_rank], output_device=args.local_rank, 150 | find_unused_parameters=True) 151 | else: 152 | if torch.cuda.is_available(): 153 | model = nn.DataParallel(model) 154 | 155 | return model 156 | 157 | --------------------------------------------------------------------------------