├── README.md ├── test.py ├── load_img.py ├── main.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # Image-segmentation-using-pytorch 2 | 数据集:[百度网盘](https://pan.baidu.com/s/1Ou4bsZkCxo6nuJ1T6Mq_HQ) 3 | 提取码:rja0 4 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms 5 | import numpy as np 6 | import os 7 | 8 | 9 | class TestDataset(Dataset): 10 | def __init__(self, test_img_path, transform=None): 11 | self.test_img = os.listdir(test_img_path) 12 | self.transform = transform 13 | self.images = [] 14 | for i in range(len(self.test_img)): 15 | self.images.append(os.path.join(test_img_path, self.test_img[i])) 16 | 17 | def __getitem__(self, item): 18 | img_path = self.images[item] 19 | img = cv2.imread(img_path) 20 | img = cv2.resize(img, (224, 224)) 21 | if self.transform is not None: 22 | img = self.transform(img) 23 | return img 24 | 25 | def __len__(self): 26 | return len(self.test_img) 27 | 28 | 29 | test_img_path = 'data/test/last' 30 | checkpoint_path = 'checkpoints/model_epoch_50.pth' 31 | 32 | transform = transforms.Compose([transforms.ToTensor(), 33 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 34 | bag = TestDataset(test_img_path, transform) 35 | dataloader = DataLoader(bag, batch_size=1, shuffle=None) 36 | 37 | net = torch.load(checkpoint_path) 38 | net = net.cuda() 39 | for idx, img in enumerate(dataloader): 40 | img = img.cuda() 41 | output = torch.sigmoid(net(img)) 42 | 43 | output_np = output.cpu().data.numpy().copy() 44 | output_np = np.argmin(output_np, axis=1) 45 | 46 | img_arr = np.squeeze(output_np) 47 | img_arr = img_arr*255 48 | cv2.imwrite('result/%03d.png'%idx, img_arr) 49 | print('result/%03d.png'%idx) -------------------------------------------------------------------------------- /load_img.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | class MyDataset(Dataset): 8 | def __init__(self, train_path, transform=None): 9 | self.images = os.listdir(train_path + '/last') 10 | self.labels = os.listdir(train_path + '/last_msk') 11 | assert len(self.images) == len(self.labels), 'Number does not match' 12 | self.transform = transform 13 | self.images_and_labels = [] # 存储图像和标签路径 14 | for i in range(len(self.images)): 15 | self.images_and_labels.append((train_path + '/last/' + self.images[i], train_path + '/last_msk/' + self.labels[i])) 16 | 17 | def __getitem__(self, item): 18 | img_path, lab_path = self.images_and_labels[item] 19 | img = cv2.imread(img_path) 20 | img = cv2.resize(img, (224, 224)) 21 | lab = cv2.imread(lab_path, 0) 22 | lab = cv2.resize(lab, (224, 224)) 23 | lab = lab / 255 # 转换成0和1 24 | lab = lab.astype('uint8') # 不为1的全置为0 25 | lab = np.eye(2)[lab] # one-hot编码 26 | lab = np.array(list(map(lambda x: abs(x-1), lab))).astype('float32') # 将所有0变为1(1对应255, 白色背景),所有1变为0(黑色,目标) 27 | lab = lab.transpose(2, 0, 1) # [224, 224, 2] => [2, 224, 224] 28 | if self.transform is not None: 29 | img = self.transform(img) 30 | return img, lab 31 | 32 | def __len__(self): 33 | return len(self.images) 34 | 35 | 36 | if __name__ == '__main__': 37 | img = cv2.imread('data/train/last_msk/150.jpg', 0) 38 | img = cv2.resize(img, (16, 16)) 39 | img2 = img/255 40 | img3 = img2.astype('uint8') 41 | hot1 = np.eye(2)[img3] 42 | hot2 = np.array(list(map(lambda x: abs(x-1), hot1))) 43 | print(hot2.shape) 44 | print(hot2.transpose(2, 0, 1)) 45 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import model 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import numpy as np 7 | from load_img import MyDataset 8 | from torchvision import transforms 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | batchsize = 8 13 | epochs = 50 14 | train_data_path = 'data/train' 15 | 16 | transform = transforms.Compose([transforms.ToTensor(), 17 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 18 | bag = MyDataset(train_data_path, transform) 19 | dataloader = DataLoader(bag, batch_size=batchsize, shuffle=True) 20 | 21 | 22 | device = torch.device('cuda') 23 | net = model.Net().to(device) 24 | criterion = nn.BCELoss() 25 | optimizer = optim.SGD(net.parameters(), lr=1e-2, momentum=0.7) 26 | 27 | if not os.path.exists('checkpoints'): 28 | os.mkdir('checkpoints') 29 | 30 | for epoch in range(1, epochs+1): 31 | for batch_idx, (img, lab) in enumerate(dataloader): 32 | img, lab = img.to(device), lab.to(device) 33 | output = torch.sigmoid(net(img)) 34 | loss = criterion(output, lab) 35 | 36 | output_np = output.cpu().data.numpy().copy() 37 | output_np = np.argmin(output_np, axis=1) 38 | y_np = lab.cpu().data.numpy().copy() 39 | y_np = np.argmin(y_np, axis=1) 40 | 41 | if batch_idx % 20 == 0: 42 | print('Epoch:[{}/{}]\tStep:[{}/{}]\tLoss:{:.6f}'.format( 43 | epoch, epochs, (batch_idx+1)*len(img), len(dataloader.dataset), loss.item() 44 | )) 45 | 46 | optimizer.zero_grad() 47 | loss.backward() 48 | optimizer.step() 49 | 50 | if epoch % 10 == 0: 51 | torch.save(net, 'checkpoints/model_epoch_{}.pth'.format(epoch)) 52 | print('checkpoints/model_epoch_{}.pth saved!'.format(epoch)) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Net(nn.Module): 6 | def __init__(self): 7 | super(Net, self).__init__() 8 | self.encode1 = nn.Sequential( 9 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), 10 | nn.BatchNorm2d(64), 11 | nn.ReLU(True), 12 | nn.MaxPool2d(2, 2) 13 | ) 14 | self.encode2 = nn.Sequential( 15 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 16 | nn.BatchNorm2d(128), 17 | nn.ReLU(True), 18 | nn.MaxPool2d(2, 2) 19 | ) 20 | self.encode3 = nn.Sequential( 21 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), 22 | nn.BatchNorm2d(256), 23 | nn.ReLU(True), 24 | nn.Conv2d(256, 256, 3, 1, 1), 25 | nn.BatchNorm2d(256), 26 | nn.ReLU(True), 27 | nn.MaxPool2d(2, 2) 28 | ) 29 | self.encode4 = nn.Sequential( 30 | nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), 31 | nn.BatchNorm2d(512), 32 | nn.ReLU(True), 33 | nn.Conv2d(512, 512, 3, 1, 1), 34 | nn.BatchNorm2d(512), 35 | nn.ReLU(True), 36 | nn.MaxPool2d(2, 2) 37 | ) 38 | self.encode5 = nn.Sequential( 39 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), 40 | nn.BatchNorm2d(512), 41 | nn.ReLU(True), 42 | nn.Conv2d(512, 512, 3, 1, 1), 43 | nn.BatchNorm2d(512), 44 | nn.ReLU(True), 45 | nn.MaxPool2d(2, 2) 46 | ) 47 | self.decode1 = nn.Sequential( 48 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, 49 | stride=2, padding=1, output_padding=1), 50 | nn.BatchNorm2d(256), 51 | nn.ReLU(True) 52 | ) 53 | self.decode2 = nn.Sequential( 54 | nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), 55 | nn.BatchNorm2d(128), 56 | nn.ReLU(True) 57 | ) 58 | self.decode3 = nn.Sequential( 59 | nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), 60 | nn.BatchNorm2d(64), 61 | nn.ReLU(True) 62 | ) 63 | self.decode4 = nn.Sequential( 64 | nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), 65 | nn.BatchNorm2d(32), 66 | nn.ReLU(True) 67 | ) 68 | self.decode5 = nn.Sequential( 69 | nn.ConvTranspose2d(32, 16, 3, 2, 1, 1), 70 | nn.BatchNorm2d(16), 71 | nn.ReLU(True) 72 | ) 73 | self.classifier = nn.Conv2d(16, 2, kernel_size=1) 74 | 75 | def forward(self, x): # b: batch_size 76 | out = self.encode1(x) # [b, 3, 224, 224] => [b, 64, 112, 112] 77 | out = self.encode2(out) # [b, 64, 112, 112] => [b, 128, 56, 56] 78 | out = self.encode3(out) # [b, 128, 56, 56] => [b, 256, 28, 28] 79 | out = self.encode4(out) # [b, 256, 28, 28] => [b, 512, 14, 14] 80 | out = self.encode5(out) # [b, 512, 14, 14] => [b, 512, 7, 7] 81 | out = self.decode1(out) # [b, 512, 7, 7] => [b, 256, 14, 14] 82 | out = self.decode2(out) # [b, 256, 14, 14] => [b, 128, 28, 28] 83 | out = self.decode3(out) # [b, 128, 28, 28] => [b, 64, 56, 56] 84 | out = self.decode4(out) # [b, 64, 56, 56] => [b, 32, 112, 112] 85 | out = self.decode5(out) # [b, 32, 112, 112] => [b, 16, 224, 224] 86 | out = self.classifier(out) # [b, 16, 224, 224] => [b, 2, 224, 224] 2表示类别数,目标和非目标两类 87 | return out 88 | 89 | 90 | if __name__ == '__main__': 91 | img = torch.randn(2, 3, 224, 224) 92 | net = Net() 93 | sample = net(img) 94 | print(sample.shape) --------------------------------------------------------------------------------