├── test.jpg ├── README.md ├── LICENSE ├── Infer.py └── Train.py /test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagieppel/Train-Semantic-Segmentation-Net-with-Pytorch-In-50-Lines-Of-Code/HEAD/test.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Train neural network for semantic segmentation (deep lab V3) with pytorch in 50 lines of code 2 | 3 | Train net semantic segmentation net using LabPics dataset: [https://zenodo.org/record/3697452/files/LabPicsV1.zip?download=1](https://zenodo.org/record/3697452/files/LabPicsV1.zip?download=1) in less then 50 lines of code (note including spaces) 4 | 5 | Full toturial that goes with the code can be find here: 6 | https://medium.com/@sagieppel/train-neural-net-for-semantic-segmentation-with-pytorch-in-50-lines-of-code-830c71a6544f 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 sagieppel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Infer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torchvision.models.segmentation 3 | import torch 4 | import torchvision.transforms as tf 5 | import matplotlib.pyplot as plt 6 | modelPath = "3000.torch" # Path to trained model 7 | imagePath = "test.jpg" # Test image 8 | height=width=900 9 | transformImg = tf.Compose([tf.ToPILImage(), tf.Resize((height, width)), tf.ToTensor(),tf.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))]) # tf.Resize((300,600)),tf.RandomRotation(145)])# 10 | 11 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # Check if there is GPU if not set trainning to CPU (very slow) 12 | Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True) # Load net 13 | Net.classifier[4] = torch.nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) # Change final layer to 3 classes 14 | Net = Net.to(device) # Set net to GPU or CPU 15 | Net.load_state_dict(torch.load(modelPath)) # Load trained model 16 | Net.eval() # Set to evaluation mode 17 | Img = cv2.imread(imagePath) # load test image 18 | height_orgin , widh_orgin ,d = Img.shape # Get image original size 19 | plt.imshow(Img[:,:,::-1]) # Show image 20 | plt.show() 21 | Img = transformImg(Img) # Transform to pytorch 22 | Img = torch.autograd.Variable(Img, requires_grad=False).to(device).unsqueeze(0) 23 | with torch.no_grad(): 24 | Prd = Net(Img)['out'] # Run net 25 | Prd = tf.Resize((height_orgin,widh_orgin))(Prd[0]) # Resize to origninal size 26 | seg = torch.argmax(Prd, 0).cpu().detach().numpy() # Get prediction classes 27 | plt.imshow(seg) # display image 28 | plt.show() 29 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import torchvision.models.segmentation 5 | import torch 6 | import torchvision.transforms as tf 7 | 8 | Learning_Rate=1e-5 9 | width=height=900 # image width and height 10 | batchSize=3 11 | 12 | TrainFolder="LabPics/Simple/Train//" 13 | ListImages=os.listdir(os.path.join(TrainFolder, "Image")) # Create list of images 14 | #----------------------------------------------Transform image------------------------------------------------------------------- 15 | transformImg=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)),tf.ToTensor(),tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 16 | transformAnn=tf.Compose([tf.ToPILImage(),tf.Resize((height,width),tf.InterpolationMode.NEAREST),tf.ToTensor()]) 17 | #---------------------Read image --------------------------------------------------------- 18 | def ReadRandomImage(): # First lets load random image and the corresponding annotation 19 | idx=np.random.randint(0,len(ListImages)) # Select random image 20 | Img=cv2.imread(os.path.join(TrainFolder, "Image", ListImages[idx]))[:,:,0:3] 21 | Filled = cv2.imread(os.path.join(TrainFolder, "Semantic/16_Filled", ListImages[idx].replace("jpg","png")),0) 22 | Vessel = cv2.imread(os.path.join(TrainFolder, "Semantic/1_Vessel", ListImages[idx].replace("jpg","png")),0) 23 | AnnMap = np.zeros(Img.shape[0:2],np.float32) 24 | if Vessel is not None: AnnMap[ Vessel == 1 ] = 1 25 | if Filled is not None: AnnMap[ Filled == 1 ] = 2 26 | Img=transformImg(Img) 27 | AnnMap=transformAnn(AnnMap) 28 | return Img,AnnMap 29 | #--------------Load batch of images----------------------------------------------------- 30 | def LoadBatch(): # Load batch of images 31 | images = torch.zeros([batchSize,3,height,width]) 32 | ann = torch.zeros([batchSize, height, width]) 33 | for i in range(batchSize): 34 | images[i],ann[i]=ReadRandomImage() 35 | return images, ann 36 | #--------------Load and set net and optimizer------------------------------------- 37 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 38 | Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True) # Load net 39 | Net.classifier[4] = torch.nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) # Change final layer to 3 classes 40 | Net=Net.to(device) 41 | optimizer=torch.optim.Adam(params=Net.parameters(),lr=Learning_Rate) # Create adam optimizer 42 | #----------------Train-------------------------------------------------------------------------- 43 | for itr in range(10000): # Training loop 44 | images,ann=LoadBatch() # Load taining batch 45 | images=torch.autograd.Variable(images,requires_grad=False).to(device) # Load image 46 | ann = torch.autograd.Variable(ann, requires_grad=False).to(device) # Load annotation 47 | Pred=Net(images)['out'] # make prediction 48 | Net.zero_grad() 49 | criterion = torch.nn.CrossEntropyLoss() # Set loss function 50 | Loss=criterion(Pred,ann.long()) # Calculate cross entropy loss 51 | Loss.backward() # Backpropogate loss 52 | optimizer.step() # Apply gradient descent change to weight 53 | seg = torch.argmax(Pred[0], 0).cpu().detach().numpy() # Get prediction classes 54 | print(itr,") Loss=",Loss.data.cpu().numpy()) 55 | if itr % 1000 == 0: #Save model weight once every 60k steps permenant file 56 | print("Saving Model" +str(itr) + ".torch") 57 | torch.save(Net.state_dict(), str(itr) + ".torch") 58 | --------------------------------------------------------------------------------