├── .gitignore ├── requirements.txt ├── src ├── __init__.py ├── test.py ├── data.py ├── utils.py ├── train.py └── sam_wrapper.py ├── README.md └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | .DS** 3 | logs 4 | data 5 | checkpoint -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.9.0 2 | imageio==2.19.3 3 | matplotlib==3.5.2 4 | numpy==1.21.5 5 | tqdm==4.64.1 6 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .sam_wrapper import SAMWrapper 2 | from .train import train 3 | from .test import test 4 | from .utils import compute_avg_bbox -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm.auto import tqdm 4 | 5 | from .data import get_dataloader 6 | from .utils import draw_mask_onimage 7 | 8 | def test(args, model): 9 | dataloader = get_dataloader(os.path.join(args.base_dir, 'testing'), args.mode) 10 | 11 | for i,(X,_) in enumerate(tqdm(dataloader)): 12 | 13 | X_orig = X.copy() 14 | _, pred_mask = model(X, None) 15 | 16 | draw_mask_onimage(X_orig, pred_mask.squeeze(), os.path.join(args.results_dir, f'{i}.jpg')) 17 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import imageio 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | class SegDataset(Dataset): 7 | def __init__(self, base_dir, mode): 8 | 9 | self.data_dir = os.path.join(base_dir, 'image') 10 | self.mask_dir = os.path.join(base_dir, 'gt_image') if mode == 'train' else None 11 | 12 | def __len__(self): 13 | return len(os.listdir(self.data_dir)) 14 | 15 | 16 | def __getitem__(self, idx): 17 | 18 | file = os.listdir(self.data_dir)[idx] 19 | x = imageio.imread(os.path.join(self.data_dir, file)) 20 | m_file = file.split('_') 21 | m_file.insert(1, 'road') 22 | 23 | if self.mask_dir: 24 | mask = imageio.imread(os.path.join(self.mask_dir, '_'.join(m_file))) 25 | else: 26 | mask = None 27 | 28 | return x, mask 29 | 30 | 31 | def trivial_collate(batch): 32 | return batch[0] 33 | 34 | def get_dataloader(base_data_dir, mode): 35 | 36 | dataset = SegDataset(base_data_dir, mode) 37 | dataloader = DataLoader( 38 | dataset, 39 | batch_size=1, 40 | shuffle=True, 41 | num_workers=8, 42 | collate_fn=trivial_collate 43 | ) 44 | return dataloader -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import imageio 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def compute_avg_bbox(args): 10 | mask_dir = os.path.join(args.base_dir, 'training', 'gt_image') 11 | 12 | all_bbox = np.empty((0, 4)) 13 | im_sizes = np.empty((0, 2)) 14 | for i,mask_file in enumerate(os.listdir(mask_dir)): 15 | if '.DS_Store' in mask_file: continue 16 | 17 | gt_mask = imageio.imread(os.path.join(mask_dir, mask_file)) 18 | gt_mask = torch.tensor(gt_mask)[...,2] / 255. 19 | 20 | im_sizes = np.vstack((im_sizes, np.array(gt_mask.shape[::-1]))) 21 | 22 | x,y = torch.where(gt_mask == 1) 23 | all_bbox = np.vstack((all_bbox, np.array([[y.min(), x.min(), y.max(), x.max()]]))) 24 | 25 | save = np.hstack(([np.mean(im_sizes, axis=0).astype(int), np.mean(all_bbox, axis=0)])) 26 | np.save(os.path.join(args.output_dir, 'avg_bboxes.npy'), save) 27 | 28 | 29 | def draw_mask_onimage(X, mask, path): 30 | mask = mask.detach().cpu().numpy() 31 | plt.figure() 32 | plt.imshow(X) 33 | color = np.array([255/255, 50/255, 50/255, 0.6]) 34 | h, w = mask.shape[-2:] 35 | mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 36 | plt.imshow(mask) 37 | plt.savefig(path) 38 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm.auto import tqdm 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .data import get_dataloader 8 | from .utils import draw_mask_onimage 9 | 10 | 11 | def train(args, model): 12 | 13 | dataloader = get_dataloader(os.path.join(args.base_dir, 'training'), args.mode) 14 | optimizer = torch.optim.Adam(model.sam_model.mask_decoder.parameters(), lr=args.lr) 15 | loss_fn = nn.MSELoss() 16 | 17 | accum_iter = 10 18 | best_model_loss = 1e10 19 | best_model_ckpt = None 20 | best_model_epoch = 0 21 | for ep in range(1, args.epochs + 1): 22 | 23 | total_loss = 0 24 | for i,(X,gt_mask) in enumerate(tqdm(dataloader)): 25 | 26 | X_orig = X.copy() 27 | gt_mask, pred_mask = model(X, gt_mask) 28 | 29 | # train step 30 | loss = loss_fn(pred_mask.squeeze(), gt_mask) 31 | total_loss += loss.item() 32 | loss.backward() 33 | 34 | if (i + 1) % accum_iter == 0 or (i + 1) == len(dataloader): 35 | optimizer.step() 36 | optimizer.zero_grad() 37 | 38 | if i % args.save_every == 0: 39 | draw_mask_onimage(X_orig, pred_mask.squeeze(), os.path.join(args.results_dir, f'ep{ep}_{i}.jpg')) 40 | draw_mask_onimage(X_orig, gt_mask, os.path.join(args.results_dir, f'ep{ep}_{i}_gt.jpg')) 41 | 42 | print(f'LOSS {loss.item()}') 43 | 44 | del gt_mask, pred_mask, loss 45 | torch.cuda.empty_cache() 46 | 47 | if ep % args.ckpt_every == 0: 48 | torch.save(model.sam_model.state_dict(), os.path.join(args.checkpoint_dir, f'sam_ckpt_{ep}.pth')) 49 | 50 | avg_loss = total_loss / len(dataloader.dataset) 51 | print(f'EPOCH {ep} | AVERAGE LOSS {avg_loss}') 52 | if avg_loss < best_model_loss: 53 | best_model_loss = avg_loss 54 | best_model_ckpt = model.sam_model.state_dict().copy() 55 | best_model_epoch = ep 56 | 57 | torch.save(best_model_ckpt, os.path.join(args.checkpoint_dir, f'best_model.pth')) 58 | print(f'BEST MODEL EPOCH {best_model_epoch} | LOSS {best_model_loss}') 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAM-finetune 2 | Finetune Meta's SegmentAnything 3 | 4 | This repository contains a wrapper for fine-tuning Meta's SegmentAnything model on a custom dataset 5 | for a single class. SegmentAnything is a novel deep learning model that performs semantic segmentation on an image for any number of classes. This repo leverages the prompt encoding feature of SegmentAnything to finetune to one class, using bounding boxes around the ground truth masks as the prompting. 6 | 7 | Currently supports ViT-H SAM (default) only, tested in KITTI road segmentation data 8 | 9 | 10 | ## Setup 11 | 12 | ``` 13 | conda create --name sam_finetune --file requirements.txt 14 | ``` 15 | 16 | Install [pytorch](https://pytorch.org/) with your version of CUDA, for instance: 17 | 18 | ``` 19 | conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia 20 | ``` 21 | 22 | Install SegmentAnything 23 | 24 | ``` 25 | pip install git+https://github.com/facebookresearch/segment-anything.git 26 | ``` 27 | 28 | 29 | ## Data 30 | 31 | Run 32 | 33 | ``` 34 | mkdir data 35 | ``` 36 | 37 | and put your dataset according to the following format 38 | 39 | ``` 40 | ├── main.py 41 | ├── src/ 42 | ├── data/ 43 | │ ├── your_dataset_name 44 | │ │ ├── training 45 | │ │ │ ├── image/ # training RGB images 46 | │ │ │ ├── gt_image/ # training ground truth masks 47 | │ │ ├── testing 48 | │ │ │ ├── image/ # testing RGB images 49 | │ │ │ ├── gt_image/ # testing ground truth masks (optional) 50 | 51 | ``` 52 | 53 | ## Training 54 | 55 | ``` 56 | python main.py --mode train --exp_name your_exp_name --base_dir your_dataset_name 57 | ``` 58 | 59 | Configurable arguments: 60 | * --ckpt_every : how often (epochs) to save checkpoint 61 | * --save_every : how often (dataset length) to save training results 62 | * --lr : learning rate 63 | * --epochs : num epochs 64 | 65 | 66 | ## Testing 67 | 68 | ** For best results, run the following to compute an average bbox from trainset to guide test results: ** 69 | 70 | ``` 71 | python main.py --mode bbox --exp_name your_exp_name --base_dir your_dataset_name 72 | ``` 73 | 74 | **Evaluation** 75 | ``` 76 | python main.py --mode test --exp_name your_exp_name --base_dir your_dataset_name 77 | ``` 78 | 79 | Configurable argument: 80 | * --ckpt_name : checkpoint to load, if not provided loads ```best_model.pth``` -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from src import SAMWrapper, train, test, compute_avg_bbox 9 | 10 | SEED = 1 11 | torch.manual_seed(SEED) 12 | np.random.seed(SEED) 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--base_dir', type=str, required=True, help='dataset name under data/') 16 | parser.add_argument('--mode', type=str, required=True, help='train | test | bbox') 17 | parser.add_argument('--exp_name', type=str, required=True) 18 | parser.add_argument('--ckpt_every', type=int, default=10) 19 | parser.add_argument('--save_every', type=int, default=50) 20 | parser.add_argument('--lr', type=float, default=1e-6) 21 | parser.add_argument('--epochs', type=int, default=100) 22 | parser.add_argument('--ckpt_name', type=str, default=None) 23 | if __name__ == "__main__": 24 | args = parser.parse_args() 25 | 26 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 27 | print(f'Using device {args.device}') 28 | 29 | args.base_dir = os.path.join('data', args.base_dir) 30 | args.output_dir = os.path.join('logs', args.exp_name) 31 | args.checkpoint_dir = os.path.join('checkpoint', args.exp_name) 32 | os.makedirs(args.checkpoint_dir, exist_ok=True) 33 | 34 | if args.mode == 'train': 35 | print('TRAIN MODE') 36 | args.results_dir = os.path.join(args.output_dir, 'train_results') 37 | os.makedirs(args.results_dir, exist_ok=True) 38 | 39 | if args.ckpt_name: 40 | model = SAMWrapper(os.path.join(args.checkpoint_dir, args.ckpt_name), args.device) 41 | else: 42 | print('Finetuning from SAM checkpoint and reinitializing MLP parameters') 43 | model = SAMWrapper(os.path.join('checkpoint', 'sam_vit_h_4b8939.pth'), args.device, from_scratch=True) 44 | model = model.to(args.device).train() 45 | train(args, model) 46 | 47 | elif args.mode == 'test': 48 | print('TEST MODE') 49 | args.results_dir = os.path.join(args.output_dir, 'test_results') 50 | os.makedirs(args.results_dir, exist_ok=True) 51 | 52 | try: 53 | avg_bbox = np.load(os.path.join(args.output_dir, 'avg_bboxes.npy')) 54 | except: 55 | avg_bbox = None 56 | print('No average bboxes available. Run with `--mode bbox` first for better results') 57 | if args.ckpt_name: 58 | model = SAMWrapper(os.path.join(args.checkpoint_dir, args.ckpt_name), args.device, avg_box=avg_bbox) 59 | else: 60 | print('Loading best model') 61 | model = SAMWrapper(os.path.join(args.checkpoint_dir, 'best_model.pth'), args.device, avg_box=avg_bbox) 62 | model = model.to(args.device).eval() 63 | test(args, model) 64 | 65 | elif args.mode == 'bbox': 66 | os.makedirs(args.output_dir, exist_ok=True) 67 | compute_avg_bbox(args) 68 | 69 | else: 70 | print(f'{args.mode} not supported, please specify mode [train | test | bbox]') 71 | -------------------------------------------------------------------------------- /src/sam_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.functional import threshold, normalize 5 | from segment_anything.utils.transforms import ResizeLongestSide 6 | from segment_anything import sam_model_registry 7 | 8 | 9 | class SAMWrapper(nn.Module): 10 | def __init__(self, ckpt_path, device, from_scratch=False, avg_box=None): 11 | super().__init__() 12 | self.device = device 13 | self.avg_bbox = avg_box 14 | 15 | self.sam_model = sam_model_registry['vit_h'](checkpoint=ckpt_path) 16 | if from_scratch: 17 | for layer in self.sam_model.mask_decoder.output_hypernetworks_mlps.children(): 18 | for cc in layer.children(): 19 | for c in cc.children(): 20 | try: 21 | c.reset_parameters() 22 | except: 23 | print(f'cannot reset parameters: {c}') 24 | 25 | self.transform = ResizeLongestSide(self.sam_model.image_encoder.img_size) 26 | 27 | 28 | def resize_bbox(self, target_size): 29 | x_scale = target_size[1] / self.avg_bbox[1] 30 | y_scale = target_size[0] / self.avg_bbox[0] 31 | 32 | self.avg_bbox[[2, 4]] *= y_scale 33 | self.avg_bbox[[3, 5]] *= x_scale 34 | self.avg_bbox[:2] = target_size 35 | 36 | 37 | def forward(self, X, gt_mask): 38 | 39 | # preprocessing 40 | original_size = X.shape[:2] 41 | X = self.transform.apply_image(X) 42 | X = torch.as_tensor(X, device=self.device) 43 | X = X.permute(2, 0, 1).contiguous()[None, ...] 44 | input_size = tuple(X.shape[-2:]) 45 | X = self.sam_model.preprocess(X) 46 | 47 | if gt_mask is not None: 48 | gt_mask = torch.tensor(gt_mask)[...,2] / 255. 49 | 50 | x,y = torch.where(gt_mask == 1) 51 | bbox = np.array([[y.min(), x.min(), y.max(), x.max()]]) 52 | bbox = self.transform.apply_boxes(bbox, original_size) 53 | bbox_tensor = torch.as_tensor(bbox, dtype=torch.float, device=self.device) 54 | gt_mask = gt_mask.to(self.device) 55 | elif self.avg_bbox is not None: 56 | if abs(original_size[0] - self.avg_bbox[1]) > 10 or abs(original_size[1] - self.avg_bbox[0]) > 10: 57 | self.resize_bbox(original_size[::-1]) 58 | bbox = self.transform.apply_boxes(self.avg_bbox[2:], original_size) 59 | bbox_tensor = torch.as_tensor(bbox, dtype=torch.float, device=self.device) 60 | else: 61 | bbox_tensor = None 62 | 63 | # model 64 | with torch.no_grad(): 65 | image_embedding = self.sam_model.image_encoder(X) 66 | sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( 67 | points=None, boxes=bbox_tensor, masks=None 68 | ) 69 | 70 | low_res_masks, iou_predictions = self.sam_model.mask_decoder( 71 | image_embeddings=image_embedding, 72 | image_pe=self.sam_model.prompt_encoder.get_dense_pe(), 73 | sparse_prompt_embeddings=sparse_embeddings, 74 | dense_prompt_embeddings=dense_embeddings, 75 | multimask_output=False, 76 | ) 77 | 78 | upscaled_masks = self.sam_model.postprocess_masks( 79 | low_res_masks, input_size, original_size 80 | ) 81 | binary_mask = normalize(threshold(upscaled_masks, 0.0, 0)) 82 | 83 | return gt_mask, binary_mask 84 | 85 | --------------------------------------------------------------------------------