├── __init__.py ├── config_reproductive.py ├── config_eval.py ├── scheduler.py ├── README.md ├── utils.py ├── dataset.py ├── eval.py ├── model_reproductive.py ├── train.py └── transformer.py /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /config_reproductive.py: -------------------------------------------------------------------------------- 1 | from box import Box 2 | 3 | config = { 4 | "num_devices": 4, 5 | "batch_size": 2, 6 | "num_workers": 4, 7 | "num_epochs": 150, 8 | "eval_interval": 1, 9 | "out_dir": "/path/to/out_dir/", 10 | "opt": { 11 | "learning_rate": 1e-4, 12 | "weight_decay": 1e-4, 13 | "decay_factor": 10, 14 | "steps": [50000, 100000], 15 | "warmup_steps": 250, 16 | }, 17 | "transformer":{ 18 | "hidden_dim": 256, 19 | "dropout": 0.0, 20 | "nheads": 8, 21 | "dim_feedforward": 1024, 22 | "enc_layers": 6, 23 | "dec_layers": 6, 24 | "pre_norm": False, 25 | }, 26 | "SAM_model": { 27 | "type": 'vit_b', 28 | "checkpoint": "/path/to/sam_vit_b_01ec64.pth", 29 | }, 30 | "Full_checkpoint": None, 31 | "dataset": { 32 | "train": { 33 | "root_dir": '/path/to/PartImageNet/images/train/', 34 | "annotation_file": '/path/to/PartImageNet/annotations/train/train.json' 35 | }, 36 | "val": { 37 | "root_dir": '/path/to/PartImageNet/images/val/', 38 | "annotation_file": "/path/to/PartImageNet/annotations/val/val.json" 39 | } 40 | }, 41 | "weight_adjust": { 42 | "loss_cls_weight": 5, 43 | "loss_embedding_weight": 20, 44 | "cost_cls_weight": 10, 45 | "cost_embedding_weight": 1, 46 | }, 47 | "num_proposals": 25, 48 | "num_catgories": 40, 49 | } 50 | 51 | cfg = Box(config) 52 | -------------------------------------------------------------------------------- /config_eval.py: -------------------------------------------------------------------------------- 1 | from box import Box 2 | 3 | config = { 4 | "num_devices": 4, 5 | "batch_size": 2, 6 | "num_workers": 4, 7 | "num_epochs": 150, 8 | "eval_interval": 1, 9 | "out_dir": "/path/to/out_dir/eval", 10 | "opt": { 11 | "learning_rate": 1e-4, 12 | "weight_decay": 1e-4, 13 | "decay_factor": 10, 14 | "steps": [50000, 100000], 15 | "warmup_steps": 250, 16 | }, 17 | "transformer":{ 18 | "hidden_dim": 256, 19 | "dropout": 0.0, 20 | "nheads": 8, 21 | "dim_feedforward": 1024, 22 | "enc_layers": 6, 23 | "dec_layers": 6, 24 | "pre_norm": False, 25 | }, 26 | "SAM_model": { 27 | "type": 'vit_b', 28 | "checkpoint": "/path/to/sam_vit_b_01ec64.pth", 29 | }, 30 | "Full_checkpoint": "/path/to/out_dir/last_ckpt.pth", 31 | "dataset": { 32 | "train": { 33 | "root_dir": '/path/to/PartImageNet/images/train/', 34 | "annotation_file": '/path/to/PartImageNet/annotations/train/train.json' 35 | }, 36 | "val": { 37 | "root_dir": '/path/to/PartImageNet/images/val/', 38 | "annotation_file": "/path/to/PartImageNet/annotations/val/val.json" 39 | } 40 | }, 41 | "weight_adjust": { 42 | "loss_cls_weight": 5, 43 | "loss_embedding_weight": 20, 44 | "cost_cls_weight": 10, 45 | "cost_embedding_weight": 1, 46 | }, 47 | "num_proposals": 25, 48 | "num_catgories": 40, 49 | } 50 | 51 | cfg = Box(config) 52 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | class ConstantLRSchedule(LambdaLR): 9 | """ Constant learning rate schedule. 10 | """ 11 | def __init__(self, optimizer, last_epoch=-1): 12 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 13 | 14 | 15 | class WarmupConstantSchedule(LambdaLR): 16 | """ Linear warmup and then constant. 17 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 18 | Keeps learning rate schedule equal to 1. after warmup_steps. 19 | """ 20 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 21 | self.warmup_steps = warmup_steps 22 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 23 | 24 | def lr_lambda(self, step): 25 | if step < self.warmup_steps: 26 | return float(step) / float(max(1.0, self.warmup_steps)) 27 | return 1. 28 | 29 | 30 | class WarmupLinearSchedule(LambdaLR): 31 | """ Linear warmup and then linear decay. 32 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 33 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 34 | """ 35 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 36 | self.warmup_steps = warmup_steps 37 | self.t_total = t_total 38 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 39 | 40 | def lr_lambda(self, step): 41 | if step < self.warmup_steps: 42 | return float(step) / float(max(1, self.warmup_steps)) 43 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 44 | 45 | 46 | class WarmupCosineSchedule(LambdaLR): 47 | """ Linear warmup and then cosine decay. 48 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 49 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 50 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 51 | """ 52 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 53 | self.warmup_steps = warmup_steps 54 | self.t_total = t_total 55 | self.cycles = cycles 56 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 57 | 58 | def lr_lambda(self, step): 59 | if step < self.warmup_steps: 60 | return float(step) / float(max(1.0, self.warmup_steps)) 61 | # progress after warmup 62 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 63 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WPS-SAM: Towards Weakly-Supervised Part Segmentation with Foundation Models 2 | 3 | Official PyTorch implementation of WPS from our paper: [WPS-SAM: Towards Weakly-Supervised Part Segmentation with Foundation Models](https://arxiv.org/abs/2407.10131). **ECCV 2024**. 4 | Xinjian Wu, Ruisong Zhang, Jie Qin, Shijie Ma, Cheng-Lin Liu. 5 | 6 | ## What is WPS-SAM 7 | 8 | ![image](https://github.com/user-attachments/assets/f50cd1fe-2fd0-4102-8b1d-e29a983772fa) 9 | 10 | Segmenting and recognizing diverse object parts is crucial in computer vision and robotics. Despite significant progress in object segmentation, part-level segmentation remains underexplored due to complex boundaries and scarce annotated data. To address this, we propose a novel 11 | Weakly-supervised Part Segmentation (WPS) setting (as shown in the figure above) and an approach called WPS-SAM (as shown in the figure below), built on the large-scale pre-trained vision foundation model, Segment Anything Model (SAM). WPS-SAM is an end-to-end framework designed to extract prompt tokens directly from images and perform pixel-level segmentation of part regions. During its training phase, it only uses weakly supervised labels in the form of bounding boxes or points. Extensive experiments demonstrate that, through exploiting the rich knowledge embedded in pre-trained foundation models, WPS-SAM outperforms other segmentation models trained with pixellevel strong annotations. Specifically, WPS-SAM achieves 68.93% mIOU and 79.53% mACC on the PartImageNet dataset, surpassing state-of-theart fully supervised methods by approximately 4% in terms of mIOU. 12 | 13 | ![image](https://github.com/user-attachments/assets/c89ef9b2-aa07-4558-8ff0-e31b227f744d) 14 | 15 | ## Usage 16 | 17 | ### Requirements 18 | 19 | ``` 20 | - python >= 3.8 21 | - pytorch >= 1.12.1 22 | - lightning 23 | - segmentation_models_pytorch 24 | - segment_anything 25 | - tensorboard 26 | - tensorboardX 27 | ``` 28 | 29 | Clone the repository locally: 30 | 31 | ``` 32 | git clone https://github.com/xjwu1024/WPS-SAM.git 33 | ``` 34 | 35 | ### Data Preparation 36 | 37 | Download and extract PartImageNet dataset from [here](https://huggingface.co/datasets/turkeyju/PartImageNet/blob/main/PartImageNet_Seg.zip) The directory structure is expected to be: 38 | 39 | ``` 40 | /path/to/PartImageNet/ 41 | annotations/ 42 | test 43 | test_whole 44 | train 45 | train_whole 46 | val 47 | val_whole 48 | images/ 49 | test 50 | train 51 | val 52 | ``` 53 | 54 | More details about the dataset can be found in [here](https://github.com/TACJu/PartImageNet?tab=readme-ov-file) 55 | 56 | ### Train 57 | 58 | To train WPS-SAM on PartImageNet run: 59 | 60 | ``` 61 | python train.py 62 | ``` 63 | 64 | ### Evaluation 65 | 66 | ``` 67 | python eval.py 68 | ``` 69 | 70 | ## Citation 71 | 72 | If you use WPS-SAM or this repository in your work, please cite: 73 | ``` 74 | @inproceedings{wu2024wps, 75 | title={WPS-SAM: Towards Weakly-Supervised Part Segmentation with Foundation Models}, 76 | author={Wu, Xin-Jian and Zhang, Ruisong and Qin, Jie and Ma, Shijie and Liu, Cheng-Lin}, 77 | booktitle={European Conference on Computer Vision}, 78 | pages={314--333}, 79 | year={2024}, 80 | organization={Springer} 81 | } 82 | ``` 83 | and provide a link to this repository as a footnote or a citation. 84 | 85 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import torch 5 | from box import Box 6 | from dataset import COCODataset 7 | from model import Model 8 | from torchvision.utils import draw_bounding_boxes 9 | from torchvision.utils import draw_segmentation_masks 10 | from tqdm import tqdm 11 | 12 | 13 | class AverageMeter: 14 | """Computes and stores the average and current value.""" 15 | 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | 32 | def calc_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor): 33 | pred_mask = (pred_mask >= 0.5).float() 34 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(1, 2)) 35 | union = torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection 36 | epsilon = 1e-7 37 | batch_iou = intersection / (union + epsilon) 38 | 39 | batch_iou = batch_iou.unsqueeze(1) 40 | return batch_iou 41 | 42 | 43 | def draw_image(image, masks, boxes, labels, alpha=0.4): 44 | image = torch.from_numpy(image).permute(2, 0, 1) 45 | if boxes is not None: 46 | image = draw_bounding_boxes(image, boxes, colors=['red'] * len(boxes), labels=labels, width=2) 47 | if masks is not None: 48 | image = draw_segmentation_masks(image, masks=masks, colors=['red'] * len(masks), alpha=alpha) 49 | return image.numpy().transpose(1, 2, 0) 50 | 51 | 52 | def visualize(cfg: Box): 53 | model = Model(cfg) 54 | model.setup() 55 | model.eval() 56 | model.cuda() 57 | dataset = COCODataset(root_dir=cfg.dataset.val.root_dir, 58 | annotation_file=cfg.dataset.val.annotation_file, 59 | transform=None) 60 | predictor = model.get_predictor() 61 | os.makedirs(cfg.out_dir, exist_ok=True) 62 | 63 | for image_id in tqdm(dataset.image_ids): 64 | image_info = dataset.coco.loadImgs(image_id)[0] 65 | image_path = os.path.join(dataset.root_dir, image_info['file_name']) 66 | image_output_path = os.path.join(cfg.out_dir, image_info['file_name']) 67 | image = cv2.imread(image_path) 68 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 69 | ann_ids = dataset.coco.getAnnIds(imgIds=image_id) 70 | anns = dataset.coco.loadAnns(ann_ids) 71 | bboxes = [] 72 | for ann in anns: 73 | x, y, w, h = ann['bbox'] 74 | bboxes.append([x, y, x + w, y + h]) 75 | bboxes = torch.as_tensor(bboxes, device=model.model.device) 76 | transformed_boxes = predictor.transform.apply_boxes_torch(bboxes, image.shape[:2]) 77 | predictor.set_image(image) 78 | masks, _, _ = predictor.predict_torch( 79 | point_coords=None, 80 | point_labels=None, 81 | boxes=transformed_boxes, 82 | multimask_output=False, 83 | ) 84 | image_output = draw_image(image, masks.squeeze(1), boxes=None, labels=None) 85 | cv2.imwrite(image_output_path, image_output) 86 | 87 | 88 | if __name__ == "__main__": 89 | from config import cfg 90 | visualize(cfg) 91 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | from pycocotools.coco import COCO 8 | from segment_anything.utils.transforms import ResizeLongestSide 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data import Dataset 11 | 12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | 14 | import pickle 15 | 16 | class COCODataset(Dataset): 17 | 18 | def __init__(self, root_dir, annotation_file, transform=None): 19 | self.root_dir = root_dir 20 | self.transform = transform 21 | self.coco = COCO(annotation_file) 22 | self.image_ids = list(self.coco.imgs.keys()) 23 | 24 | # Filter out image_ids without any annotations 25 | self.image_ids = [image_id for image_id in self.image_ids if len(self.coco.getAnnIds(imgIds=image_id)) > 0] 26 | 27 | def __len__(self): 28 | return len(self.image_ids) 29 | 30 | def __getitem__(self, idx): 31 | image_id = self.image_ids[idx] 32 | image_info = self.coco.loadImgs(image_id)[0] 33 | image_path = os.path.join(self.root_dir, image_info['file_name']) 34 | image = cv2.imread(image_path) 35 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 36 | ann_ids = self.coco.getAnnIds(imgIds=image_id) 37 | anns = self.coco.loadAnns(ann_ids) 38 | bboxes = [] 39 | masks = [] 40 | category_ids = [] 41 | 42 | for ann in anns: 43 | if len(ann['segmentation'][0])>4: 44 | x, y, w, h = ann['bbox'] 45 | bboxes.append([x, y, x + w, y + h]) 46 | category_ids.append(ann['category_id']) 47 | mask = self.coco.annToMask(ann) 48 | masks.append(mask) 49 | # breakpoint() 50 | if image is None: 51 | print(image_info['file_name']) 52 | print(image) 53 | 54 | if self.transform: 55 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 56 | 57 | bboxes = np.stack(bboxes, axis=0) 58 | masks = np.stack(masks, axis=0) 59 | 60 | return image, torch.tensor(bboxes), torch.tensor(masks).float(), torch.tensor(category_ids) 61 | 62 | def collate_fn(batch): 63 | images, bboxes, masks, category_ids = zip(*batch) 64 | images = torch.stack(images) 65 | return images, bboxes, masks, category_ids 66 | 67 | 68 | class ResizeAndPad: 69 | 70 | def __init__(self, target_size): 71 | self.target_size = target_size 72 | self.transform = ResizeLongestSide(target_size) 73 | self.to_tensor = transforms.ToTensor() 74 | self.normalization = transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 75 | 76 | def __call__(self, image, masks, bboxes): 77 | # Resize image and masks 78 | og_h, og_w, _ = image.shape 79 | image = self.transform.apply_image(image) 80 | masks = [torch.tensor(self.transform.apply_image(mask)) for mask in masks] 81 | image = self.to_tensor(image) 82 | 83 | # Pad image and masks to form a square 84 | _, h, w = image.shape 85 | max_dim = max(w, h) 86 | pad_w = (max_dim - w) // 2 87 | pad_h = (max_dim - h) // 2 88 | 89 | padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h) 90 | image = transforms.Pad(padding)(image) 91 | masks = [transforms.Pad(padding)(mask) for mask in masks] 92 | 93 | # Adjust bounding boxes 94 | bboxes = self.transform.apply_boxes(bboxes, (og_h, og_w)) 95 | bboxes = [[bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h] for bbox in bboxes] 96 | 97 | image = self.normalization(image) 98 | 99 | return image, masks, bboxes 100 | 101 | 102 | def load_datasets(cfg, img_size): 103 | 104 | transform = ResizeAndPad(img_size) 105 | 106 | train = COCODataset(root_dir=cfg.dataset.train.root_dir, 107 | annotation_file=cfg.dataset.train.annotation_file, 108 | transform=transform) 109 | val = COCODataset(root_dir=cfg.dataset.val.root_dir, 110 | annotation_file=cfg.dataset.val.annotation_file, 111 | transform=transform) 112 | train_dataloader = DataLoader(train, 113 | batch_size=cfg.batch_size, 114 | shuffle=True, 115 | num_workers=cfg.num_workers, 116 | collate_fn=collate_fn, 117 | persistent_workers=True, 118 | pin_memory=True) 119 | val_dataloader = DataLoader(val, 120 | batch_size=cfg.batch_size, 121 | shuffle=False, 122 | num_workers=cfg.num_workers, 123 | collate_fn=collate_fn, 124 | persistent_workers=True, 125 | pin_memory=True) 126 | return train_dataloader, val_dataloader 127 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import lightning as L 5 | import segmentation_models_pytorch as smp 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from box import Box 10 | from dataset import load_datasets 11 | from lightning.fabric.fabric import _FabricOptimizer 12 | from lightning.fabric.loggers import TensorBoardLogger 13 | 14 | from model_reproductive import Model 15 | from torch.utils.data import DataLoader 16 | from utils import AverageMeter 17 | 18 | from config_eval import cfg 19 | import logging 20 | import cv2 21 | from scheduler import WarmupCosineSchedule 22 | 23 | torch.set_float32_matmul_precision('high') 24 | torch.multiprocessing.set_sharing_strategy('file_system') 25 | 26 | def validate(fabric: L.Fabric, model: Model, val_dataloader: DataLoader, epoch: int = 0, best_ious=0): 27 | model.eval() 28 | ious = AverageMeter() 29 | accs = AverageMeter() 30 | f1_scores = AverageMeter() 31 | null_imgs = 0 32 | failed_parts = 0 33 | ok_parts = 0 34 | 35 | with torch.no_grad(): 36 | for iter, data in enumerate(val_dataloader): 37 | 38 | images, bboxes, batch_gt_masks, batch_category_ids = data 39 | batch_pred_masks, _, batch_teacher_masks, _, batch_logits, batch_pred_indices = model(images, batch_category_ids, bboxes) 40 | 41 | for logits, pred_masks, pred_indices, category_ids, gt_masks in zip(batch_logits, batch_pred_masks, batch_pred_indices, batch_category_ids, batch_gt_masks): 42 | 43 | if len(pred_masks)==0: 44 | null_imgs = null_imgs + 1 45 | iou = 0 46 | f1 = 0 47 | acc = 0 48 | accs.update(acc, n=1) 49 | ious.update(iou, n=1) 50 | f1_scores.update(f1, n=1) 51 | continue 52 | 53 | category_masks_dict = {} 54 | for category, mask in zip(category_ids, gt_masks): 55 | if category.item() not in category_masks_dict: 56 | category_masks_dict[category.item()] = mask 57 | else: 58 | category_masks_dict[category.item()] = torch.add(category_masks_dict[category.item()], mask) 59 | 60 | category_masks_dict_pred = {} 61 | _, out_class = logits.max(-1) 62 | pred_category = out_class[pred_indices] 63 | for category, mask in zip(pred_category, pred_masks): 64 | if category.item() not in category_masks_dict_pred: 65 | category_masks_dict_pred[category.item()] = mask 66 | else: 67 | category_masks_dict_pred[category.item()] = torch.add(category_masks_dict_pred[category.item()], mask) 68 | 69 | for category in category_masks_dict: 70 | 71 | if category in category_masks_dict_pred: 72 | ok_parts = ok_parts + 1 73 | stats = smp.metrics.get_stats( 74 | category_masks_dict_pred[category], 75 | category_masks_dict[category].int(), 76 | mode='binary', 77 | threshold=0.5, 78 | ) 79 | iou = smp.metrics.iou_score(*stats, reduction="micro-imagewise") 80 | acc = smp.metrics.accuracy(*stats, reduction="micro-imagewise") 81 | f1 = smp.metrics.f1_score(*stats, reduction="micro-imagewise") 82 | ious.update(iou, n=1) 83 | accs.update(acc, n=1) 84 | f1_scores.update(f1, n=1) 85 | else: 86 | failed_parts = failed_parts + 1 87 | iou = 0 88 | f1 = 0 89 | acc = 0 90 | accs.update(acc, n=1) 91 | ious.update(iou, n=1) 92 | f1_scores.update(f1, n=1) 93 | 94 | fabric.print(f'Val: [{epoch}] - [{iter}/{len(val_dataloader)}]: Mean IoU: [{ious.avg:.4f}] -- Mean ACC: [{accs.avg:.4f}] -- Mean F1: {f1_scores.avg:.4f}]') 95 | logging.info(f'Val: [{epoch}] - [{iter}/{len(val_dataloader)}]: Mean IoU: [{ious.avg:.4f}] -- Mean ACC: [{accs.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]') 96 | 97 | fabric.print(f'Validation [{epoch}]: Mean IoU: [{ious.avg:.4f}] -- Mean ACC: [{accs.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}] -- null imgs: [{null_imgs}] -- failed parts: [{failed_parts}] -- seg imgs: [{ok_parts}]') 98 | logging.info(f'Validation [{epoch}]: Mean IoU: [{ious.avg:.4f}] -- Mean ACC: [{accs.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}] -- null imgs: [{null_imgs}] -- failed parts: [{failed_parts}] -- seg imgs: [{ok_parts}]') 99 | 100 | logging.info('==============================next epoch=============================================') 101 | return best_ious 102 | 103 | def configure_opt(cfg: Box, model: Model): 104 | 105 | def lr_lambda(step): 106 | return 1.0 107 | 108 | def get_parameters(): 109 | params = [] 110 | for name, param in model.named_parameters(): 111 | if not name.startswith('SAM_mode'): 112 | params.append(param) 113 | return params 114 | 115 | optimizer = torch.optim.Adam(get_parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay) 116 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 117 | 118 | return optimizer, scheduler 119 | 120 | 121 | def main(cfg: Box) -> None: 122 | 123 | print('ready!') 124 | if not os.path.exists(cfg.out_dir): 125 | os.mkdir(cfg.out_dir) 126 | 127 | os.system('cp config_eval.py '+ cfg.out_dir) 128 | 129 | log_file = os.path.join(cfg.out_dir, "log.txt") 130 | logging.basicConfig(filename=log_file, level=logging.INFO) 131 | 132 | fabric = L.Fabric(accelerator="cuda", 133 | devices=cfg.num_devices, 134 | strategy="ddp", 135 | loggers=[TensorBoardLogger(cfg.out_dir, name="lightning-sam")]) 136 | fabric.launch() 137 | fabric.seed_everything(1337 + fabric.global_rank) 138 | 139 | if fabric.global_rank == 0: 140 | os.makedirs(cfg.out_dir, exist_ok=True) 141 | 142 | model = Model(cfg) 143 | 144 | train_data, val_data = load_datasets(cfg, model.SAM_model.image_encoder.img_size) 145 | train_data = fabric._setup_dataloader(train_data) 146 | val_data = fabric._setup_dataloader(val_data) 147 | 148 | optimizer, scheduler = configure_opt(cfg, model) 149 | model, optimizer = fabric.setup(model, optimizer) 150 | 151 | validate(fabric, model, val_data, epoch=0, best_ious=0) 152 | logging.shutdown() 153 | 154 | if __name__ == "__main__": 155 | main(cfg) 156 | 157 | -------------------------------------------------------------------------------- /model_reproductive.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from segment_anything import sam_model_registry 4 | import torch 5 | from scipy.optimize import linear_sum_assignment 6 | import numpy as np 7 | 8 | from transformer import build_transformer 9 | 10 | class LayerNorm2d(nn.Module): 11 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 12 | super().__init__() 13 | self.weight = nn.Parameter(torch.ones(num_channels)) 14 | self.bias = nn.Parameter(torch.zeros(num_channels)) 15 | self.eps = eps 16 | 17 | def forward(self, x: torch.Tensor) -> torch.Tensor: 18 | u = x.mean(1, keepdim=True) 19 | s = (x - u).pow(2).mean(1, keepdim=True) 20 | x = (x - u) / torch.sqrt(s + self.eps) 21 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 22 | return x 23 | 24 | 25 | def matcher(source_label, source_embedding, logits, target_embedding, cost_cls_weight, cost_embedding_weight): 26 | 27 | source_embedding = source_embedding.reshape(source_embedding.shape[0], -1) 28 | target_embedding = target_embedding.detach().reshape(target_embedding.shape[0], -1) 29 | 30 | # Compute the classification cost 31 | out_prob = logits.softmax(-1) 32 | cost_class = -out_prob[:, source_label] 33 | 34 | # Compute the embedding cost 35 | cost_embedding = torch.cdist(target_embedding, source_embedding, p=2) 36 | 37 | # Final cost matrix [M, N] 38 | C = (cost_cls_weight * cost_class + cost_embedding_weight * cost_embedding).cpu() 39 | 40 | pred_indices, gt_indices = linear_sum_assignment(C.detach()) 41 | sorted_gt_indices = np.argsort(gt_indices) 42 | sorted_pred_indices = pred_indices[sorted_gt_indices] 43 | 44 | return sorted_pred_indices 45 | 46 | class Model(nn.Module): 47 | 48 | def __init__(self, cfg): 49 | super().__init__() 50 | self.cfg = cfg 51 | 52 | self.SAM_model = sam_model_registry[self.cfg.SAM_model.type](checkpoint=self.cfg.SAM_model.checkpoint) 53 | self.conv_layers = nn.Sequential( 54 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), # shape: 64*64→32*32 55 | LayerNorm2d(256), 56 | nn.GELU(), 57 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), # shape: 32*32→16*16 58 | LayerNorm2d(256), 59 | nn.GELU() 60 | ) 61 | self.prompter = build_transformer(cfg.transformer) 62 | self.num_queries = cfg.num_proposals 63 | self.dim_embedding = 256 64 | self.num_classes = cfg.num_catgories 65 | self.pos_embed = nn.Parameter(torch.zeros(1, self.dim_embedding, 16, 16)) 66 | self.query_embed = nn.Embedding(self.num_queries, self.dim_embedding) 67 | self.class_embed = nn.Linear(self.dim_embedding, self.num_classes + 1) 68 | self.coords_embed = nn.Sequential( 69 | nn.Linear(in_features=256, out_features=512), 70 | nn.GELU(), 71 | nn.Linear(in_features=512, out_features=512), 72 | ) 73 | 74 | if self.cfg.Full_checkpoint is not None: 75 | with open(self.cfg.Full_checkpoint, "rb") as f: 76 | state_dict = torch.load(f) 77 | self.load_state_dict(state_dict, strict=True) 78 | 79 | for param in self.SAM_model.image_encoder.parameters(): 80 | param.requires_grad = False 81 | for param in self.SAM_model.prompt_encoder.parameters(): 82 | param.requires_grad = False 83 | for param in self.SAM_model.mask_decoder.parameters(): 84 | param.requires_grad = False 85 | 86 | def forward(self, images, batch_category_ids, batch_bboxes): 87 | bs, _, H, W = images.shape 88 | with torch.no_grad(): 89 | batch_image_embeddings = self.SAM_model.image_encoder(images) 90 | 91 | batch_feature_maps = self.conv_layers(batch_image_embeddings) 92 | hs = self.prompter(src=batch_feature_maps, mask=None, query_embed=self.query_embed.weight, pos_embed=self.pos_embed)[0][-1] 93 | batch_logits = self.class_embed(hs) 94 | batch_student_embeddings = self.coords_embed(hs).reshape(bs, self.num_queries, 2, 256) 95 | batch_teacher_embeddings = [] 96 | batch_pred_indices = [] 97 | batch_pred_masks = [] 98 | batch_teacher_masks = [] 99 | 100 | for image_embeddings, category_ids, bboxes, student_embeddings, logits in zip(batch_image_embeddings, batch_category_ids, batch_bboxes, batch_student_embeddings, batch_logits): 101 | 102 | # bbox_supervised 103 | with torch.no_grad(): 104 | teacher_embeddings, dense_embeddings = self.SAM_model.prompt_encoder( 105 | points=None, 106 | boxes=bboxes, 107 | masks=None, 108 | ) 109 | 110 | # point_supervised 111 | # coords = torch.cat((((bboxes[:, 0]+bboxes[:, 2])/2).unsqueeze(1), ((bboxes[:, 1]+bboxes[:, 3])/2).unsqueeze(1)), dim=1).unsqueeze(1) 112 | # labels = torch.ones(coords.shape[0], 1, device=coords.device, dtype=int) 113 | # points = [coords, labels] 114 | # with torch.no_grad(): 115 | # teacher_embeddings, dense_embeddings = self.SAM_model.prompt_encoder( 116 | # points=points, 117 | # boxes=None, 118 | # masks=None, 119 | # ) 120 | 121 | if self.training: 122 | batch_teacher_embeddings.append(teacher_embeddings) 123 | pred_indices = matcher(source_label=category_ids, source_embedding=teacher_embeddings, logits=logits, target_embedding=student_embeddings, cost_cls_weight=self.cfg.weight_adjust.cost_cls_weight, cost_embedding_weight=self.cfg.weight_adjust.cost_embedding_weight) 124 | batch_pred_indices.append(pred_indices) 125 | 126 | else: 127 | _, pred_class = logits.max(-1) 128 | pred_indices = torch.nonzero(pred_class!=self.num_classes).squeeze(-1) 129 | 130 | if pred_indices.shape[0]>0: 131 | with torch.no_grad(): 132 | low_res_masks, _ = self.SAM_model.mask_decoder( 133 | image_embeddings=image_embeddings.unsqueeze(0), 134 | image_pe=self.SAM_model.prompt_encoder.get_dense_pe(), 135 | sparse_prompt_embeddings=student_embeddings[pred_indices], 136 | dense_prompt_embeddings=dense_embeddings[0].unsqueeze(0).repeat(pred_indices.shape[0], 1, 1, 1), 137 | multimask_output=False, 138 | ) 139 | pred_masks = F.interpolate( 140 | low_res_masks, 141 | (H, W), 142 | mode="bilinear", 143 | align_corners=False, 144 | ) 145 | pred_masks = torch.sigmoid(pred_masks) 146 | batch_pred_masks.append(pred_masks.squeeze(1)) 147 | batch_pred_indices.append(pred_indices) 148 | 149 | return batch_pred_masks, batch_teacher_embeddings, batch_teacher_masks, batch_student_embeddings, batch_logits, batch_pred_indices 150 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import lightning as L 5 | import segmentation_models_pytorch as smp 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from box import Box 10 | from dataset import load_datasets 11 | from lightning.fabric.fabric import _FabricOptimizer 12 | from lightning.fabric.loggers import TensorBoardLogger 13 | 14 | from model_reproductive import Model 15 | from torch.utils.data import DataLoader 16 | from utils import AverageMeter 17 | 18 | from config_reproductive import cfg 19 | import logging 20 | import cv2 21 | from scheduler import WarmupCosineSchedule 22 | 23 | torch.set_float32_matmul_precision('high') 24 | torch.multiprocessing.set_sharing_strategy('file_system') 25 | 26 | def validate(fabric: L.Fabric, model: Model, val_dataloader: DataLoader, epoch: int = 0, best_ious=0): 27 | model.eval() 28 | ious = AverageMeter() 29 | accs = AverageMeter() 30 | f1_scores = AverageMeter() 31 | cls_acc = AverageMeter() 32 | 33 | null_imgs = 0 34 | failed_imgs = 0 35 | matched_imgs = 0 36 | 37 | with torch.no_grad(): 38 | for iter, data in enumerate(val_dataloader): 39 | 40 | images, bboxes, batch_gt_masks, batch_category_ids = data 41 | batch_pred_masks, _, batch_teacher_masks, _, batch_logits, batch_pred_indices = model(images, batch_category_ids, bboxes) 42 | 43 | for logits, pred_masks, pred_indices, category_ids, gt_masks in zip(batch_logits, batch_pred_masks, batch_pred_indices, batch_category_ids, batch_gt_masks): 44 | if len(pred_masks)==0: 45 | null_imgs = null_imgs + 1 46 | iou = 0 47 | f1 = 0 48 | acc = 0 49 | accs.update(acc, n=1) 50 | ious.update(iou, n=1) 51 | f1_scores.update(f1, n=1) 52 | continue 53 | 54 | category_masks_dict = {} 55 | for category, mask in zip(category_ids, gt_masks): 56 | if category.item() not in category_masks_dict: 57 | category_masks_dict[category.item()] = mask 58 | else: 59 | category_masks_dict[category.item()] = torch.add(category_masks_dict[category.item()], mask) 60 | 61 | category_masks_dict_pred = {} 62 | _, out_class = logits.max(-1) 63 | pred_category = out_class[pred_indices] 64 | for category, mask in zip(pred_category, pred_masks): 65 | if category.item() not in category_masks_dict_pred: 66 | category_masks_dict_pred[category.item()] = mask 67 | else: 68 | category_masks_dict_pred[category.item()] = torch.add(category_masks_dict_pred[category.item()], mask) 69 | 70 | for category in category_masks_dict: 71 | if category in category_masks_dict_pred: 72 | matched_imgs = matched_imgs + 1 73 | stats = smp.metrics.get_stats( 74 | category_masks_dict_pred[category], 75 | category_masks_dict[category].int(), 76 | mode='binary', 77 | threshold=0.5, 78 | ) 79 | iou = smp.metrics.iou_score(*stats, reduction="micro-imagewise") 80 | acc = smp.metrics.accuracy(*stats, reduction="micro-imagewise") 81 | f1 = smp.metrics.f1_score(*stats, reduction="micro-imagewise") 82 | ious.update(iou, n=1) 83 | accs.update(acc, n=1) 84 | f1_scores.update(f1, n=1) 85 | else: 86 | failed_imgs = failed_imgs + 1 87 | iou = 0 88 | f1 = 0 89 | acc = 0 90 | accs.update(acc, n=1) 91 | ious.update(iou, n=1) 92 | f1_scores.update(f1, n=1) 93 | 94 | fabric.print(f'Val: [{epoch}] - [{iter}/{len(val_dataloader)}]: Mean IoU: [{ious.avg:.4f}] -- Mean ACC: [{accs.avg:.4f}] -- Mean F1: {f1_scores.avg:.4f}]') 95 | logging.info(f'Val: [{epoch}] - [{iter}/{len(val_dataloader)}]: Mean IoU: [{ious.avg:.4f}] -- Mean ACC: [{accs.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]') 96 | 97 | fabric.print(f'Validation [{epoch}]: Mean IoU: [{ious.avg:.4f}] -- Mean ACC: [{accs.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}] -- matched imgs: [{matched_imgs}] -- null imgs: [{null_imgs}] -- failed imgs: [{failed_imgs}]') 98 | logging.info(f'Validation [{epoch}]: Mean IoU: [{ious.avg:.4f}] -- Mean ACC: [{accs.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}] -- matched imgs: [{matched_imgs}] -- null imgs: [{null_imgs}] -- failed imgs: [{failed_imgs}]') 99 | 100 | fabric.print(f"Saving last checkpoint to {cfg.out_dir}") 101 | state_dict = model.state_dict() 102 | if fabric.global_rank == 0: 103 | torch.save(state_dict, os.path.join(cfg.out_dir, f"last_ckpt.pth")) 104 | model.train() 105 | 106 | if ious.avg > best_ious: 107 | fabric.print(f"Cool! Saving checkpoint to {cfg.out_dir}") 108 | state_dict = model.state_dict() 109 | if fabric.global_rank == 0: 110 | torch.save(state_dict, os.path.join(cfg.out_dir, f"best_ckpt.pth")) 111 | best_ious = ious.avg 112 | fabric.print(f'best performance: Mean IoU: [{ious.avg:.4f}] -- Mean ACC: [{accs.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]') 113 | logging.info(f'best performance: Mean IoU: [{ious.avg:.4f}] -- Mean ACC: [{accs.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]') 114 | 115 | logging.info('==============================next epoch=============================================') 116 | return best_ious 117 | 118 | def prompter_kd( 119 | cfg: Box, 120 | fabric: L.Fabric, 121 | model: Model, 122 | optimizer: _FabricOptimizer, 123 | scheduler: _FabricOptimizer, 124 | train_dataloader: DataLoader, 125 | val_dataloader: DataLoader, 126 | ): 127 | """The prompter kd loop.""" 128 | 129 | best_ious = 0 130 | for epoch in range(1, cfg.num_epochs): 131 | batch_time = AverageMeter() 132 | data_time = AverageMeter() 133 | cls_losses = AverageMeter() 134 | embedding_losses = AverageMeter() 135 | Acc = AverageMeter() 136 | Recall = AverageMeter() 137 | end = time.time() 138 | 139 | for iter, data in enumerate(train_dataloader): 140 | data_time.update(time.time() - end) 141 | images, bboxes, _, batch_category_ids = data 142 | batch_size = images.size(0) 143 | 144 | _, batch_teacher_embeddings, _, batch_student_embeddings, batch_logits, batch_pred_indices = model(images, batch_category_ids, bboxes) 145 | 146 | loss_ce = torch.tensor(0., device=fabric.device) 147 | loss_embedding = torch.tensor(0., device=fabric.device) 148 | acc = torch.tensor(0., device=fabric.device) 149 | recall = torch.tensor(0., device=fabric.device) 150 | 151 | for teacher_embeddings, student_embeddings, logits, pred_indices, category_ids in zip(batch_teacher_embeddings, batch_student_embeddings, batch_logits, batch_pred_indices, batch_category_ids): 152 | loss_embedding += F.smooth_l1_loss(teacher_embeddings, student_embeddings[pred_indices]) # L1_smoth_loss 153 | expended_labels = torch.full((logits.size(0), ), cfg.num_catgories , dtype=torch.int64, device=category_ids.device) 154 | expended_labels[pred_indices] = category_ids 155 | loss_ce += F.cross_entropy(logits, expended_labels) 156 | _, pred_class = logits.max(1) 157 | acc += sum((pred_class==expended_labels).int())/logits.size(0) 158 | recall += sum(pred_class[pred_indices]==category_ids.int())/category_ids.size(0) 159 | 160 | loss_ce = loss_ce / batch_size 161 | loss_embedding = loss_embedding / batch_size 162 | acc = acc / batch_size 163 | recall = recall / batch_size 164 | 165 | loss_matcher = cfg.weight_adjust.loss_cls_weight * loss_ce + cfg.weight_adjust.loss_embedding_weight * loss_embedding 166 | 167 | optimizer.zero_grad() 168 | fabric.backward(loss_matcher) 169 | optimizer.step() 170 | scheduler.step() 171 | batch_time.update(time.time() - end) 172 | end = time.time() 173 | 174 | cls_losses.update(loss_ce.item(), batch_size) 175 | embedding_losses.update(loss_embedding.item(), batch_size) 176 | Acc.update(acc.item(), batch_size) 177 | Recall.update(recall.item(), batch_size) 178 | 179 | fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]' 180 | f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]' 181 | f' | embedding Loss [{embedding_losses.val:.4f} ({embedding_losses.avg:.4f})]' 182 | f' | cls Loss [{cls_losses.val:.4f} ({cls_losses.avg:.4f})]' 183 | f' | matcher Acc [{Acc.val:.4f} ({Acc.avg:.4f})]' 184 | f' | matcher recall [{Recall.val:.4f} ({Recall.avg:.4f})]' 185 | ) 186 | logging.info(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]' 187 | f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]' 188 | f' | embedding Loss [{embedding_losses.val:.4f} ({embedding_losses.avg:.4f})]' 189 | f' | cls Loss [{cls_losses.val:.4f} ({cls_losses.avg:.4f})]' 190 | f' | matcher Acc [{Acc.val:.4f} ({Acc.avg:.4f})]' 191 | f' | matcher recall [{Recall.val:.4f} ({Recall.avg:.4f})]' 192 | ) 193 | 194 | if epoch % cfg.eval_interval == 0: 195 | best_ious = validate(fabric, model, val_dataloader, epoch, best_ious) 196 | fabric.print(f'best Mean IoU: [{best_ious:.4f}]') 197 | 198 | 199 | def configure_opt(cfg: Box, model: Model): 200 | 201 | def lr_lambda(step): 202 | if step < cfg.opt.warmup_steps: 203 | return step / cfg.opt.warmup_steps 204 | elif step < cfg.opt.steps[0]: 205 | return 1.0 206 | elif step < cfg.opt.steps[1]: 207 | return 1 / cfg.opt.decay_factor 208 | else: 209 | return 1 / (cfg.opt.decay_factor**2) 210 | # return 1.0 211 | 212 | def get_parameters(): 213 | params = [] 214 | for name, param in model.named_parameters(): 215 | if not name.startswith('SAM_mode'): 216 | params.append(param) 217 | return params 218 | 219 | optimizer = torch.optim.Adam(get_parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay) 220 | # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 221 | scheduler = WarmupCosineSchedule(optimizer, warmup_steps=cfg.opt.warmup_steps, t_total=38370) 222 | 223 | return optimizer, scheduler 224 | 225 | 226 | def main(cfg: Box) -> None: 227 | 228 | print('ready!') 229 | if not os.path.exists(cfg.out_dir): 230 | os.mkdir(cfg.out_dir) 231 | 232 | os.system('cp config_reproductive.py '+ cfg.out_dir) 233 | 234 | log_file = os.path.join(cfg.out_dir, "log.txt") 235 | logging.basicConfig(filename=log_file, level=logging.INFO) 236 | 237 | fabric = L.Fabric(accelerator="cuda", 238 | devices=cfg.num_devices, 239 | strategy="ddp", 240 | loggers=[TensorBoardLogger(cfg.out_dir, name="lightning-sam")]) 241 | fabric.launch() 242 | fabric.seed_everything(1337 + fabric.global_rank) 243 | 244 | if fabric.global_rank == 0: 245 | os.makedirs(cfg.out_dir, exist_ok=True) 246 | 247 | model = Model(cfg) 248 | 249 | train_data, val_data = load_datasets(cfg, model.SAM_model.image_encoder.img_size) 250 | train_data = fabric._setup_dataloader(train_data) 251 | val_data = fabric._setup_dataloader(val_data) 252 | 253 | optimizer, scheduler = configure_opt(cfg, model) 254 | model, optimizer = fabric.setup(model, optimizer) 255 | # breakpoint() 256 | # total = sum([param.nelement() for param in model.parameters() if param.requires_grad]) 257 | torch.backends.cudnn.benchmark = True 258 | prompter_kd(cfg, fabric, model, optimizer, scheduler, train_data, val_data) 259 | logging.shutdown() 260 | 261 | if __name__ == "__main__": 262 | main(cfg) 263 | 264 | -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | DETR Transformer class. 4 | 5 | Copy-paste from torch.nn.Transformer with modifications: 6 | * positional encodings are passed in MHattention 7 | * extra LN at the end of encoder is removed 8 | * decoder returns a stack of activations from all decoding layers 9 | """ 10 | import copy 11 | from typing import Optional, List 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn, Tensor 16 | 17 | 18 | class Transformer(nn.Module): 19 | 20 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 21 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 22 | activation="relu", normalize_before=False, 23 | return_intermediate_dec=False): 24 | super().__init__() 25 | 26 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 27 | dropout, activation, normalize_before) 28 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 29 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 30 | 31 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 32 | dropout, activation, normalize_before) 33 | decoder_norm = nn.LayerNorm(d_model) 34 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 35 | return_intermediate=return_intermediate_dec) 36 | 37 | self._reset_parameters() 38 | 39 | self.d_model = d_model 40 | self.nhead = nhead 41 | 42 | def _reset_parameters(self): 43 | for p in self.parameters(): 44 | if p.dim() > 1: 45 | nn.init.xavier_uniform_(p) 46 | 47 | def forward(self, src, mask, query_embed, pos_embed): 48 | # flatten NxCxHxW to HWxNxC 49 | # breakpoint() 50 | bs, c, h, w = src.shape 51 | src = src.flatten(2).permute(2, 0, 1) 52 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 53 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 54 | if mask is not None: 55 | mask = mask.flatten(1) 56 | 57 | tgt = torch.zeros_like(query_embed) 58 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 59 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 60 | pos=pos_embed, query_pos=query_embed) # [6, 100, 8, 256]? 61 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 62 | 63 | 64 | class TransformerEncoder(nn.Module): 65 | 66 | def __init__(self, encoder_layer, num_layers, norm=None): 67 | super().__init__() 68 | self.layers = _get_clones(encoder_layer, num_layers) 69 | self.num_layers = num_layers 70 | self.norm = norm 71 | 72 | def forward(self, src, 73 | mask: Optional[Tensor] = None, 74 | src_key_padding_mask: Optional[Tensor] = None, 75 | pos: Optional[Tensor] = None): 76 | output = src 77 | # breakpoint() 78 | for layer in self.layers: 79 | output = layer(output, src_mask=mask, 80 | src_key_padding_mask=src_key_padding_mask, pos=pos) 81 | 82 | if self.norm is not None: 83 | output = self.norm(output) 84 | 85 | return output 86 | 87 | 88 | class TransformerDecoder(nn.Module): 89 | 90 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 91 | super().__init__() 92 | self.layers = _get_clones(decoder_layer, num_layers) 93 | self.num_layers = num_layers 94 | self.norm = norm 95 | self.return_intermediate = return_intermediate 96 | 97 | def forward(self, tgt, memory, 98 | tgt_mask: Optional[Tensor] = None, 99 | memory_mask: Optional[Tensor] = None, 100 | tgt_key_padding_mask: Optional[Tensor] = None, 101 | memory_key_padding_mask: Optional[Tensor] = None, 102 | pos: Optional[Tensor] = None, 103 | query_pos: Optional[Tensor] = None): 104 | output = tgt 105 | 106 | intermediate = [] 107 | 108 | for layer in self.layers: 109 | output = layer(output, memory, tgt_mask=tgt_mask, 110 | memory_mask=memory_mask, 111 | tgt_key_padding_mask=tgt_key_padding_mask, 112 | memory_key_padding_mask=memory_key_padding_mask, 113 | pos=pos, query_pos=query_pos) 114 | if self.return_intermediate: 115 | intermediate.append(self.norm(output)) 116 | 117 | if self.norm is not None: 118 | output = self.norm(output) 119 | if self.return_intermediate: 120 | intermediate.pop() 121 | intermediate.append(output) 122 | 123 | if self.return_intermediate: 124 | return torch.stack(intermediate) 125 | 126 | return output.unsqueeze(0) 127 | 128 | 129 | class TransformerEncoderLayer(nn.Module): 130 | 131 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 132 | activation="relu", normalize_before=False): 133 | super().__init__() 134 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 135 | # Implementation of Feedforward model 136 | self.linear1 = nn.Linear(d_model, dim_feedforward) 137 | self.dropout = nn.Dropout(dropout) 138 | self.linear2 = nn.Linear(dim_feedforward, d_model) 139 | 140 | self.norm1 = nn.LayerNorm(d_model) 141 | self.norm2 = nn.LayerNorm(d_model) 142 | self.dropout1 = nn.Dropout(dropout) 143 | self.dropout2 = nn.Dropout(dropout) 144 | 145 | self.activation = _get_activation_fn(activation) 146 | self.normalize_before = normalize_before 147 | 148 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 149 | return tensor if pos is None else tensor + pos 150 | 151 | def forward_post(self, 152 | src, 153 | src_mask: Optional[Tensor] = None, 154 | src_key_padding_mask: Optional[Tensor] = None, 155 | pos: Optional[Tensor] = None): 156 | # breakpoint() 157 | q = k = self.with_pos_embed(src, pos) 158 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 159 | key_padding_mask=src_key_padding_mask)[0] 160 | src = src + self.dropout1(src2) 161 | src = self.norm1(src) 162 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 163 | src = src + self.dropout2(src2) 164 | src = self.norm2(src) 165 | return src 166 | 167 | def forward_pre(self, src, 168 | src_mask: Optional[Tensor] = None, 169 | src_key_padding_mask: Optional[Tensor] = None, 170 | pos: Optional[Tensor] = None): 171 | src2 = self.norm1(src) 172 | q = k = self.with_pos_embed(src2, pos) 173 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 174 | key_padding_mask=src_key_padding_mask)[0] 175 | src = src + self.dropout1(src2) 176 | src2 = self.norm2(src) 177 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 178 | src = src + self.dropout2(src2) 179 | return src 180 | 181 | def forward(self, src, 182 | src_mask: Optional[Tensor] = None, 183 | src_key_padding_mask: Optional[Tensor] = None, 184 | pos: Optional[Tensor] = None): 185 | if self.normalize_before: 186 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 187 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 188 | 189 | 190 | class TransformerDecoderLayer(nn.Module): 191 | 192 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 193 | activation="relu", normalize_before=False): 194 | super().__init__() 195 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 196 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 197 | # Implementation of Feedforward model 198 | self.linear1 = nn.Linear(d_model, dim_feedforward) 199 | self.dropout = nn.Dropout(dropout) 200 | self.linear2 = nn.Linear(dim_feedforward, d_model) 201 | 202 | self.norm1 = nn.LayerNorm(d_model) 203 | self.norm2 = nn.LayerNorm(d_model) 204 | self.norm3 = nn.LayerNorm(d_model) 205 | self.dropout1 = nn.Dropout(dropout) 206 | self.dropout2 = nn.Dropout(dropout) 207 | self.dropout3 = nn.Dropout(dropout) 208 | 209 | self.activation = _get_activation_fn(activation) 210 | self.normalize_before = normalize_before 211 | 212 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 213 | return tensor if pos is None else tensor + pos 214 | 215 | def forward_post(self, tgt, memory, 216 | tgt_mask: Optional[Tensor] = None, 217 | memory_mask: Optional[Tensor] = None, 218 | tgt_key_padding_mask: Optional[Tensor] = None, 219 | memory_key_padding_mask: Optional[Tensor] = None, 220 | pos: Optional[Tensor] = None, 221 | query_pos: Optional[Tensor] = None): 222 | q = k = self.with_pos_embed(tgt, query_pos) 223 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 224 | key_padding_mask=tgt_key_padding_mask)[0] 225 | tgt = tgt + self.dropout1(tgt2) 226 | tgt = self.norm1(tgt) 227 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 228 | key=self.with_pos_embed(memory, pos), 229 | value=memory, attn_mask=memory_mask, 230 | key_padding_mask=memory_key_padding_mask)[0] 231 | tgt = tgt + self.dropout2(tgt2) 232 | tgt = self.norm2(tgt) 233 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 234 | tgt = tgt + self.dropout3(tgt2) 235 | tgt = self.norm3(tgt) 236 | return tgt 237 | 238 | def forward_pre(self, tgt, memory, 239 | tgt_mask: Optional[Tensor] = None, 240 | memory_mask: Optional[Tensor] = None, 241 | tgt_key_padding_mask: Optional[Tensor] = None, 242 | memory_key_padding_mask: Optional[Tensor] = None, 243 | pos: Optional[Tensor] = None, 244 | query_pos: Optional[Tensor] = None): 245 | tgt2 = self.norm1(tgt) 246 | q = k = self.with_pos_embed(tgt2, query_pos) 247 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 248 | key_padding_mask=tgt_key_padding_mask)[0] 249 | tgt = tgt + self.dropout1(tgt2) 250 | tgt2 = self.norm2(tgt) 251 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 252 | key=self.with_pos_embed(memory, pos), 253 | value=memory, attn_mask=memory_mask, 254 | key_padding_mask=memory_key_padding_mask)[0] 255 | tgt = tgt + self.dropout2(tgt2) 256 | tgt2 = self.norm3(tgt) 257 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 258 | tgt = tgt + self.dropout3(tgt2) 259 | return tgt 260 | 261 | def forward(self, tgt, memory, 262 | tgt_mask: Optional[Tensor] = None, 263 | memory_mask: Optional[Tensor] = None, 264 | tgt_key_padding_mask: Optional[Tensor] = None, 265 | memory_key_padding_mask: Optional[Tensor] = None, 266 | pos: Optional[Tensor] = None, 267 | query_pos: Optional[Tensor] = None): 268 | if self.normalize_before: 269 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 270 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 271 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 272 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 273 | 274 | 275 | def _get_clones(module, N): 276 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 277 | 278 | 279 | def build_transformer(args): 280 | return Transformer( 281 | d_model=args.hidden_dim, 282 | dropout=args.dropout, 283 | nhead=args.nheads, 284 | dim_feedforward=args.dim_feedforward, 285 | num_encoder_layers=args.enc_layers, 286 | num_decoder_layers=args.dec_layers, 287 | normalize_before=args.pre_norm, 288 | return_intermediate_dec=True, 289 | ) 290 | 291 | 292 | def _get_activation_fn(activation): 293 | """Return an activation function given a string""" 294 | if activation == "relu": 295 | return F.relu 296 | if activation == "gelu": 297 | return F.gelu 298 | if activation == "glu": 299 | return F.glu 300 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 301 | --------------------------------------------------------------------------------