├── IC_test.py ├── IC_train.py ├── README.md ├── SDD_test.py ├── SDD_train.py ├── checkpoints ├── icModule.pth └── sdd.pth ├── scripts ├── DataAug.py ├── binary.py ├── bwfunction.m └── statistic.py ├── torch2trt ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── calibration.cpython-36.pyc │ ├── module_test.cpython-36.pyc │ └── torch2trt.cpython-36.pyc ├── calibration.py ├── converters │ ├── AdaptiveAvgPool2d.py │ ├── BatchNorm1d.py │ ├── BatchNorm2d.py │ ├── Conv1d.py │ ├── Conv2d.py │ ├── ConvTranspose2d.py │ ├── Identity.py │ ├── Linear.py │ ├── LogSoftmax.py │ ├── ReLU.py │ ├── ReLU6.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── AdaptiveAvgPool2d.cpython-36.pyc │ │ ├── BatchNorm1d.cpython-36.pyc │ │ ├── BatchNorm2d.cpython-36.pyc │ │ ├── Conv1d.cpython-36.pyc │ │ ├── Conv2d.cpython-36.pyc │ │ ├── ConvTranspose2d.cpython-36.pyc │ │ ├── Linear.cpython-36.pyc │ │ ├── LogSoftmax.cpython-36.pyc │ │ ├── ReLU6.cpython-36.pyc │ │ ├── activation.cpython-36.pyc │ │ ├── adaptive_avg_pool2d.cpython-36.pyc │ │ ├── adaptive_max_pool2d.cpython-36.pyc │ │ ├── add.cpython-36.pyc │ │ ├── avg_pool2d.cpython-36.pyc │ │ ├── cat.cpython-36.pyc │ │ ├── chunk.cpython-36.pyc │ │ ├── clamp.cpython-36.pyc │ │ ├── div.cpython-36.pyc │ │ ├── dummy_converters.cpython-36.pyc │ │ ├── getitem.cpython-36.pyc │ │ ├── identity.cpython-36.pyc │ │ ├── instance_norm.cpython-36.pyc │ │ ├── max.cpython-36.pyc │ │ ├── max_pool2d.cpython-36.pyc │ │ ├── mean.cpython-36.pyc │ │ ├── min.cpython-36.pyc │ │ ├── mul.cpython-36.pyc │ │ ├── normalize.cpython-36.pyc │ │ ├── pad.cpython-36.pyc │ │ ├── permute.cpython-36.pyc │ │ ├── pow.cpython-36.pyc │ │ ├── prelu.cpython-36.pyc │ │ ├── prod.cpython-36.pyc │ │ ├── relu.cpython-36.pyc │ │ ├── sigmoid.cpython-36.pyc │ │ ├── softmax.cpython-36.pyc │ │ ├── split.cpython-36.pyc │ │ ├── sub.cpython-36.pyc │ │ ├── sum.cpython-36.pyc │ │ ├── tanh.cpython-36.pyc │ │ ├── transpose.cpython-36.pyc │ │ └── unary.cpython-36.pyc │ ├── activation.py │ ├── adaptive_avg_pool2d.py │ ├── adaptive_max_pool2d.py │ ├── add.py │ ├── avg_pool2d.py │ ├── cat.py │ ├── chunk.py │ ├── clamp.py │ ├── div.py │ ├── dummy_converters.py │ ├── getitem.py │ ├── instance_norm.py │ ├── interpolate │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── interpolate.cpython-36.pyc │ │ │ └── interpolate_pb2.cpython-36.pyc │ │ ├── interpolate.cpp │ │ ├── interpolate.pb.cc │ │ ├── interpolate.pb.h │ │ ├── interpolate.proto │ │ ├── interpolate.py │ │ └── interpolate_pb2.py │ ├── max.py │ ├── max_pool2d.py │ ├── mean.py │ ├── min.py │ ├── mul.py │ ├── normalize.py │ ├── pad.py │ ├── permute.py │ ├── pow.py │ ├── prelu.py │ ├── prod.py │ ├── sigmoid.py │ ├── softmax.py │ ├── split.py │ ├── sub.py │ ├── sum.py │ ├── tanh.py │ ├── transpose.py │ ├── unary.py │ └── view.py ├── init.py ├── libtorch2trt.so ├── module_test.py ├── test.py ├── torch2trt.py └── utils.py └── torchtrt2trt /IC_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision as tv 6 | from torch.autograd import Variable 7 | from PIL import Image,ImageStat 8 | from collections import OrderedDict 9 | import operator 10 | import os 11 | import re 12 | import time 13 | import numpy as np 14 | from torch.nn.parallel import DistributedDataParallel 15 | import torch.distributed as dist 16 | import torch.utils.data.distributed 17 | import sys 18 | sys.path.append('..') 19 | from torch2trt import torch2trt 20 | #define argument 21 | class BaseOptions(): 22 | def __init__(self): 23 | self.parser = argparse.ArgumentParser() 24 | self.initialized = False 25 | def initialize(self): 26 | self.parser.add_argument('--data_path',default='train',help="path to store training data") 27 | self.parser.add_argument('--save_path',type=str,default='test_result/icModule',help='path to save result') 28 | self.parser.add_argument('--image_size',type=int,default=896,help="the size to resize") 29 | self.parser.add_argument('--load_model',type=str,help="path to load trained model") 30 | self.parser.add_argument('--acc',default=False,help='whether to use tensorRT') 31 | def parse(self): 32 | if not self.initialized: 33 | self.initialize() 34 | self.opt=self.parser.parse_args() 35 | args=vars(self.opt) 36 | print('-----------Options----------') 37 | for k,v in sorted(args.items()): 38 | print('%s:%s'%(str(k),str(v))) 39 | print('-------------End------------') 40 | return self.opt 41 | 42 | #define dataset 43 | def make_dataset(dir): 44 | images = [] 45 | path1 = [] 46 | path2 = [] 47 | path3 = [] 48 | path4 = [] 49 | path5 = [] 50 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 51 | for root,dirs,fnames in os.walk(dir): 52 | if len(dirs)==0: 53 | subFolder=[] 54 | for fname in fnames: 55 | path=os.path.join(root,fname) 56 | subFolder.append(path) 57 | images.append(subFolder) 58 | else: 59 | for dirr in dirs: 60 | subFolder = [] 61 | for fname in sorted(os.listdir(dir+'/'+dirr)): 62 | path = os.path.join(dir+'/'+dirr+'/', fname) 63 | subFolder.append(path) 64 | images.append(subFolder) 65 | break 66 | for i in range(0, len(images)): 67 | #data augmentation 68 | for j in range(4, len(images[i])): 69 | path1.extend([images[i][j-4],images[i][j]]) 70 | path2.extend([images[i][j-3],images[i][j-1]]) 71 | path3.extend([images[i][j-2],images[i][j-2]]) 72 | path4.extend([images[i][j-1],images[i][j-3]]) 73 | path5.extend([images[i][j],images[i][j-4]]) 74 | return path1, path2, path3, path4, path5 75 | 76 | def transform(image): 77 | 78 | transforms = tv.transforms.Compose([ 79 | tv.transforms.Resize((opt.image_size, opt.image_size)), 80 | tv.transforms.ToTensor(), 81 | ]) 82 | image = transforms(image) 83 | 84 | return image 85 | 86 | class MyDataset(torch.utils.data.Dataset): 87 | def __init__(self): 88 | self.opt=opt 89 | self.data_path=os.path.join(opt.data_path) 90 | self.path1, self.path2, self.path3, self.path4, self.path5 = make_dataset( 91 | self.data_path) 92 | def __getitem__(self,index): 93 | 94 | A_path1 = self.path1[index] 95 | A_path2 = self.path2[index] 96 | A_path3 = self.path3[index] 97 | A_path4 = self.path4[index] 98 | A_path5 = self.path5[index] 99 | A_image1 = Image.open(A_path1).convert('L') 100 | A_image2 = Image.open(A_path2).convert('L') 101 | A_image3 = Image.open(A_path3).convert('L') 102 | A_image4 = Image.open(A_path4).convert('L') 103 | A_image5 = Image.open(A_path5).convert('L') 104 | image = torch.zeros(5, opt.image_size, opt.image_size) 105 | image[0] = transform(A_image1) 106 | image[1] = transform(A_image2) 107 | image[2] = transform(A_image3) 108 | image[3] = transform(A_image4) 109 | image[4] = transform(A_image5) 110 | 111 | label_name = re.split(r'[/]', A_path5) 112 | label_name = os.path.join(label_name[-2],label_name[-1]) 113 | return image,label_name 114 | 115 | def __len__(self): 116 | return len(self.path1) 117 | 118 | class ResidualBlock(nn.Module): 119 | def __init__(self, in_channels, out_channels, stride=1): 120 | super(ResidualBlock, self).__init__() 121 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) 122 | self.bn1 = nn.BatchNorm2d(out_channels) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 125 | self.bn2 = nn.BatchNorm2d(out_channels) 126 | if in_channels != out_channels: 127 | self.downsample = nn.Sequential( 128 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2), 129 | nn.BatchNorm2d(out_channels) 130 | ) 131 | else: 132 | self.downsample = None 133 | 134 | def forward(self, x): 135 | identity = x 136 | out = self.conv1(x) 137 | out = self.bn1(out) 138 | out = self.relu(out) 139 | out = self.conv2(out) 140 | out = self.bn2(out) 141 | 142 | if self.downsample is not None: 143 | identity = self.downsample(x) 144 | out += identity 145 | out = self.relu(out) 146 | return out 147 | 148 | class PostProcess(nn.Module): 149 | def __init__(self): 150 | super (PostProcess,self).__init__() 151 | self.layer1=ResidualBlock(5, 32, 2) 152 | self.layer2=ResidualBlock(32, 64, 2) 153 | self.layer3=ResidualBlock(64, 128, 2) 154 | self.pad1 = nn.Sequential( 155 | nn.Conv2d(128,64,3,1,1), 156 | nn.BatchNorm2d(64), 157 | nn.ReLU() 158 | ) 159 | self.pad2 = nn.Sequential( 160 | nn.Conv2d(64,32,3,1,1), 161 | nn.BatchNorm2d(32), 162 | nn.ReLU() 163 | ) 164 | self.pad3 = nn.Sequential( 165 | nn.Conv2d(32,1,3,1,1), 166 | ) 167 | def forward(self,x): 168 | x=self.layer1(x) 169 | x=self.layer2(x) 170 | x=self.layer3(x) 171 | x=F.interpolate(x,scale_factor=2,mode='nearest') 172 | x=self.pad1(x) 173 | x=F.interpolate(x,scale_factor=2,mode='nearest') 174 | x=self.pad2(x) 175 | x=F.interpolate(x,scale_factor=2,mode='nearest') 176 | x=self.pad3(x) 177 | x=torch.sigmoid(x) 178 | 179 | return x 180 | def test(): 181 | print('---------------------start------------------------') 182 | ICNet=PostProcess() 183 | ICNet.cuda() 184 | map_location = lambda storage, loc: storage 185 | ICNet.load_state_dict(torch.load(opt.load_model, map_location=map_location),False) 186 | dataset = MyDataset() 187 | dataloader = torch.utils.data.DataLoader(dataset, 188 | batch_size=1, 189 | shuffle=False, 190 | num_workers=1, 191 | drop_last=False 192 | ) 193 | 194 | for ii, (image, labels) in enumerate(dataloader): 195 | # network output 196 | 197 | label_name = labels[0] 198 | save_path = os.path.join(opt.save_path,re.split(r'[/]',label_name)[-2]) 199 | if (os.path.exists(save_path)==False): 200 | os.makedirs(save_path) 201 | 202 | inputs = Variable(image).cuda() 203 | output = ICNet(inputs) 204 | output_numpy = output.data[0,0,:,:].cpu().float().numpy() 205 | output_numpy = output_numpy * 255.0 206 | output_numpy = output_numpy.astype(np.uint8) 207 | output_PIL = Image.fromarray(output_numpy, mode='L') 208 | output_PIL.save('%s/%s'%(opt.save_path,label_name)) 209 | 210 | print('--------------------complete!-----------------------') 211 | 212 | def test_acc(): 213 | print('---------------------start------------------------') 214 | x=torch.ones(1,5,896,896).cuda() 215 | ICNet=PostProcess() 216 | ICNet.cuda() 217 | map_location = lambda storage, loc: storage 218 | ICNet.load_state_dict(torch.load(opt.load_model, map_location=map_location),False) 219 | model_trt = torch2trt(ICNet,[x]) 220 | dataset = MyDataset() 221 | dataloader = torch.utils.data.DataLoader(dataset, 222 | batch_size=1, 223 | shuffle=False, 224 | num_workers=1, 225 | drop_last=False 226 | ) 227 | 228 | for ii, (image, labels) in enumerate(dataloader): 229 | # network output 230 | count += 1 231 | label_name = labels[0] 232 | save_path = os.path.join('res_result',re.split(r'[/]',label_name)[-2]) 233 | if (os.path.exists(save_path)==False): 234 | os.makedirs(save_path) 235 | 236 | inputs = Variable(image).cuda() 237 | output = model_trt(inputs) 238 | output_numpy = output.data[0,0,:,:].cpu().float().numpy() 239 | output_numpy = output_numpy * 255.0 240 | output_numpy = output_numpy.astype(np.uint8) 241 | output_PIL = Image.fromarray(output_numpy, mode='L') 242 | output_PIL.save('%s/%s'%(opt.save_path,label_name)) 243 | 244 | print('--------------------complete!-----------------------') 245 | 246 | if __name__ == "__main__": 247 | opt=BaseOptions().parse() 248 | if opt.acc==True: 249 | test_acc() 250 | else: 251 | test() 252 | -------------------------------------------------------------------------------- /IC_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision as tv 6 | from torch.autograd import Variable 7 | from PIL import Image,ImageStat 8 | from collections import OrderedDict 9 | import operator 10 | import os 11 | import re 12 | import time 13 | import numpy as np 14 | from torch.nn.parallel import DistributedDataParallel 15 | import torch.distributed as dist 16 | import torch.utils.data.distributed 17 | 18 | #define argument 19 | class BaseOptions(): 20 | def __init__(self): 21 | self.parser = argparse.ArgumentParser() 22 | self.initialized = False 23 | def initialize(self): 24 | self.parser.add_argument('--data_path',default='train',help="path to store training data") 25 | self.parser.add_argument('--label_path',default='label',help="path to store label") 26 | self.parser.add_argument('--save_path',type=str,default='IC_checkpoints',help='path to save checkpoints') 27 | self.parser.add_argument('--num_workers',default=0,type=int) 28 | self.parser.add_argument('--batch_size',type=int,default=16) 29 | self.parser.add_argument('--lr',type=float,default=0.001) 30 | self.parser.add_argument('--max_epoch',type=int,default=200) 31 | self.parser.add_argument('--image_size',type=int,default=896,help="the size to resize") 32 | self.parser.add_argument('--vis',default=True,help="whether to use visdom") 33 | self.parser.add_argument('--gpus',action='store_true',default=True,help='whether to use gpu') 34 | self.parser.add_argument('--load_model',type=str,help="path to load pre-trained model") 35 | self.parser.add_argument('--local_rank',type=int) 36 | def parse(self): 37 | if not self.initialized: 38 | self.initialize() 39 | self.opt=self.parser.parse_args() 40 | args=vars(self.opt) 41 | print('-----------Options----------') 42 | for k,v in sorted(args.items()): 43 | print('%s:%s'%(str(k),str(v))) 44 | print('-------------End------------') 45 | return self.opt 46 | 47 | #define dataset 48 | def make_dataset(dir): 49 | images = [] 50 | path1 = [] 51 | path2 = [] 52 | path3 = [] 53 | path4 = [] 54 | path5 = [] 55 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 56 | for root,dirs,fnames in os.walk(dir): 57 | if len(dirs)==0: 58 | subFolder=[] 59 | for fname in fnames: 60 | path=os.path.join(root,fname) 61 | subFolder.append(path) 62 | images.append(subFolder) 63 | else: 64 | for dirr in dirs: 65 | subFolder = [] 66 | for fname in sorted(os.listdir(dir+'/'+dirr)): 67 | path = os.path.join(dir+'/'+dirr+'/', fname) 68 | subFolder.append(path) 69 | images.append(subFolder) 70 | break 71 | for i in range(0, len(images)): 72 | #data augmentation 73 | for j in range(4, len(images[i])): 74 | path1.extend([images[i][j-4],images[i][j]]) 75 | path2.extend([images[i][j-3],images[i][j-1]]) 76 | path3.extend([images[i][j-2],images[i][j-2]]) 77 | path4.extend([images[i][j-1],images[i][j-3]]) 78 | path5.extend([images[i][j],images[i][j-4]]) 79 | return path1, path2, path3, path4, path5 80 | 81 | def transform(image): 82 | 83 | transforms = tv.transforms.Compose([ 84 | tv.transforms.Resize((opt.image_size, opt.image_size)), 85 | tv.transforms.ToTensor(), 86 | ]) 87 | image = transforms(image) 88 | 89 | return image 90 | 91 | class MyDataset(torch.utils.data.Dataset): 92 | def __init__(self): 93 | self.opt=opt 94 | self.data_path=os.path.join(opt.data_path) 95 | self.path1, self.path2, self.path3, self.path4, self.path5 = make_dataset( 96 | self.data_path) 97 | def __getitem__(self,index): 98 | 99 | A_path1 = self.path1[index] 100 | A_path2 = self.path2[index] 101 | A_path3 = self.path3[index] 102 | A_path4 = self.path4[index] 103 | A_path5 = self.path5[index] 104 | A_image1 = Image.open(A_path1).convert('L') 105 | A_image2 = Image.open(A_path2).convert('L') 106 | A_image3 = Image.open(A_path3).convert('L') 107 | A_image4 = Image.open(A_path4).convert('L') 108 | A_image5 = Image.open(A_path5).convert('L') 109 | image = torch.zeros(5, opt.image_size, opt.image_size) 110 | image[0] = transform(A_image1) 111 | image[1] = transform(A_image2) 112 | image[2] = transform(A_image3) 113 | image[3] = transform(A_image4) 114 | image[4] = transform(A_image5) 115 | 116 | label_name = re.split(r'[/]', A_path5) 117 | label_name = os.path.join(label_name[-2],label_name[-1]) 118 | return image,label_name 119 | 120 | def __len__(self): 121 | return len(self.path1) 122 | 123 | class ResidualBlock(nn.Module): 124 | def __init__(self, in_channels, out_channels, stride=1): 125 | super(ResidualBlock, self).__init__() 126 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) 127 | self.bn1 = nn.BatchNorm2d(out_channels) 128 | self.relu = nn.ReLU(inplace=True) 129 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 130 | self.bn2 = nn.BatchNorm2d(out_channels) 131 | if in_channels != out_channels: 132 | self.downsample = nn.Sequential( 133 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2), 134 | nn.BatchNorm2d(out_channels) 135 | ) 136 | else: 137 | self.downsample = None 138 | 139 | def forward(self, x): 140 | identity = x 141 | out = self.conv1(x) 142 | out = self.bn1(out) 143 | out = self.relu(out) 144 | out = self.conv2(out) 145 | out = self.bn2(out) 146 | 147 | if self.downsample is not None: 148 | identity = self.downsample(x) 149 | out += identity 150 | out = self.relu(out) 151 | return out 152 | 153 | class PostProcess(nn.Module): 154 | def __init__(self): 155 | super (PostProcess,self).__init__() 156 | self.layer1=ResidualBlock(5, 32, 2) 157 | self.layer2=ResidualBlock(32, 64, 2) 158 | self.layer3=ResidualBlock(64, 128, 2) 159 | self.pad1 = nn.Sequential( 160 | nn.Conv2d(128,64,3,1,1), 161 | nn.BatchNorm2d(64), 162 | nn.ReLU() 163 | ) 164 | self.pad2 = nn.Sequential( 165 | nn.Conv2d(64,32,3,1,1), 166 | nn.BatchNorm2d(32), 167 | nn.ReLU() 168 | ) 169 | self.pad3 = nn.Sequential( 170 | nn.Conv2d(32,1,3,1,1), 171 | ) 172 | def forward(self,x): 173 | x=self.layer1(x) 174 | x=self.layer2(x) 175 | x=self.layer3(x) 176 | x=F.interpolate(x,scale_factor=2,mode='nearest') 177 | x=self.pad1(x) 178 | x=F.interpolate(x,scale_factor=2,mode='nearest') 179 | x=self.pad2(x) 180 | x=F.interpolate(x,scale_factor=2,mode='nearest') 181 | x=self.pad3(x) 182 | 183 | return x 184 | 185 | # visulize 186 | class Visualizer(): 187 | def __init__(self, opt): 188 | self.display_id = 1 189 | self.win_size = 256 190 | self.name = 'detection loss' 191 | if self.display_id: 192 | import visdom 193 | self.vis = visdom.Visdom(env='main', port=8097) 194 | def plot_current_errors(self, epoch, count_ratio, opt, errors): 195 | if not hasattr(self, 'plot_data'): 196 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 197 | self.plot_data['X'].append(epoch + count_ratio) 198 | for k in self.plot_data['legend']: 199 | errors=errors[k].cpu().numpy() 200 | self.plot_data['Y'].append([errors]) 201 | self.vis.line( 202 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1) if len(self.plot_data['X'])==1 else np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1).squeeze(), 203 | Y=np.array(self.plot_data['Y']) if len(self.plot_data['Y'])==1 else np.array(self.plot_data['Y']).squeeze(), 204 | opts={ 205 | 'title': self.name + ' loss over time', 206 | 'legend': self.plot_data['legend'], 207 | 'xlabel': 'epoch', 208 | 'ylabel': 'loss'}, 209 | win=self.display_id) 210 | def print_current_errors(epoch, i, errors,t): 211 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i,t) 212 | for k, v in errors.items(): 213 | message += '%s: %.8f ' % (k, v) 214 | print(message) 215 | 216 | def train(): 217 | dist.init_process_group(backend='nccl', init_method='env://') 218 | device = torch.device('cuda', opt.local_rank) 219 | ICNet=PostProcess() 220 | ICNet = nn.SyncBatchNorm.convert_sync_batchnorm(ICNet) 221 | criterion=torch.nn.BCEWithLogitsLoss() 222 | optimizer=torch.optim.Adam(ICNet.parameters(),opt.lr,betas=(0.9, 0.999)) 223 | if opt.vis: 224 | vis=Visualizer(opt) 225 | if opt.gpus: 226 | ICNet=ICNet.to(device) 227 | criterion=criterion.to(device) 228 | ICNet = torch.nn.parallel.DistributedDataParallel(ICNet, device_ids=[opt.local_rank], output_device=opt.local_rank,find_unused_parameters=True) 229 | dataset = MyDataset() 230 | train_sampler=torch.utils.data.distributed.DistributedSampler(dataset) 231 | dataloader=torch.utils.data.DataLoader(dataset, 232 | batch_size=opt.batch_size, 233 | shuffle=(train_sampler is None), 234 | num_workers=opt.num_workers, 235 | drop_last=True, 236 | pin_memory=True 237 | ) 238 | for epoch in range(1,opt.max_epoch+1): 239 | epoch_iter=0 240 | for ii,(images,labels) in enumerate(dataloader): 241 | iter_start_time=time.time() 242 | epoch_iter+=opt.batch_size 243 | inputs=Variable(images) 244 | optimizer.zero_grad() 245 | outputs=ICNet(inputs.to(device)) 246 | target=np.zeros((opt.batch_size,1,896,896)) 247 | for l in range(0,opt.batch_size): 248 | label=Image.open('%s/%s'%(opt.label_path,labels[l])).convert('L') 249 | label=tv.transforms.Resize((opt.image_size,opt.image_size))(label) 250 | target[l,0,:,:]=label 251 | target=torch.Tensor(target) 252 | pre_outputs=Variable(target).to(device) 253 | loss = criterion(outputs, pre_outputs) 254 | loss.backward() 255 | optimizer.step() 256 | errors = get_current_errors(loss) 257 | if opt.local_rank==0: 258 | if (ii+1)% 100 == 0: 259 | ti = (time.time() - iter_start_time) / opt.batch_size 260 | print_current_errors(epoch, epoch_iter, errors, ti) 261 | if opt.vis and (ii+1)% 100 == 0: 262 | with open('ICmodule_loss.txt','a') as f: 263 | vdl = 'epoch:%d ICmodule_loss:%.10f'%(epoch, loss) 264 | f.write(vdl + '\n') 265 | f.close() 266 | load_rate = float(epoch_iter)/dataset.__len__() 267 | vis.plot_current_errors(epoch, load_rate, opt, errors) 268 | if epoch % 1 ==0 and opt.local_rank==0: 269 | torch.save(ICNet.module.state_dict(), './%s/%s.pth'%(opt.save_path,str(epoch))) 270 | print('complete!') 271 | 272 | def get_current_errors(loss): 273 | return OrderedDict([('ResnetLoss', loss.data)]) 274 | 275 | if __name__ == "__main__": 276 | opt=BaseOptions().parse() 277 | train() 278 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Infrared Small and Dim Target Detection 4 | 5 | *Our implements contain two parts : SDDNet and IC-Module.* 6 | *SDDNet: Small and Dim Object Detection. Using segmentation pictures as the labels.* 7 | *IC-Module: Inter-frame correlation Module. Using the inter-frame correlation information aim to reduce the false alarm rate.* 8 | 9 | ## Environments 10 | pytorch==1.3.0 11 | 12 | torchvision==0.5.0 13 | 14 | python==3.7 15 | 16 | visdom 17 | 18 | torch2trt (refer to :https://github.com/NVIDIA-AI-IOT/torch2trt) 19 | 20 | ## Inference 21 | To run inference, you can type the following commands: 22 | ```python 23 | python SDD_test.py --data_path (image to inference) --load_model (trained model) 24 | python IC_test.py --data_path (image to inference) --load_model (trained model) 25 | ``` 26 | And the result will be saved in ```test_result/SDD``` or ```test_result/IC``` folder if not specified. 27 | 28 | **More parameters:** 29 | 30 |   --save_path : path to save result. 31 |   --acc : if True, use tensorRT to accelerate inference. 32 | ## Training 33 | DDP mode is adopted in both SDDNet and IC-Module, and 4 Gpus are used for training. 34 | 35 | To run training scripts, you can type the following commands: 36 | 37 | ```python 38 | python -m torch.distributed.launch --nproc_per_node 4 SDD_train.py --data_path (training dataset path) --label_path (label path) & nohup visdom 39 | python -m torch.distributed.launch --nproc_per_node 4 IC_train.py --data_path (training dataset path) --label_path (label path) & nohup visdom 40 | ``` 41 | And the trained model will be saved in sdd_checkpoints or ic_checkpoints folder if not specified. 42 | 43 | **More Parameters:** 44 | 45 |   --save_path : path to save checkpoints. 46 |   --vis : whether to visualize, default True.  47 |   --load_model : path to load pre-trained model. 48 | ## Others(in scripts folder) 49 | 1. If you want to train the model with your own custom dataset, you should prepare binarized segmentation as the label. We provide the script ```binary.py``` to convert images to binarized images. 50 | 51 | 2. We provide a script ```DataAug.py ``` to enhance prepared dataset before training, the default folders where storing the training data and labels are ‘train’ and ‘label’ respectively. 52 | 53 | 3. We provide a function to binary result images, and you can call the ```bwfunction``` to use it. 54 | 55 | 4. We also provide scripts to calculate PD & FA of our inference result. And the same statistical method was used in the comparative experiment. 56 | ```python 57 | Command : python statistic.py --image_path (path saved result) --label_path (path saved ground truth) --width 896 --height 896 58 | ``` 59 | 60 | -------------------------------------------------------------------------------- /SDD_test.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import sys 3 | sys.path.append('..') 4 | from torch2trt import torch2trt 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import re 10 | import torchvision as tv 11 | from torch.autograd import Variable 12 | import torch.nn.functional as F 13 | import os 14 | from PIL import Image 15 | import cv2 16 | from collections import OrderedDict 17 | import time 18 | from scipy.io import loadmat 19 | import operator 20 | from torch.nn.parallel import DistributedDataParallel 21 | import torch.distributed as dist 22 | import torch.utils.data.distributed 23 | 24 | #define argument 25 | class BaseOptions(): 26 | def __init__(self): 27 | self.parser = argparse.ArgumentParser() 28 | self.initialized = False 29 | 30 | def initialize(self): 31 | self.parser.add_argument('--data_path', default='train',help='# path to store training data ') 32 | self.parser.add_argument('--image_size', type=int, default=896, help='# the image size to resize') 33 | self.parser.add_argument('--load_model', type=str, default=None, help='# load the trained model') 34 | self.parser.add_argument('--save_path',type=str,default='test_result/SDD',help='path to save result') 35 | self.parser.add_argument('--acc',default=False,help='whether to use tensorRT') 36 | def parse(self): 37 | if not self.initialized: 38 | self.initialize() 39 | self.opt = self.parser.parse_args() 40 | args = vars(self.opt) 41 | print('------------ Options -------------') 42 | for k, v in sorted(args.items()): 43 | print('%s: %s' % (str(k), str(v))) 44 | print('-------------- End ----------------') 45 | return self.opt 46 | 47 | 48 | #define dataset 49 | image_EXTENSIONS = [ 50 | '.jpg', '.JPG', '.jpeg', '.JPEG', 51 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 52 | ] 53 | 54 | 55 | def is_image_file(filename): 56 | return any(filename.endswith(extension) for extension in image_EXTENSIONS) 57 | 58 | def make_dataset(dir): 59 | images = [] 60 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 61 | for root, dirs, fnames in sorted(os.walk(dir)): 62 | for fname in fnames: 63 | path = os.path.join(root, fname) 64 | images.append(path) 65 | 66 | return images 67 | 68 | def transform(image): 69 | 70 | transforms = tv.transforms.Compose([ 71 | tv.transforms.Resize((opt.image_size,opt.image_size)), 72 | tv.transforms.ToTensor(), 73 | ]) 74 | image=transforms(image) 75 | mean=image.mean() 76 | std=image.std() 77 | if opt.acc==True: 78 | image=tv.transforms.Normalize(np.array([mean],dtype=float),np.array([std],dtype=float))(image) 79 | else: 80 | image=tv.transforms.Normalize([mean],[std])(image) 81 | return image 82 | class MyDataset(torch.utils.data.Dataset): 83 | def __init__(self, validata=False): 84 | self.opt = opt 85 | 86 | if validata: 87 | self.root = opt.validata_path 88 | self.dir= os.path.join(opt.validata_path) 89 | else: 90 | self.root = opt.data_path 91 | self.dir = os.path.join(opt.data_path) 92 | self.path = make_dataset(self.dir) 93 | 94 | 95 | def __getitem__(self, index) : 96 | image_path = self.path[index] 97 | label_name = re.split(r'[/]', image_path) 98 | label_name = os.path.join(label_name[-2],label_name[-1]) 99 | image = Image.open(image_path).convert('L') 100 | image = transform(image) 101 | return image, label_name 102 | 103 | def __len__(self): 104 | 105 | return len(self.path) 106 | 107 | 108 | #define network 109 | 110 | class ResidualBlock(nn.Module): 111 | def __init__(self, in_channels, out_channels, stride=1): 112 | super(ResidualBlock, self).__init__() 113 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) 114 | self.bn1 = nn.BatchNorm2d(out_channels) 115 | self.relu = nn.ReLU(inplace=True) 116 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 117 | self.bn2 = nn.BatchNorm2d(out_channels) 118 | if in_channels != out_channels: 119 | self.downsample = nn.Sequential( 120 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2), 121 | nn.BatchNorm2d(out_channels) 122 | ) 123 | else: 124 | self.downsample = None 125 | 126 | def forward(self, x): 127 | identity = x 128 | out = self.conv1(x) 129 | out = self.bn1(out) 130 | out = self.relu(out) 131 | out = self.conv2(out) 132 | out = self.bn2(out) 133 | 134 | if self.downsample is not None: 135 | identity = self.downsample(x) 136 | out += identity 137 | out = self.relu(out) 138 | return out 139 | 140 | class ResNet(nn.Module): 141 | def __init__(self): 142 | super(ResNet, self).__init__() 143 | self.first = nn.Sequential( 144 | nn.Conv2d(1, 32, 7, 2, 3), 145 | nn.BatchNorm2d(32), 146 | nn.ReLU(), 147 | nn.Conv2d(32, 32, 3, 2, 1), 148 | nn.BatchNorm2d(32), 149 | nn.ReLU() 150 | ) 151 | self.layer1 = self.make_layer(32, 32, 3, 1) #in_channel=64,out_channel=64,block_num=3,stride=1 152 | self.layer2 = self.make_layer(32, 64, 4, 2) #in_channel=64,out_channel=128,block_num=4,stride=2 153 | self.layer3 = self.make_layer(64, 128, 6, 2) #in_channel=128,out_channel=256,block_num=6,stride=2 154 | self.layer4 = self.make_layer(128, 256, 3, 2) #in_channel=256,out_channel=512,block_num=3,stride=2 155 | self.layer5 = ResidualBlock(256, 512, 2) 156 | self.layer6 = ResidualBlock(512,1024,2) 157 | self.end = nn.Sequential( 158 | nn.Upsample(scale_factor=2, mode='nearest'), 159 | nn.Conv2d(1024, 512, 3, 1,1), 160 | nn.BatchNorm2d(512), 161 | nn.ReLU(), 162 | nn.Upsample(scale_factor=2, mode='nearest'), 163 | nn.Conv2d(512, 256, 3, 1,1), 164 | nn.BatchNorm2d(256), 165 | nn.ReLU(), 166 | nn.Upsample(scale_factor=2, mode='nearest'), 167 | nn.Conv2d(256, 128, 3, 1,1), 168 | nn.BatchNorm2d(128), 169 | nn.ReLU(), 170 | nn.Upsample(scale_factor=2, mode='nearest'), 171 | nn.Conv2d(128, 64, 3, 1,1), 172 | nn.BatchNorm2d(64), 173 | nn.ReLU(), 174 | nn.Upsample(scale_factor=2, mode='nearest'), 175 | nn.Conv2d(64, 32, 3, 1,1), 176 | nn.BatchNorm2d(32), 177 | nn.ReLU(), 178 | nn.Upsample(scale_factor=2, mode='nearest'), 179 | nn.Conv2d(32, 32, 3, 1,1), 180 | nn.BatchNorm2d(32), 181 | nn.ReLU(), 182 | nn.Upsample(scale_factor=2,mode='nearest'), 183 | nn.Conv2d(32, 1, 3, 1,1), 184 | ) 185 | 186 | def make_layer(self, in_channels, out_channels, block_num, stride): 187 | layers = [] 188 | layers.append(ResidualBlock(in_channels, out_channels, stride)) 189 | for i in range(1, block_num): 190 | layers.append(ResidualBlock(out_channels, out_channels, 1)) 191 | 192 | return nn.Sequential(*layers) 193 | 194 | def forward(self, x): 195 | x = self.first(x) 196 | x = self.layer1(x) 197 | x = self.layer2(x) 198 | x = self.layer3(x) 199 | x = self.layer4(x) 200 | x = self.layer5(x) 201 | x = self.layer6(x) 202 | x = self.end(x) 203 | x = torch.sigmoid(x) 204 | return x 205 | 206 | def test(): 207 | print('---------------------start------------------------') 208 | 209 | SDDNet = ResNet().eval() 210 | SDDNet.cuda() 211 | 212 | map_location = lambda storage, loc: storage 213 | SDDNet.load_state_dict(torch.load(opt.load_model, map_location=map_location),False) 214 | dataset = MyDataset() 215 | dataloader = torch.utils.data.DataLoader(dataset, 216 | batch_size=1, 217 | shuffle=False, 218 | num_workers=1, 219 | drop_last=False 220 | ) 221 | 222 | for ii, (image, labels) in enumerate(dataloader): 223 | label_name = labels[0] 224 | save_path = os.path.join('test_result',re.split(r'[/]',label_name)[-2]) 225 | if (os.path.exists(save_path)==False): 226 | os.makedirs(save_path) 227 | inputs = Variable(image).cuda() 228 | output = SDDNet(inputs) 229 | output_numpy = output.data[0,0,:,:].cpu().float().numpy() 230 | output_numpy = output_numpy * 255.0 231 | output_numpy = output_numpy.astype(np.uint8) 232 | output_PIL = Image.fromarray(output_numpy, mode='L') 233 | output_PIL.save('./test_result/%s'%(label_name)) 234 | print('--------------------complete!-----------------------') 235 | 236 | 237 | def acc_test(): 238 | print('---------------------start------------------------') 239 | SDDNet = ResNet().eval() 240 | SDDNet.cuda() 241 | x=torch.ones(1,1,896,896).cuda() 242 | model_trt = torch2trt(resnet,[x]) 243 | map_location = lambda storage, loc: storage 244 | SDDNet.load_state_dict(torch.load(opt.load_model, map_location=map_location),False) 245 | dataset = MyDataset() 246 | dataloader = torch.utils.data.DataLoader(dataset, 247 | batch_size=1, 248 | shuffle=False, 249 | num_workers=1, 250 | drop_last=False 251 | ) 252 | for ii, (image, labels) in enumerate(dataloader): 253 | label_name = labels[0] 254 | save_path = os.path.join('%s'%(opt.save_path),re.split(r'[/]',label_name)[-2]) 255 | if (os.path.exists(save_path)==False): 256 | os.makedirs(save_path) 257 | inputs = Variable(image).cuda() 258 | output = model_trt(inputs) 259 | output_numpy = output.data[0,0,:,:].cpu().float().numpy() 260 | output_numpy = output_numpy * 255.0 261 | output_numpy = output_numpy.astype(np.uint8) 262 | output_PIL = Image.fromarray(output_numpy, mode='L') 263 | output_PIL.save('./test_result/%s'%(label_name)) 264 | print('--------------------complete!-----------------------') 265 | 266 | 267 | if __name__ == "__main__": 268 | opt = BaseOptions().parse() 269 | if opt.acc==True: 270 | acc_test() 271 | else: 272 | test() 273 | -------------------------------------------------------------------------------- /SDD_train.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import re 8 | import torchvision as tv 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import os 12 | from PIL import Image 13 | import cv2 14 | from collections import OrderedDict 15 | import time 16 | from scipy.io import loadmat 17 | import operator 18 | from torch.nn.parallel import DistributedDataParallel 19 | import torch.distributed as dist 20 | import torch.utils.data.distributed 21 | 22 | #define argument 23 | class BaseOptions(): 24 | def __init__(self): 25 | self.parser = argparse.ArgumentParser() 26 | self.initialized = False 27 | 28 | def initialize(self): 29 | self.parser.add_argument('--data_path', default='train',help='# path to store training data ') 30 | self.parser.add_argument('--save_path',type=str,default='sdd_checkpoints',help='# path to save checkpoints') 31 | self.parser.add_argument('--label_path', default='label',help='# path to store label') 32 | self.parser.add_argument('--batch_size', type=int, default=64) 33 | self.parser.add_argument('--num_workers', default=2, type=int) 34 | self.parser.add_argument('--image_size', type=int, default=896, help='# the image size to resize') 35 | self.parser.add_argument('--max_epoch', type=int, default=300, help='# epoch count') 36 | self.parser.add_argument('--lr', type=float, default=0.0001, help='# learning rate') 37 | self.parser.add_argument('--gpus', action='store_true', default=True, help='# whether to use gpu') 38 | self.parser.add_argument('--vis', default=True, help='# whether to use visdom visulizer') 39 | self.parser.add_argument('--env', type=str, default='main', help='# visdom env') 40 | self.parser.add_argument('--print_every', type=int, default=50, help='# batchsize interval to print error') 41 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 42 | self.parser.add_argument('--display_port', type=int, default=8097, help='# visdom port of the web display') 43 | self.parser.add_argument('--load_model', type=str, default=None, help='# path to load the pre-trained model') 44 | self.parser.add_argument('--local_rank', type=int,help='# use multi GPU to train') 45 | 46 | def parse(self): 47 | if not self.initialized: 48 | self.initialize() 49 | self.opt = self.parser.parse_args() 50 | args = vars(self.opt) 51 | if self.opt.local_rank==0: 52 | print('------------ Options -------------') 53 | for k, v in sorted(args.items()): 54 | print('%s: %s' % (str(k), str(v))) 55 | print('-------------- End ----------------') 56 | return self.opt 57 | 58 | 59 | #define dataset 60 | image_EXTENSIONS = [ 61 | '.jpg', '.JPG', '.jpeg', '.JPEG', 62 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 63 | ] 64 | 65 | 66 | def is_image_file(filename): 67 | return any(filename.endswith(extension) for extension in image_EXTENSIONS) 68 | 69 | def make_dataset(dir): 70 | images = [] 71 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 72 | for root, dirs, fnames in sorted(os.walk(dir)): 73 | for fname in fnames: 74 | path = os.path.join(root, fname) 75 | images.append(path) 76 | 77 | return images 78 | 79 | def transform(image): 80 | 81 | transforms = tv.transforms.Compose([ 82 | tv.transforms.Resize((opt.image_size,opt.image_size)), 83 | tv.transforms.ToTensor(), 84 | ]) 85 | image=transforms(image) 86 | mean=image.mean() 87 | std=image.std() 88 | image=tv.transforms.Normalize([mean],[std])(image) 89 | return image 90 | class MyDataset(torch.utils.data.Dataset): 91 | def __init__(self, validata=False): 92 | self.opt = opt 93 | 94 | if validata: 95 | self.root = opt.validata_path 96 | self.dir= os.path.join(opt.validata_path) 97 | else: 98 | self.root = opt.data_path 99 | self.dir = os.path.join(opt.data_path) 100 | self.path = make_dataset(self.dir) 101 | 102 | 103 | def __getitem__(self, index) : 104 | image_path = self.path[index] 105 | label_name = re.split(r'[/]', image_path) 106 | label_name = os.path.join(label_name[-2],label_name[-1]) 107 | image = Image.open(image_path).convert('L') 108 | image = transform(image) 109 | return image, label_name 110 | 111 | def __len__(self): 112 | 113 | return len(self.path) 114 | 115 | 116 | #define network 117 | 118 | class ResidualBlock(nn.Module): 119 | def __init__(self, in_channels, out_channels, stride=1): 120 | super(ResidualBlock, self).__init__() 121 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) 122 | self.bn1 = nn.BatchNorm2d(out_channels) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 125 | self.bn2 = nn.BatchNorm2d(out_channels) 126 | if in_channels != out_channels: 127 | self.downsample = nn.Sequential( 128 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2), 129 | nn.BatchNorm2d(out_channels) 130 | ) 131 | else: 132 | self.downsample = None 133 | 134 | def forward(self, x): 135 | identity = x 136 | out = self.conv1(x) 137 | out = self.bn1(out) 138 | out = self.relu(out) 139 | out = self.conv2(out) 140 | out = self.bn2(out) 141 | 142 | if self.downsample is not None: 143 | identity = self.downsample(x) 144 | out += identity 145 | out = self.relu(out) 146 | return out 147 | 148 | class ResNet(nn.Module): 149 | def __init__(self): 150 | super(ResNet, self).__init__() 151 | self.first = nn.Sequential( 152 | nn.Conv2d(1, 32, 7, 2, 3), 153 | nn.BatchNorm2d(32), 154 | nn.ReLU(), 155 | nn.Conv2d(32, 32, 3, 2, 1), 156 | nn.BatchNorm2d(32), 157 | nn.ReLU() 158 | ) 159 | self.layer1 = self.make_layer(32, 32, 3, 1) #in_channel=64,out_channel=64,block_num=3,stride=1 160 | self.layer2 = self.make_layer(32, 64, 4, 2) #in_channel=64,out_channel=128,block_num=4,stride=2 161 | self.layer3 = self.make_layer(64, 128, 6, 2) #in_channel=128,out_channel=256,block_num=6,stride=2 162 | self.layer4 = self.make_layer(128, 256, 3, 2) #in_channel=256,out_channel=512,block_num=3,stride=2 163 | self.layer5 = ResidualBlock(256, 512, 2) 164 | self.layer6 = ResidualBlock(512,1024,2) 165 | self.end = nn.Sequential( 166 | nn.Upsample(scale_factor=2, mode='nearest'), 167 | nn.Conv2d(1024, 512, 3, 1,1), 168 | nn.BatchNorm2d(512), 169 | nn.ReLU(), 170 | nn.Upsample(scale_factor=2, mode='nearest'), 171 | nn.Conv2d(512, 256, 3, 1,1), 172 | nn.BatchNorm2d(256), 173 | nn.ReLU(), 174 | nn.Upsample(scale_factor=2, mode='nearest'), 175 | nn.Conv2d(256, 128, 3, 1,1), 176 | nn.BatchNorm2d(128), 177 | nn.ReLU(), 178 | nn.Upsample(scale_factor=2, mode='nearest'), 179 | nn.Conv2d(128, 64, 3, 1,1), 180 | nn.BatchNorm2d(64), 181 | nn.ReLU(), 182 | nn.Upsample(scale_factor=2, mode='nearest'), 183 | nn.Conv2d(64, 32, 3, 1,1), 184 | nn.BatchNorm2d(32), 185 | nn.ReLU(), 186 | nn.Upsample(scale_factor=2, mode='nearest'), 187 | nn.Conv2d(32, 32, 3, 1,1), 188 | nn.BatchNorm2d(32), 189 | nn.ReLU(), 190 | nn.Upsample(scale_factor=2,mode='nearest'), 191 | nn.Conv2d(32, 1, 3, 1,1), 192 | ) 193 | 194 | def make_layer(self, in_channels, out_channels, block_num, stride): 195 | layers = [] 196 | layers.append(ResidualBlock(in_channels, out_channels, stride)) 197 | for i in range(1, block_num): 198 | layers.append(ResidualBlock(out_channels, out_channels, 1)) 199 | 200 | return nn.Sequential(*layers) 201 | 202 | def forward(self, x): 203 | x = self.first(x) 204 | x = self.layer1(x) 205 | x = self.layer2(x) 206 | x = self.layer3(x) 207 | x = self.layer4(x) 208 | x = self.layer5(x) 209 | x = self.layer6(x) 210 | x = self.end(x) 211 | return x 212 | 213 | 214 | # visulize 215 | class Visualizer(): 216 | def __init__(self, opt): 217 | self.display_id = 1 218 | self.win_size = 256 219 | self.name = 'detection loss' 220 | if opt.vis==True: 221 | import visdom 222 | self.vis = visdom.Visdom(env=opt.env, port=opt.display_port) 223 | 224 | def plot_current_errors(self, epoch, count_ratio, opt, errors): 225 | if not hasattr(self, 'plot_data'): 226 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 227 | self.plot_data['X'].append(epoch + count_ratio) 228 | for k in self.plot_data['legend']: 229 | errors=errors[k].cpu().numpy() 230 | self.plot_data['Y'].append([errors]) 231 | self.vis.line( 232 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1) if len(self.plot_data['X'])==1 else np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1).squeeze(), 233 | Y=np.array(self.plot_data['Y']) if len(self.plot_data['Y'])==1 else np.array(self.plot_data['Y']).squeeze(), 234 | opts={ 235 | 'title': self.name + ' loss over time', 236 | 'legend': self.plot_data['legend'], 237 | 'xlabel': 'epoch', 238 | 'ylabel': 'loss'}, 239 | win=self.display_id) 240 | def print_current_errors(epoch, i, errors,t): 241 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i,t) 242 | for k, v in errors.items(): 243 | message += '%s: %.8f ' % (k, v) 244 | print(message) 245 | 246 | #train network 247 | def train(): 248 | if opt.vis: 249 | vis = Visualizer(opt) 250 | 251 | SDDNet = ResNet() 252 | SDDNet = nn.SyncBatchNorm.convert_sync_batchnorm(SDDNet) 253 | dist.init_process_group(backend='nccl', init_method='env://') 254 | 255 | if opt.load_model: 256 | map_location = lambda storage, loc: storage 257 | SDDNet.load_state_dict(torch.load(opt.load_model, map_location=map_location)) 258 | 259 | criterion=nn.BCEWithLogitsLoss() 260 | optimizer = torch.optim.Adam(SDDNet.parameters(), opt.lr, betas=(0.5, 0.999)) 261 | 262 | if opt.gpus: 263 | device = torch.device('cuda', opt.local_rank) 264 | SDDNet=SDDNet.to(device) 265 | criterion=criterion.to(device) 266 | SDDNet = torch.nn.parallel.DistributedDataParallel(SDDNet, device_ids=[opt.local_rank], output_device=opt.local_rank,find_unused_parameters=True) 267 | dataset = MyDataset() 268 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 269 | dataloader = torch.utils.data.DataLoader(dataset, 270 | pin_memory=True, 271 | batch_size=opt.batch_size, 272 | shuffle=(train_sampler is None), 273 | num_workers=opt.num_workers, 274 | sampler=train_sampler, 275 | drop_last=True 276 | ) 277 | 278 | for epoch in range(1,opt.max_epoch+1): 279 | epoch_iter = 0 280 | for ii,(image, labels) in enumerate(dataloader): 281 | iter_start_time = time.time() 282 | epoch_iter += opt.batch_size 283 | inputs = Variable(image) 284 | optimizer.zero_grad() 285 | outputs = SDDNet(inputs.to(device)) 286 | target = np.zeros((opt.batch_size,1,opt.image_size, opt.image_size)) 287 | for l in range(0,opt.batch_size): 288 | mask = labels[l] 289 | mask = Image.open('%s/%s'%(opt.label_path,mask)).convert('L') 290 | mask = tv.transforms.Resize((opt.image_size,opt.image_size))(mask) 291 | target[l,0,:,:] = mask 292 | target = torch.Tensor(target) 293 | pre_outputs = Variable(target).to(device) 294 | weights=torch.empty_like(target).fill_(0.0141) 295 | weights[target==1]=0.9859 296 | criterion=nn.BCEWithLogitsLoss(weight=weights).to(device) 297 | loss = criterion(outputs, pre_outputs) 298 | loss.backward() 299 | optimizer.step() 300 | if opt.local_rank==0: 301 | errors = get_current_errors(loss) 302 | if (ii+1)% opt.print_every == 0: 303 | 304 | ti = (time.time() - iter_start_time) / opt.batch_size 305 | print_current_errors(epoch, epoch_iter, errors, ti) 306 | 307 | if opt.vis and (ii+1)% opt.print_every == 0: 308 | with open('training_loss.txt','a') as f: 309 | vdl = 'epoch:%d training loss:%.10f'%(epoch, loss) 310 | f.write(vdl + '\n') 311 | f.close() 312 | load_rate = float(epoch_iter)/dataset.__len__() 313 | vis.plot_current_errors(epoch, load_rate, opt, errors) 314 | if opt.local_rank ==0 and epoch % 2 ==0: 315 | torch.save(SDDNet.module.state_dict(), './%s/%s.pth'%(opt.save_path, str(epoch))) 316 | 317 | print('complete!') 318 | 319 | 320 | def get_current_errors(loss): 321 | return OrderedDict([('ResnetLoss', loss.data)]) 322 | 323 | if __name__ == "__main__": 324 | 325 | 326 | opt = BaseOptions().parse() 327 | train() 328 | 329 | -------------------------------------------------------------------------------- /checkpoints/icModule.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/checkpoints/icModule.pth -------------------------------------------------------------------------------- /checkpoints/sdd.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:95cfce1b00024e3dceaf095c0591c9bc3a29e2c589eeffa816daf119c3c75e10 3 | size 120140450 4 | -------------------------------------------------------------------------------- /scripts/DataAug.py: -------------------------------------------------------------------------------- 1 | 2 | from torchvision import transforms 3 | from PIL import Image,ImageStat,ImageEnhance 4 | import cv2 5 | import numpy 6 | import re 7 | import os 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | 20 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 21 | for root, dirs, fnames in sorted(os.walk(dir)): 22 | for dirr in dirs: 23 | if not os.path.exists('train/flip_'+dirr): 24 | os.makedirs('train/flip_'+dirr) 25 | if not os.path.exists('label/flip_'+dirr): 26 | os.makedirs('label/flip_'+dirr) 27 | for fname in sorted(os.listdir(dir+'/'+dirr)): 28 | if is_image_file(fname): 29 | path = os.path.join(dir+'/'+dirr+'/', fname) 30 | 31 | image=Image.open(path).convert('L') 32 | label=Image.open('relabel/'+dirr+'/'+fname).convert('L') 33 | image=transforms.RandomHorizontalFlip(p=1)(image) 34 | lable=transforms.RandomHorizontalFlip(p=1)(label) 35 | image.save('train/flip_'+dirr+'/'+fname) 36 | label.save('label/flip_'+dirr+'/'+fname) 37 | if __name__=="__main__": 38 | make_dataset('train/') 39 | -------------------------------------------------------------------------------- /scripts/binary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | path = os.walk("label/") 5 | for root,dirs,files in path: 6 | for filename in files: 7 | I=cv2.imread(os.path.join(root,filename),0) 8 | I=np.where(I==0,0,1) 9 | cv2.imwrite(os.path.join(root,filename),I) 10 | print('saved to'+filename) 11 | 12 | 13 | -------------------------------------------------------------------------------- /scripts/bwfunction.m: -------------------------------------------------------------------------------- 1 | function bw = bwfunc(src1, src2) 2 | fold = dir(fullfile(src1,'*.bmp')); 3 | refold = src2; 4 | for kk = 1:length(fold) 5 | img=imread(strcat(fold(kk).folder,'\',fold(kk).name)); 6 | re=bwfunc(img); 7 | m =mean2(re); 8 | s = std2(re); 9 | maxv = max(re(:)); 10 | T = m + 0.5*(maxv - m); 11 | bw = re> T; 12 | imwrite(re, strcat(refold,'\',fold(kk).name)); 13 | end -------------------------------------------------------------------------------- /scripts/statistic.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import cv2 4 | import os 5 | import argparse 6 | 7 | def ROC(src1,src2): 8 | img = cv2.imread(src1,0) 9 | lab = cv2.imread(src2,0) 10 | lab = cv2.resize(lab,(896,896)) 11 | _,_,status,_ = cv2.connectedComponentsWithStats(img) 12 | num2,_,status2,_ = cv2.connectedComponentsWithStats(lab) 13 | ground_truth = num2-1 14 | detected_truth = 0 15 | alarm_pixel = 0 16 | 17 | for i in range(1,len(status)): 18 | x=status[i][0] 19 | y=status[i][1] 20 | w=status[i][2] 21 | h=status[i][3] 22 | ROI=lab[y:y+h,x:x+w] 23 | 24 | if np.all(ROI == 0): 25 | alarm_pixel += status[i][4] 26 | 27 | for j in range(1,len(status2)): 28 | x2=status2[j][0] 29 | y2=status2[j][1] 30 | w2=status2[j][2] 31 | h2=status2[j][3] 32 | ROI2=img[y2:y2+h2,x2:x2+w2] 33 | 34 | if np.all(ROI2 == 0) == False: 35 | detected_truth += 1 36 | 37 | return ground_truth,detected_truth,alarm_pixel 38 | 39 | if __name__ == "__main__": 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--image_path",required='True',help='path to save predicetd images') 43 | parser.add_argument("--label_path",required='True',help='path to save label images ') 44 | parser.add_argument("--width",type=int,required='True',help='label size should be same as the predicted images,if not,then resize to this size. ') 45 | parser.add_argument("--height",type=int,required='True') 46 | args = parser.parse_args() 47 | 48 | all_GT=0 49 | all_PD=0 50 | all_FA=0 51 | for fname in os.listdir(args.image_path): 52 | GT,pred,alarm=ROC(os.path.join(args.image_path,fname), os.path.join(args.label_path,fname)) 53 | all_GT += GT 54 | all_PD += pred 55 | all_FA += alarm 56 | all_pixel = len(os.listdir(args.image_path))*args.width*args.height 57 | print('PD=',all_PD/all_GT,'FA=',all_FA/all_pixel) 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /torch2trt/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch2trt import * 2 | from .converters import * 3 | import tensorrt as trt 4 | 5 | 6 | def load_plugins(): 7 | import os 8 | import ctypes 9 | ctypes.CDLL(os.path.join(os.path.dirname(__file__), 'libtorch2trt.so')) 10 | 11 | registry = trt.get_plugin_registry() 12 | torch2trt_creators = [c for c in registry.plugin_creator_list if c.plugin_namespace == 'torch2trt'] 13 | for c in torch2trt_creators: 14 | registry.register_creator(c, 'torch2trt') 15 | 16 | 17 | try: 18 | load_plugins() 19 | PLUGINS_LOADED = True 20 | except OSError: 21 | PLUGINS_LOADED = False 22 | -------------------------------------------------------------------------------- /torch2trt/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/__pycache__/calibration.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/__pycache__/calibration.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/__pycache__/module_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/__pycache__/module_test.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/__pycache__/torch2trt.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/__pycache__/torch2trt.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/calibration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorrt as trt 3 | 4 | 5 | if trt.__version__ >= '5.1': 6 | DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 7 | else: 8 | DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION 9 | 10 | 11 | class TensorBatchDataset(): 12 | 13 | def __init__(self, tensors): 14 | self.tensors = tensors 15 | 16 | def __len__(self): 17 | return len(self.tensors[0]) 18 | 19 | def __getitem__(self, idx): 20 | return [t[idx] for t in self.tensors] 21 | 22 | 23 | class DatasetCalibrator(trt.IInt8Calibrator): 24 | 25 | def __init__(self, inputs, dataset, batch_size=1, algorithm=DEFAULT_CALIBRATION_ALGORITHM): 26 | super(DatasetCalibrator, self).__init__() 27 | 28 | self.dataset = dataset 29 | self.batch_size = batch_size 30 | self.algorithm = algorithm 31 | 32 | # create buffers that will hold data batches 33 | self.buffers = [] 34 | for tensor in inputs: 35 | size = (batch_size,) + tuple(tensor.shape[1:]) 36 | buf = torch.zeros(size=size, dtype=tensor.dtype, device=tensor.device).contiguous() 37 | self.buffers.append(buf) 38 | 39 | self.count = 0 40 | 41 | def get_batch(self, *args, **kwargs): 42 | if self.count < len(self.dataset): 43 | 44 | for i in range(self.batch_size): 45 | 46 | idx = self.count % len(self.dataset) # roll around if not multiple of dataset 47 | inputs = self.dataset[idx] 48 | 49 | # copy data for (input_idx, dataset_idx) into buffer 50 | for buffer, tensor in zip(self.buffers, inputs): 51 | buffer[i].copy_(tensor) 52 | 53 | self.count += 1 54 | 55 | return [int(buf.data_ptr()) for buf in self.buffers] 56 | else: 57 | return [] 58 | 59 | def get_algorithm(self): 60 | return self.algorithm 61 | 62 | def get_batch_size(self): 63 | return self.batch_size 64 | 65 | def read_calibration_cache(self, *args, **kwargs): 66 | return None 67 | 68 | def write_calibration_cache(self, cache, *args, **kwargs): 69 | pass -------------------------------------------------------------------------------- /torch2trt/converters/AdaptiveAvgPool2d.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.AdaptiveAvgPool2d.forward') 6 | def convert_AdaptiveAvgPool2d(ctx): 7 | module = ctx.method_args[0] 8 | input = ctx.method_args[1] 9 | output = ctx.method_return 10 | 11 | input_trt = trt_(ctx.network, input) 12 | 13 | output_size = module.output_size 14 | if not isinstance(output_size, tuple): 15 | output_size = (output_size, ) * 2 16 | 17 | stride = (input_trt.shape[-2] // output_size[-2], input_trt.shape[-1] // output_size[-1]) 18 | 19 | kernel_size = stride 20 | layer = ctx.network.add_pooling( 21 | input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size) 22 | layer.stride = stride 23 | 24 | output._trt = layer.get_output(0) 25 | 26 | 27 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 28 | def test_AdaptiveAvgPool2d_1x1(): 29 | return torch.nn.AdaptiveAvgPool2d((1, 1)) 30 | 31 | 32 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 33 | def test_AdaptiveAvgPool2d_2x2(): 34 | return torch.nn.AdaptiveAvgPool2d((2, 2)) 35 | 36 | 37 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 38 | def test_AdaptiveAvgPool2d_3x3(): 39 | return torch.nn.AdaptiveAvgPool2d((3, 3)) 40 | -------------------------------------------------------------------------------- /torch2trt/converters/BatchNorm1d.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.BatchNorm1d.forward') 6 | def convert_BatchNorm2d(ctx): 7 | module = ctx.method_args[0] 8 | input = ctx.method_args[1] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | 12 | scale = module.weight.detach().cpu().numpy() / np.sqrt(module.running_var.detach().cpu().numpy() + module.eps) 13 | bias = module.bias.detach().cpu().numpy() - module.running_mean.detach().cpu().numpy() * scale 14 | power = np.ones_like(scale) 15 | 16 | # reshape to 2D 17 | layer = ctx.network.add_shuffle(input_trt) 18 | 19 | if len(input.shape) == 2: 20 | layer.reshape_dims = (input.shape[1], 1, 1) 21 | else: 22 | layer.reshape_dims = (input.shape[1], input.shape[2], 1) 23 | 24 | layer = ctx.network.add_scale(layer.get_output(0), trt.ScaleMode.CHANNEL, bias, scale, power) 25 | 26 | # reshape back to 1D 27 | layer = ctx.network.add_shuffle(layer.get_output(0)) 28 | layer.reshape_dims = tuple(output.shape[1:]) 29 | 30 | output._trt = layer.get_output(0) 31 | 32 | 33 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10)]) 34 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3)]) 35 | def test_BatchNorm1d_basic(): 36 | return torch.nn.BatchNorm1d(10) -------------------------------------------------------------------------------- /torch2trt/converters/BatchNorm2d.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | 3 | 4 | @tensorrt_converter('torch.nn.BatchNorm2d.forward') 5 | def convert_BatchNorm2d(ctx): 6 | module = ctx.method_args[0] 7 | input = ctx.method_args[1] 8 | input_trt = trt_(ctx.network, input) 9 | output = ctx.method_return 10 | 11 | scale = module.weight.detach().cpu().numpy() / np.sqrt(module.running_var.detach().cpu().numpy() + module.eps) 12 | bias = module.bias.detach().cpu().numpy() - module.running_mean.detach().cpu().numpy() * scale 13 | power = np.ones_like(scale) 14 | 15 | layer = ctx.network.add_scale(input_trt, trt.ScaleMode.CHANNEL, bias, scale, power) 16 | 17 | output._trt = layer.get_output(0) -------------------------------------------------------------------------------- /torch2trt/converters/Conv1d.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.Conv1d.forward') 6 | def convert_Conv1d(ctx): 7 | module = ctx.method_args[0] 8 | input = ctx.method_args[1] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | 12 | kernel_size = (module.kernel_size[0], 1) 13 | stride = (module.stride[0], 1) 14 | padding = (module.padding[0], 0) 15 | dilation = (module.dilation[0], 1) 16 | 17 | kernel = module.weight.detach().cpu().numpy()[..., None] 18 | 19 | bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype)) 20 | if module.bias is not None: 21 | bias = module.bias.detach().cpu().numpy() 22 | 23 | # reshape to 2D 24 | layer = ctx.network.add_shuffle(input_trt) 25 | layer.reshape_dims = (-1, input.shape[-1], 1) 26 | 27 | layer = ctx.network.add_convolution( 28 | input=layer.get_output(0), 29 | num_output_maps=module.out_channels, 30 | kernel_shape=kernel_size, 31 | kernel=kernel, 32 | bias=bias) 33 | layer.stride = stride 34 | layer.padding = padding 35 | layer.dilation = dilation 36 | 37 | if module.groups is not None: 38 | layer.num_groups = module.groups 39 | 40 | # reshape back to 1D 41 | layer = ctx.network.add_shuffle(layer.get_output(0)) 42 | layer.reshape_dims = (-1, output.shape[-1]) 43 | 44 | output._trt = layer.get_output(0) 45 | 46 | 47 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)]) 48 | def test_Conv1d_basic(): 49 | return torch.nn.Conv1d(10, 5, kernel_size=1, stride=1, padding=0) 50 | 51 | 52 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)]) 53 | def test_Conv1d_stride2(): 54 | return torch.nn.Conv1d(10, 5, kernel_size=1, stride=2, padding=0) 55 | 56 | 57 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)]) 58 | def test_Conv1d_kernel3(): 59 | return torch.nn.Conv1d(10, 5, kernel_size=3, stride=2, padding=1) 60 | 61 | 62 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)]) 63 | def test_Conv1d_dilation2(): 64 | return torch.nn.Conv1d(10, 5, kernel_size=3, stride=1, padding=1, dilation=2) 65 | -------------------------------------------------------------------------------- /torch2trt/converters/Conv2d.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.Conv2d.forward') 6 | def convert_Conv2d(ctx): 7 | module = ctx.method_args[0] 8 | input = ctx.method_args[1] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | 12 | kernel_size = module.kernel_size 13 | if not isinstance(kernel_size, tuple): 14 | kernel_size = (kernel_size, ) * 2 15 | 16 | stride = module.stride 17 | if not isinstance(stride, tuple): 18 | stride = (stride, ) * 2 19 | 20 | padding = module.padding 21 | if not isinstance(padding, tuple): 22 | padding = (padding, ) * 2 23 | 24 | dilation = module.dilation 25 | if not isinstance(dilation, tuple): 26 | dilation = (dilation, ) * 2 27 | 28 | kernel = module.weight.detach().cpu().numpy() 29 | 30 | bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype)) 31 | if module.bias is not None: 32 | bias = module.bias.detach().cpu().numpy() 33 | 34 | layer = ctx.network.add_convolution( 35 | input=input_trt, 36 | num_output_maps=module.out_channels, 37 | kernel_shape=kernel_size, 38 | kernel=kernel, 39 | bias=bias) 40 | layer.stride = stride 41 | layer.padding = padding 42 | layer.dilation = dilation 43 | 44 | if module.groups is not None: 45 | layer.num_groups = module.groups 46 | 47 | output._trt = layer.get_output(0) 48 | 49 | 50 | 51 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)]) 52 | def test_Conv2d_basic(): 53 | return torch.nn.Conv2d(10, 5, kernel_size=1, stride=1, padding=0) 54 | 55 | 56 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)]) 57 | def test_Conv2d_stride2(): 58 | return torch.nn.Conv2d(10, 5, kernel_size=1, stride=2, padding=0) 59 | 60 | 61 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)]) 62 | def test_Conv2d_kernel3(): 63 | return torch.nn.Conv2d(10, 5, kernel_size=3, stride=2, padding=1) 64 | 65 | 66 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)]) 67 | def test_Conv2d_dilation2(): 68 | return torch.nn.Conv2d(10, 5, kernel_size=3, stride=1, padding=1, dilation=2) 69 | -------------------------------------------------------------------------------- /torch2trt/converters/ConvTranspose2d.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | 3 | 4 | @tensorrt_converter('torch.nn.ConvTranspose2d.forward') 5 | def convert_ConvTranspose2d(ctx): 6 | module = ctx.method_args[0] 7 | input = ctx.method_args[1] 8 | input_trt = trt_(ctx.network, input) 9 | output = ctx.method_return 10 | 11 | kernel_size = module.kernel_size 12 | if not isinstance(kernel_size, tuple): 13 | kernel_size = (kernel_size, ) * 2 14 | 15 | stride = module.stride 16 | if not isinstance(stride, tuple): 17 | stride = (stride, ) * 2 18 | 19 | padding = module.padding 20 | if not isinstance(padding, tuple): 21 | padding = (padding, ) * 2 22 | 23 | kernel = module.weight.detach().cpu().numpy() 24 | 25 | bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype)) 26 | if module.bias is not None: 27 | bias = module.bias.detach().cpu().numpy() 28 | 29 | layer = ctx.network.add_deconvolution( 30 | input=input_trt, 31 | num_output_maps=module.out_channels, 32 | kernel_shape=kernel_size, 33 | kernel=kernel, 34 | bias=bias) 35 | layer.stride = stride 36 | layer.padding = padding 37 | 38 | if module.groups is not None: 39 | layer.num_groups = module.groups 40 | 41 | output._trt = layer.get_output(0) -------------------------------------------------------------------------------- /torch2trt/converters/Identity.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | 3 | 4 | @tensorrt_converter('torch.nn.Dropout.forward') 5 | @tensorrt_converter('torch.nn.Dropout2d.forward') 6 | @tensorrt_converter('torch.nn.Dropout3d.forward') 7 | def convert_Identity(ctx): 8 | input = ctx.method_args[1] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | output._trt = input_trt -------------------------------------------------------------------------------- /torch2trt/converters/Linear.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.Linear.forward') 6 | def convert_Linear(ctx): 7 | module = ctx.method_args[0] 8 | input = ctx.method_args[1] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | 12 | # reshape to ...xNx1x1 13 | layer = ctx.network.add_shuffle(input_trt) 14 | layer.reshape_dims = tuple(input_trt.shape) + (1, 1) 15 | 16 | bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype)) 17 | if module.bias is not None: 18 | bias = module.bias.detach().cpu().numpy() 19 | 20 | # add fully connected 21 | layer = ctx.network.add_fully_connected( 22 | input=layer.get_output(0), 23 | num_outputs=module.out_features, 24 | kernel=module.weight.detach().cpu().numpy(), 25 | bias=bias) 26 | 27 | # reshape back to N 28 | layer = ctx.network.add_shuffle(layer.get_output(0)) 29 | layer.reshape_dims = tuple(output.shape[1:]) 30 | 31 | output._trt = layer.get_output(0) 32 | 33 | 34 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10)]) 35 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 10)]) 36 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 10)]) 37 | def test_Linear_basic(): 38 | return torch.nn.Linear(10, 5) 39 | 40 | 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10)]) 42 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 10)]) 43 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 10)]) 44 | def test_Linear_no_bias(): 45 | return torch.nn.Linear(10, 5, bias=False) -------------------------------------------------------------------------------- /torch2trt/converters/LogSoftmax.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | 3 | 4 | @tensorrt_converter('torch.nn.LogSoftmax.forward') 5 | def convert_LogSoftmax(ctx): 6 | input = ctx.method_args[1] 7 | input_trt = trt_(ctx.network, input) 8 | output = ctx.method_return 9 | layer = ctx.network.add_softmax(input=input_trt) 10 | layer = ctx.network.add_unary(input=layer.get_output(0), 11 | op=trt.UnaryOperation.LOG) 12 | output._trt = layer.get_output(0) -------------------------------------------------------------------------------- /torch2trt/converters/ReLU.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | 3 | 4 | @tensorrt_converter('torch.nn.ReLU.forward') 5 | def convert_ReLU(ctx): 6 | input = ctx.method_args[1] 7 | input_trt = trt_(ctx.network, input) 8 | output = ctx.method_return 9 | layer = ctx.network.add_activation( 10 | input=input_trt, type=trt.ActivationType.RELU) 11 | output._trt = layer.get_output(0) -------------------------------------------------------------------------------- /torch2trt/converters/ReLU6.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.ReLU6.forward') 6 | def convert_ReLU6(ctx): 7 | input = ctx.method_args[1] 8 | output = ctx.method_return 9 | 10 | input_trt, trt_6 = trt_(ctx.network, input, 6) 11 | 12 | layer = ctx.network.add_activation( 13 | input=input_trt, type=trt.ActivationType.RELU) 14 | layer = ctx.network.add_elementwise( 15 | layer.get_output(0), trt_6, trt.ElementWiseOperation.MIN) 16 | 17 | output._trt = layer.get_output(0) 18 | 19 | 20 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)]) 21 | def test_relu6_basic(): 22 | return torch.nn.ReLU6() -------------------------------------------------------------------------------- /torch2trt/converters/__init__.py: -------------------------------------------------------------------------------- 1 | # dummy converters throw warnings method encountered 2 | 3 | from .dummy_converters import * 4 | 5 | # supported converters will override dummy converters 6 | 7 | from .activation import * 8 | from .adaptive_avg_pool2d import * 9 | from .adaptive_max_pool2d import * 10 | from .AdaptiveAvgPool2d import * 11 | from .add import * 12 | from .avg_pool2d import * 13 | from .mul import * 14 | from .div import * 15 | from .BatchNorm1d import * 16 | from .BatchNorm2d import * 17 | from .cat import * 18 | from .clamp import * 19 | from .Conv1d import * 20 | from .Conv2d import * 21 | from .ConvTranspose2d import * 22 | from .getitem import * 23 | from .identity import * 24 | from .Identity import * 25 | from .instance_norm import * 26 | from .Linear import * 27 | from .LogSoftmax import * 28 | from .max_pool2d import * 29 | from .max import * 30 | from .min import * 31 | from .normalize import * 32 | from .pad import * 33 | from .permute import * 34 | from .pow import * 35 | from .prelu import * 36 | from .prod import * 37 | from .relu import * 38 | from .ReLU import * 39 | from .relu6 import * 40 | from .ReLU6 import * 41 | from .sigmoid import * 42 | from .sub import * 43 | from .sum import * 44 | from .view import * 45 | from .tanh import * 46 | from .transpose import * 47 | from .mean import * 48 | from .softmax import * 49 | from .split import * 50 | from .chunk import * 51 | from .unary import * 52 | 53 | 54 | try: 55 | from .interpolate import * 56 | except: 57 | pass 58 | -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/AdaptiveAvgPool2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/AdaptiveAvgPool2d.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/BatchNorm1d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/BatchNorm1d.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/BatchNorm2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/BatchNorm2d.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/Conv1d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/Conv1d.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/Conv2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/Conv2d.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/ConvTranspose2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/ConvTranspose2d.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/Linear.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/Linear.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/LogSoftmax.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/LogSoftmax.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/ReLU6.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/ReLU6.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/activation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/activation.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/adaptive_avg_pool2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/adaptive_avg_pool2d.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/adaptive_max_pool2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/adaptive_max_pool2d.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/add.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/add.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/avg_pool2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/avg_pool2d.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/cat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/cat.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/chunk.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/chunk.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/clamp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/clamp.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/div.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/div.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/dummy_converters.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/dummy_converters.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/getitem.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/getitem.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/identity.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/identity.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/instance_norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/instance_norm.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/max.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/max.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/max_pool2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/max_pool2d.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/mean.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/mean.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/min.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/min.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/mul.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/mul.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/normalize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/normalize.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/pad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/pad.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/permute.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/permute.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/pow.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/pow.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/prelu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/prelu.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/prod.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/prod.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/relu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/relu.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/sigmoid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/sigmoid.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/softmax.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/softmax.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/split.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/split.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/sub.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/sub.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/sum.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/sum.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/tanh.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/tanh.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/transpose.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/transpose.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/__pycache__/unary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/__pycache__/unary.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/activation.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | from .unary import UnaryModule 4 | 5 | 6 | # | RELU : Rectified Linear activation (impl in relu.py) 7 | # | SIGMOID : Sigmoid activation (impl in sigmoid.py) 8 | # | TANH : Hyperbolic Tangent activation (impl in tanh.py) 9 | 10 | 11 | # | LEAKY_RELU : Leaky Relu activation: f(x) = x if x >= 0, f(x) = alpha * x if x < 0 12 | 13 | 14 | @tensorrt_converter('torch.nn.functional.leaky_relu') 15 | @tensorrt_converter('torch.nn.functional.leaky_relu_') 16 | def convert_leaky_relu(ctx): 17 | input = get_arg(ctx, 'input', pos=0, default=None) 18 | negative_slope = get_arg(ctx, 'negative_slope', pos=1, default=0.01) 19 | output = ctx.method_return 20 | 21 | input_trt = trt_(ctx.network, input) 22 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.LEAKY_RELU) 23 | layer.alpha = negative_slope 24 | 25 | output._trt = layer.get_output(0) 26 | 27 | 28 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 29 | def test_leaky_relu(): 30 | return UnaryModule(lambda x: torch.nn.functional.leaky_relu(x)) 31 | 32 | 33 | # | ELU : Elu activation: f(x) = x if x >= 0, f(x) = alpha * (exp(x) - 1) if x < 0 34 | 35 | 36 | @tensorrt_converter('torch.nn.functional.elu') 37 | @tensorrt_converter('torch.nn.functional.elu_') 38 | def convert_elu(ctx): 39 | input = get_arg(ctx, 'input', pos=0, default=None) 40 | alpha = get_arg(ctx, 'alpha', pos=1, default=1.0) 41 | output = ctx.method_return 42 | 43 | input_trt = trt_(ctx.network, input) 44 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.ELU) 45 | layer.alpha = alpha 46 | 47 | output._trt = layer.get_output(0) 48 | 49 | 50 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 51 | def test_elu(): 52 | return UnaryModule(lambda x: torch.nn.functional.elu(x)) 53 | 54 | 55 | # | SELU : Selu activation: f(x) = beta * x if x > 0, f(x) = beta * (alpha * exp(x) - alpha) if x <= 0 56 | 57 | @tensorrt_converter('torch.selu') 58 | @tensorrt_converter('torch.selu_') 59 | @tensorrt_converter('torch.nn.functional.selu') 60 | @tensorrt_converter('torch.nn.functional.selu_') 61 | def convert_selu(ctx): 62 | input = get_arg(ctx, 'input', pos=0, default=None) 63 | alpha = get_arg(ctx, 'alpha', pos=1, default=1.0) 64 | output = ctx.method_return 65 | 66 | input_trt = trt_(ctx.network, input) 67 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.SELU) 68 | layer.alpha = 1.6732632423543772848170429916717 69 | layer.beta = 1.0507009873554804934193349852946 70 | 71 | output._trt = layer.get_output(0) 72 | 73 | 74 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 75 | def test_selu(): 76 | return UnaryModule(lambda x: torch.nn.functional.selu(x)) 77 | 78 | 79 | # | SOFTSIGN : Softsign activation: f(x) = x / (1 + \|x\|) 80 | 81 | 82 | @tensorrt_converter('torch.nn.functional.softsign') 83 | def convert_softsign(ctx): 84 | input = get_arg(ctx, 'input', pos=0, default=None) 85 | output = ctx.method_return 86 | 87 | input_trt = trt_(ctx.network, input) 88 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.SOFTSIGN) 89 | 90 | output._trt = layer.get_output(0) 91 | 92 | 93 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 94 | def test_softsign(): 95 | return UnaryModule(lambda x: torch.nn.functional.softsign(x)) 96 | 97 | 98 | # | SOFTPLUS : Softplus activation: f(x) = alpha * log(exp(beta * x) + 1) 99 | 100 | 101 | @tensorrt_converter('torch.nn.functional.softplus') 102 | def convert_softplus(ctx): 103 | input = get_arg(ctx, 'input', pos=0, default=None) 104 | output = ctx.method_return 105 | 106 | input_trt = trt_(ctx.network, input) 107 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.SOFTPLUS) 108 | 109 | output._trt = layer.get_output(0) 110 | 111 | 112 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 113 | def test_softplus(): 114 | return UnaryModule(lambda x: torch.nn.functional.softplus(x)) 115 | 116 | 117 | # | CLIP : Clip activation: f(x) = max(alpha, min(beta, x)) (impl in clamp.py) 118 | 119 | # | HARD_SIGMOID : Hard sigmoid activation: f(x) = max(0, min(1, alpha * x + beta)) (not sure if there is this in Pytorch?) 120 | # | SCALED_TANH : Scaled Tanh activation: f(x) = alpha * tanh(beta * x) (not sure if there is this in Pytorch?) 121 | # | THRESHOLDED_RELU : Thresholded Relu activation: f(x) = x if x > alpha, f(x) = 0 if x <= alpha (not sure if there is this in Pytorch?) -------------------------------------------------------------------------------- /torch2trt/converters/adaptive_avg_pool2d.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from .AdaptiveAvgPool2d import * 3 | 4 | 5 | @tensorrt_converter('torch.nn.functional.adaptive_avg_pool2d') 6 | def convert_adaptive_avg_pool2d(ctx): 7 | ctx.method_args = (torch.nn.AdaptiveAvgPool2d(ctx.method_args[1]), ctx.method_args[0]) 8 | convert_AdaptiveAvgPool2d(ctx) 9 | -------------------------------------------------------------------------------- /torch2trt/converters/adaptive_max_pool2d.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.functional.adaptive_max_pool2d') 6 | def convert_adaptive_max_pool2d(ctx): 7 | input = ctx.method_args[0] 8 | output = ctx.method_return 9 | 10 | output_size = ctx.method_args[1] 11 | if isinstance(output_size, int): 12 | output_size = (output_size, ) * 2 13 | 14 | stride = (input._trt.shape[-2] // output_size[-2], input._trt.shape[-1] // output_size[-1]) 15 | 16 | kernel_size = stride 17 | layer = ctx.network.add_pooling( 18 | input=input._trt, type=trt.PoolingType.MAX, window_size=kernel_size) 19 | layer.stride = stride 20 | 21 | output._trt = layer.get_output(0) 22 | 23 | 24 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 25 | def test_adaptive_max_pool2d_1x1(): 26 | return torch.nn.AdaptiveMaxPool2d((1, 1)) 27 | 28 | 29 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 30 | def test_adaptive_max_pool2d_2x2(): 31 | return torch.nn.AdaptiveMaxPool2d((2, 2)) 32 | 33 | 34 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 35 | def test_adaptive_max_pool2d_3x3(): 36 | return torch.nn.AdaptiveMaxPool2d((3, 3)) 37 | -------------------------------------------------------------------------------- /torch2trt/converters/add.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.add') 6 | @tensorrt_converter('torch.Tensor.__iadd__') 7 | @tensorrt_converter('torch.Tensor.__add__') 8 | @tensorrt_converter('torch.Tensor.__radd__') 9 | def convert_add(ctx): 10 | input_a = ctx.method_args[0] 11 | input_b = ctx.method_args[1] 12 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 13 | output = ctx.method_return 14 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.SUM) 15 | output._trt = layer.get_output(0) 16 | 17 | 18 | class Add(torch.nn.Module): 19 | def __init__(self): 20 | super(Add, self).__init__() 21 | 22 | def forward(self, x, y): 23 | return x + y 24 | 25 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 26 | def test_add_basic(): 27 | return Add() 28 | 29 | 30 | class IAdd(torch.nn.Module): 31 | def __init__(self): 32 | super(IAdd, self).__init__() 33 | 34 | def forward(self, x, y): 35 | x += y 36 | return x 37 | 38 | 39 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 40 | def test_add_iadd(): 41 | return IAdd() 42 | 43 | 44 | class TorchAdd(torch.nn.Module): 45 | def __init__(self): 46 | super(TorchAdd, self).__init__() 47 | 48 | def forward(self, x, y): 49 | return torch.add(x, y) 50 | 51 | 52 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 53 | def test_add_torchadd(): 54 | return TorchAdd() 55 | 56 | 57 | class RAddInt(torch.nn.Module): 58 | def __init__(self): 59 | super(RAddInt, self).__init__() 60 | 61 | def forward(self, x): 62 | return 1 + x 63 | 64 | 65 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 66 | def test_add_radd_int(): 67 | return RAddInt() 68 | 69 | 70 | class RAddFloat(torch.nn.Module): 71 | def __init__(self): 72 | super(RAddFloat, self).__init__() 73 | 74 | def forward(self, x): 75 | return 1.0 + x 76 | 77 | 78 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 79 | def test_add_radd_float(): 80 | return RAddFloat() -------------------------------------------------------------------------------- /torch2trt/converters/avg_pool2d.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.functional.avg_pool2d') 6 | def convert_avg_pool2d(ctx): 7 | # parse args 8 | input = get_arg(ctx, 'input', pos=0, default=None) 9 | kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=None) 10 | stride = get_arg(ctx, 'stride', pos=2, default=None) 11 | padding = get_arg(ctx, 'padding', pos=3, default=0) 12 | ceil_mode = get_arg(ctx, 'ceil_mode', pos=4, default=False) 13 | count_include_pad = get_arg(ctx, 'count_include_pad', pos=5, default=True) 14 | 15 | # get input trt tensor (or create constant if it doesn't exist) 16 | input_trt = trt_(ctx.network, input) 17 | 18 | output = ctx.method_return 19 | 20 | # get kernel size 21 | if not isinstance(kernel_size, tuple): 22 | kernel_size = (kernel_size, ) * 2 23 | 24 | # get stride 25 | if not isinstance(stride, tuple): 26 | stride = (stride, ) * 2 27 | 28 | # get padding 29 | if not isinstance(padding, tuple): 30 | padding = (padding, ) * 2 31 | 32 | layer = ctx.network.add_pooling( 33 | input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size) 34 | 35 | layer.stride = stride 36 | layer.padding = padding 37 | layer.average_count_excludes_padding = not count_include_pad 38 | 39 | if ceil_mode: 40 | layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP 41 | 42 | output._trt = layer.get_output(0) 43 | 44 | 45 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 6)]) 46 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 5, 7)]) 47 | def test_avg_pool2d_without_ceil_mode(): 48 | return torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) 49 | 50 | 51 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 6)]) 52 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 5, 7)]) 53 | def test_avg_pool2d_with_ceil_mode(): 54 | return torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True, count_include_pad=False) # TRT does not support ceil_mode=True && count_include_pad=True 55 | -------------------------------------------------------------------------------- /torch2trt/converters/cat.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | 3 | 4 | @tensorrt_converter('torch.cat') 5 | def convert_cat(ctx): 6 | inputs = ctx.method_args[0] 7 | 8 | if 'dim' in ctx.method_kwargs: 9 | dim = ctx.method_kwargs['dim'] 10 | else: 11 | dim = ctx.method_args[1] 12 | 13 | output = ctx.method_return 14 | trt_inputs = [trt_(ctx.network, i) for i in inputs] 15 | 16 | layer = ctx.network.add_concatenation(inputs=trt_inputs) 17 | layer.axis = dim - 1 18 | output._trt = layer.get_output(0) -------------------------------------------------------------------------------- /torch2trt/converters/chunk.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | from .split import convert_split 4 | 5 | 6 | @tensorrt_converter('torch.chunk') 7 | @tensorrt_converter('torch.Tensor.chunk') 8 | def convert_chunk(ctx): 9 | convert_split(ctx) 10 | 11 | 12 | class TorchChunk(torch.nn.Module): 13 | 14 | def __init__(self, *args, **kwargs): 15 | super(TorchChunk, self).__init__() 16 | self.args = args 17 | self.kwargs = kwargs 18 | 19 | def forward(self, x): 20 | return torch.chunk(x, *self.args, **self.kwargs) 21 | 22 | 23 | class TensorChunk(torch.nn.Module): 24 | 25 | def __init__(self, *args, **kwargs): 26 | super(TensorChunk, self).__init__() 27 | self.args = args 28 | self.kwargs = kwargs 29 | 30 | def forward(self, x): 31 | return x.chunk(*self.args, **self.kwargs) 32 | 33 | 34 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 35 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 36 | def test_torch_chunk_1_1(): 37 | return TorchChunk(1, 1) 38 | 39 | 40 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 42 | def test_torch_chunk_2_1(): 43 | return TorchChunk(2, 1) 44 | 45 | 46 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 47 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 48 | def test_torch_chunk_3_1(): 49 | return TorchChunk(3, 1) 50 | 51 | 52 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 53 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 54 | def test_torch_chunk_3_2(): 55 | return TorchChunk(3, 2) 56 | 57 | 58 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 59 | def test_tensor_chunk_3_2(): 60 | return TensorChunk(3, 2) -------------------------------------------------------------------------------- /torch2trt/converters/clamp.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | def __add_clamp(network, trt_input, val, op): 6 | 7 | # create TensorRT constant for minimum value 8 | val_shape = (1, ) * len(trt_input.shape) # broadcast all dimensions 9 | val_tensor = val * torch.ones(val_shape, dtype=torch_dtype_from_trt(trt_input.dtype)).cpu().numpy() 10 | val_trt = network.add_constant(val_shape, val_tensor) 11 | layer = network.add_elementwise(trt_input, val_trt.get_output(0), op) 12 | 13 | return layer 14 | 15 | 16 | # CLAMP_MIN 17 | 18 | 19 | @tensorrt_converter('torch.clamp_min') 20 | @tensorrt_converter('torch.Tensor.clamp_min') 21 | def convert_clamp_min(ctx): 22 | input = ctx.method_args[0] 23 | input_trt = trt_(ctx.network, input) 24 | val = ctx.method_args[1] 25 | output = ctx.method_return 26 | 27 | layer = __add_clamp(ctx.network, input_trt, val, trt.ElementWiseOperation.MAX) 28 | 29 | output._trt = layer.get_output(0) 30 | 31 | 32 | class TorchClampMin(torch.nn.Module): 33 | def forward(self, x): 34 | return torch.clamp_min(x, -0.1) 35 | 36 | 37 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 38 | def test_torch_clamp_min(): 39 | return TorchClampMin() 40 | 41 | 42 | class TensorClampMin(torch.nn.Module): 43 | def forward(self, x): 44 | return x.clamp_min(-0.1) 45 | 46 | 47 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 48 | def test_tensor_clamp_min(): 49 | return TensorClampMin() 50 | 51 | 52 | # CLAMP_MAX 53 | 54 | 55 | @tensorrt_converter('torch.clamp_max') 56 | @tensorrt_converter('torch.Tensor.clamp_max') 57 | def convert_clamp_max(ctx): 58 | input = ctx.method_args[0] 59 | val = ctx.method_args[1] 60 | output = ctx.method_return 61 | 62 | layer = __add_clamp(ctx.network, input._trt, val, trt.ElementWiseOperation.MIN) 63 | 64 | output._trt = layer.get_output(0) 65 | 66 | 67 | class TorchClampMax(torch.nn.Module): 68 | def forward(self, x): 69 | return torch.clamp_max(x, 0.1) 70 | 71 | 72 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 73 | def test_torch_clamp_max(): 74 | return TorchClampMax() 75 | 76 | 77 | class TensorClampMax(torch.nn.Module): 78 | def forward(self, x): 79 | return x.clamp_max(0.1) 80 | 81 | 82 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 83 | def test_tensor_clamp_max(): 84 | return TensorClampMax() 85 | 86 | 87 | # CLAMP 88 | 89 | @tensorrt_converter('torch.clamp') 90 | @tensorrt_converter('torch.Tensor.clamp') 91 | def convert_clamp(ctx): 92 | input = ctx.method_args[0] 93 | output = ctx.method_return 94 | if "min" in ctx.method_kwargs and "max" in ctx.method_kwargs: 95 | min_val = ctx.method_kwargs["min"] 96 | max_val = ctx.method_kwargs["max"] 97 | layer = __add_clamp(ctx.network, input._trt, min_val, trt.ElementWiseOperation.MAX) 98 | layer = __add_clamp(ctx.network, layer.get_output(0), max_val, trt.ElementWiseOperation.MIN) 99 | elif "min" in ctx.method_kwargs: 100 | min_val = ctx.method_kwargs["min"] 101 | layer = __add_clamp(ctx.network, input._trt, min_val, trt.ElementWiseOperation.MAX) 102 | elif "max" in ctx.method_kwargs: 103 | max_val = ctx.method_kwargs["max"] 104 | layer = __add_clamp(ctx.network, input._trt, max_val, trt.ElementWiseOperation.MIN) 105 | else: 106 | min_val = ctx.method_args[1] 107 | max_val = ctx.method_args[2] 108 | layer = __add_clamp(ctx.network, input._trt, min_val, trt.ElementWiseOperation.MAX) 109 | layer = __add_clamp(ctx.network, layer.get_output(0), max_val, trt.ElementWiseOperation.MIN) 110 | 111 | output._trt = layer.get_output(0) 112 | 113 | 114 | class TorchClamp(torch.nn.Module): 115 | def forward(self, x): 116 | return torch.clamp(x, -0.1, 0.1) 117 | 118 | 119 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 120 | def test_torch_clamp(): 121 | return TorchClamp() 122 | 123 | 124 | class TensorClamp(torch.nn.Module): 125 | def forward(self, x): 126 | return x.clamp(-0.1, 0.1) 127 | 128 | 129 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 130 | def test_tensor_clamp(): 131 | return TensorClamp() 132 | 133 | 134 | class TorchClampOptionMax(torch.nn.Module): 135 | def forward(self, x): 136 | return torch.clamp(x, max=0.1) 137 | 138 | 139 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 140 | def test_torch_clamp_option_max(): 141 | return TorchClampOptionMax() 142 | 143 | class TorchClampOptionMin(torch.nn.Module): 144 | def forward(self, x): 145 | return torch.clamp(x, min=-0.1) 146 | 147 | 148 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 149 | def test_torch_clamp_option_min(): 150 | return TorchClampOptionMin() 151 | 152 | 153 | class TorchClampOptionMaxMin(torch.nn.Module): 154 | def forward(self, x): 155 | return torch.clamp(x, min=-0.1, max=0.1) 156 | 157 | 158 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 159 | def test_torch_clamp_option_max_min(): 160 | return TorchClampOptionMaxMin() 161 | 162 | 163 | class TensorClampOptionMax(torch.nn.Module): 164 | def forward(self, x): 165 | return x.clamp(max=0.1) 166 | 167 | 168 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 169 | def test_tensor_clamp_option_max(): 170 | return TensorClampOptionMax() 171 | 172 | class TensorClampOptionMin(torch.nn.Module): 173 | def forward(self, x): 174 | return x.clamp(min=-0.1) 175 | 176 | 177 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 178 | def test_tensor_clamp_option_min(): 179 | return TensorClampOptionMin() 180 | 181 | 182 | class TensorClampOptionMaxMin(torch.nn.Module): 183 | def forward(self, x): 184 | return x.clamp(min=-0.1, max=0.1) 185 | 186 | 187 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 188 | def test_tensor_clamp_option_max_min(): 189 | return TensorClampOptionMaxMin() -------------------------------------------------------------------------------- /torch2trt/converters/div.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.div') 6 | @tensorrt_converter('torch.Tensor.__div__') # py2 7 | @tensorrt_converter('torch.Tensor.__idiv__') # py2 8 | @tensorrt_converter('torch.Tensor.__truediv__') # py3 9 | @tensorrt_converter('torch.Tensor.__itruediv__') # py3 10 | def convert_div(ctx): 11 | input_a = ctx.method_args[0] 12 | input_b = ctx.method_args[1] 13 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 14 | output = ctx.method_return 15 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.DIV) 16 | output._trt = layer.get_output(0) 17 | 18 | 19 | @tensorrt_converter('torch.Tensor.__rdiv__') # py2 20 | @tensorrt_converter('torch.Tensor.__rtruediv__') # py3 21 | def convert_rdiv(ctx): 22 | input_a = ctx.method_args[1] # inputs switched for rdiv 23 | input_b = ctx.method_args[0] 24 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 25 | output = ctx.method_return 26 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.DIV) 27 | output._trt = layer.get_output(0) 28 | 29 | 30 | class Div(torch.nn.Module): 31 | def __init__(self): 32 | super(Div, self).__init__() 33 | 34 | def forward(self, x, y): 35 | return x / y 36 | 37 | 38 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 39 | def test_div_basic(): 40 | return Div() 41 | 42 | 43 | class IDiv(torch.nn.Module): 44 | def __init__(self): 45 | super(IDiv, self).__init__() 46 | 47 | def forward(self, x, y): 48 | x /= y 49 | return x 50 | 51 | 52 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 53 | def test_div_idiv(): 54 | return IDiv() 55 | 56 | 57 | class TorchDiv(torch.nn.Module): 58 | def __init__(self): 59 | super(TorchDiv, self).__init__() 60 | 61 | def forward(self, x, y): 62 | return torch.div(x, y) 63 | 64 | 65 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 66 | def test_div_torchdiv(): 67 | return TorchDiv() 68 | 69 | 70 | class RDivInt(torch.nn.Module): 71 | def __init__(self): 72 | super(RDivInt, self).__init__() 73 | 74 | def forward(self, x): 75 | return 100 / x 76 | 77 | 78 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 79 | def test_rdiv_int(): 80 | return RDivInt() 81 | 82 | 83 | class RDivFloat(torch.nn.Module): 84 | def __init__(self): 85 | super(RDivFloat, self).__init__() 86 | 87 | def forward(self, x): 88 | return 100.0 / x 89 | 90 | 91 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 92 | def test_rdiv_float(): 93 | return RDivFloat() -------------------------------------------------------------------------------- /torch2trt/converters/dummy_converters.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | 3 | 4 | def is_private(method): 5 | method = method.split('.')[-1] # remove prefix 6 | return method[0] == '_' and method[1] is not '_' 7 | 8 | def is_function_type(method): 9 | fntype = eval(method + '.__class__.__name__') 10 | return fntype == 'function' or fntype == 'builtin_function_or_method' or fntype == 'method_descriptor' 11 | 12 | def get_methods(namespace): 13 | methods = [] 14 | for method in dir(eval(namespace)): 15 | full_method = namespace + '.' + method 16 | if not is_private(full_method) and is_function_type(full_method): 17 | methods.append(full_method) 18 | return methods 19 | 20 | 21 | TORCH_METHODS = [] 22 | TORCH_METHODS += get_methods('torch') 23 | TORCH_METHODS += get_methods('torch.Tensor') 24 | TORCH_METHODS += get_methods('torch.nn.functional') 25 | 26 | 27 | for method in TORCH_METHODS: 28 | 29 | @tensorrt_converter(method, is_real=False) 30 | def warn_method(ctx): 31 | print('Warning: Encountered known unsupported method %s' % ctx.method_str) 32 | 33 | 34 | @tensorrt_converter('torch.Tensor.dim', is_real=False) 35 | @tensorrt_converter('torch.Tensor.size', is_real=False) 36 | def dont_warn(ctx): 37 | pass -------------------------------------------------------------------------------- /torch2trt/converters/getitem.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | def slice_to_trt(dim_size, dim_slice): 6 | 7 | start = 0 if dim_slice.start is None else dim_slice.start 8 | stop = dim_size if dim_slice.stop is None else dim_slice.stop 9 | stride = 1 if dim_slice.step is None else dim_slice.step 10 | 11 | size = (stop - start - 1) // stride + 1 12 | 13 | return start, size, stride 14 | 15 | 16 | def num_slice_types(slices): 17 | num_slice = 0 18 | for s in slices: 19 | if isinstance(s, slice) or isinstance(s, int): 20 | num_slice += 1 21 | return num_slice 22 | 23 | 24 | @tensorrt_converter('torch.Tensor.__getitem__') 25 | def convert_tensor_getitem(ctx): 26 | input = ctx.method_args[0] 27 | slices = ctx.method_args[1] 28 | output = ctx.method_return 29 | 30 | input_trt = input._trt 31 | 32 | # Step 1 - Replace ellipsis with expanded slices 33 | 34 | num_ellipsis = input.ndim - num_slice_types(slices) 35 | 36 | new_slices = [] 37 | for s in slices: 38 | 39 | if s == Ellipsis: 40 | while num_ellipsis > 0: 41 | new_slices.append(slice(None, None, None)) 42 | num_ellipsis -= 1 43 | elif isinstance(s, slice): 44 | new_slices.append(s) 45 | elif s is None: 46 | new_slices.append(None) 47 | elif isinstance(s, int): 48 | new_slices.append(s) 49 | 50 | # fill missing slices at end 51 | while num_slice_types(new_slices) < len(input.shape): 52 | new_slices.append(slice(None, None, None)) 53 | 54 | # Step 2 - Remove batch from slices (TRT from this point) 55 | 56 | slices = tuple(new_slices[1:]) # remove batch 57 | 58 | 59 | # Step 3 - Add slice layer (will currently ignore 'None' slices) 60 | 61 | starts = [] 62 | sizes = [] 63 | strides = [] 64 | 65 | input_dim = 0 66 | for s in slices: 67 | 68 | if input_dim >= len(input_trt.shape): 69 | break 70 | 71 | input_size = int(input_trt.shape[input_dim]) 72 | 73 | if isinstance(s, slice): 74 | start, size, stride = slice_to_trt(input_size, s) 75 | starts.append(start) 76 | sizes.append(size) 77 | strides.append(stride) 78 | input_dim += 1 79 | 80 | elif isinstance(s, int): 81 | starts.append(s) 82 | sizes.append(1) 83 | strides.append(1) 84 | input_dim += 1 85 | 86 | output_trt = ctx.network.add_slice(input_trt, starts, sizes, strides).get_output(0) 87 | 88 | # Step 4 - Add shuffle layer to insert dimensions for 'None' slices and remove dimensions for 'int' slices 89 | 90 | num_non_slice = len([s for s in slices if not isinstance(s, slice)]) 91 | if num_non_slice > 0: 92 | layer = ctx.network.add_shuffle(output_trt) 93 | layer.reshape_dims = tuple(output.shape[1:]) # exclude batch 94 | output_trt = layer.get_output(0) 95 | 96 | output._trt = output_trt 97 | 98 | 99 | class LambdaModule(torch.nn.Module): 100 | def __init__(self, fn): 101 | super(LambdaModule, self).__init__() 102 | self.fn = fn 103 | 104 | def forward(self, x): 105 | return self.fn(x) 106 | 107 | 108 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 109 | def test_tensor_getitem_1d_int(): 110 | return LambdaModule(lambda x: x[:, 0]) 111 | 112 | 113 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 4, 3)]) 114 | def test_tensor_getitem_2d_int(): 115 | return LambdaModule(lambda x: x[:, 0]) 116 | 117 | 118 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 4, 3)]) 119 | def test_tensor_getitem_2d_strided(): 120 | return LambdaModule(lambda x: x[:, ::2]) 121 | 122 | 123 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 4, 3)]) 124 | def test_tensor_getitem_2d_strided_offset(): 125 | return LambdaModule(lambda x: x[:, 1::2]) 126 | 127 | 128 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 4, 3)]) 129 | def test_tensor_getitem_2d_strided_range(): 130 | return LambdaModule(lambda x: x[:, 1:3:2]) 131 | 132 | 133 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 4, 3)]) 134 | def test_tensor_getitem_2d_insert_dim(): 135 | return LambdaModule(lambda x: x[:, None]) 136 | 137 | 138 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 4, 3)]) 139 | def test_tensor_getitem_2d_insert_dim_ellipsis(): 140 | return LambdaModule(lambda x: x[:, None, ...]) 141 | 142 | 143 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 4, 3)]) 144 | def test_tensor_getitem_2d_append_dim(): 145 | return LambdaModule(lambda x: x[:, ..., None]) 146 | 147 | 148 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 4, 3)]) 149 | def test_tensor_getitem_2d_append_2dim(): 150 | return LambdaModule(lambda x: x[:, ..., None, None]) 151 | 152 | 153 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 4, 3)]) 154 | def test_tensor_getitem_2d_weird_combo(): 155 | return LambdaModule(lambda x: x[:, 0:3:4, None, None, 1, ...]) -------------------------------------------------------------------------------- /torch2trt/converters/instance_norm.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | def _add_scale_1d2d3d(network, x_trt, mode, offset, scale, power): 6 | ndim = len(x_trt.shape) 7 | 8 | y_trt = x_trt 9 | 10 | # shape to 2D 11 | if ndim != 3: 12 | layer = network.add_shuffle(y_trt) 13 | layer.reshape_dims = (x_trt.shape[0], x_trt.shape[1], -1) # NCH -> NCHW 14 | y_trt = layer.get_output(0) 15 | 16 | y_trt = network.add_scale(y_trt, mode, offset, scale, power).get_output(0) 17 | 18 | # shape to original dimension 19 | if ndim != 3: 20 | layer = network.add_shuffle(layer.get_output(0)) 21 | layer.reshape_dims = tuple(x_trt.shape) 22 | y_trt = layer.get_output(0) 23 | 24 | return y_trt 25 | 26 | @tensorrt_converter('torch.instance_norm') 27 | @tensorrt_converter('torch.nn.functional.instance_norm') 28 | def convert_instance_norm(ctx): 29 | input = get_arg(ctx, 'input', pos=0, default=None) 30 | running_mean = get_arg(ctx, 'running_mean', pos=1, default=None) 31 | running_var = get_arg(ctx, 'running_var', pos=2, default=None) 32 | weight = get_arg(ctx, 'weight', pos=3, default=None) 33 | bias = get_arg(ctx, 'bias', pos=4, default=None) 34 | use_input_stats = get_arg(ctx, 'use_input_stats', pos=5, default=True) 35 | momentum = get_arg(ctx, 'momentum', pos=6, default=0.1) 36 | eps = get_arg(ctx, 'eps', pos=7, default=1e-05) 37 | output = ctx.method_return 38 | 39 | 40 | # CASE 1 - USING RUNNING STATISTICS 41 | if not use_input_stats: 42 | 43 | # equivalent to batch norm 44 | scale = 1.0 / np.sqrt(running_var.detach().cpu().numpy() + eps) 45 | offset = -running_mean.detach().cpu().numpy() * scale 46 | power = np.ones_like(scale) 47 | 48 | if weight is not None: 49 | scale *= weight.detach().cpu().numpy() 50 | offset += bias.detach().cpu().numpy() 51 | 52 | result_trt = _add_scale_1d2d3d(ctx.network, input._trt, trt.ScaleMode.CHANNEL, offset, scale, power) 53 | 54 | output._trt = result_trt 55 | 56 | # CASE 2 - USING INPUT STATS 57 | else: 58 | 59 | eps_np = np.array([eps], dtype=np.float32) 60 | keep_dims = True 61 | reduce_axes = torch_dim_to_trt_axes(tuple(range(2, input.ndim))) 62 | 63 | # compute mean over spatial 64 | mean_trt = ctx.network.add_reduce(input._trt, trt.ReduceOperation.AVG, reduce_axes, keep_dims).get_output(0) 65 | 66 | # compute variance over spatial (include eps, to reduce layer count) 67 | delta_trt = ctx.network.add_elementwise(input._trt, mean_trt, trt.ElementWiseOperation.SUB).get_output(0) 68 | var_trt = ctx.network.add_scale(delta_trt, trt.ScaleMode.UNIFORM, np.zeros_like(eps_np), np.ones_like(eps_np), 2 * np.ones_like(eps_np)).get_output(0) 69 | var_trt = ctx.network.add_reduce(var_trt, trt.ReduceOperation.AVG, reduce_axes, keep_dims).get_output(0) 70 | 71 | # compute sqrt(var + eps) 72 | var_trt = ctx.network.add_scale(var_trt, trt.ScaleMode.UNIFORM, eps_np, np.ones_like(eps_np), 0.5 * np.ones_like(eps_np)).get_output(0) 73 | 74 | # compute final result 75 | result_trt = ctx.network.add_elementwise(delta_trt, var_trt, trt.ElementWiseOperation.DIV).get_output(0) 76 | 77 | # compute affine (if applicable) 78 | if weight is not None: 79 | 80 | weight_np = weight.detach().cpu().numpy() 81 | bias_np = bias.detach().cpu().numpy() 82 | 83 | result_trt = _add_scale_1d2d3d(ctx.network, result_trt, trt.ScaleMode.CHANNEL, bias_np, weight_np, np.ones_like(bias_np)) 84 | 85 | output._trt = result_trt 86 | 87 | 88 | # STATIC 89 | 90 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3)]) 91 | def test_instance_norm_1d_static(): 92 | return torch.nn.InstanceNorm1d(10, track_running_stats=True) 93 | 94 | 95 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3, 3)]) 96 | def test_instance_norm_2d_static(): 97 | return torch.nn.InstanceNorm2d(10, track_running_stats=True) 98 | 99 | 100 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3, 3, 3)]) 101 | def test_instance_norm_3d_static(): 102 | return torch.nn.InstanceNorm3d(10, track_running_stats=True) 103 | 104 | 105 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3)]) 106 | def test_instance_norm_1d_static_affine(): 107 | return torch.nn.InstanceNorm1d(10, affine=True, track_running_stats=True) 108 | 109 | 110 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3, 3)]) 111 | def test_instance_norm_2d_static_affine(): 112 | return torch.nn.InstanceNorm2d(10, affine=True, track_running_stats=True) 113 | 114 | 115 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3, 3, 3)]) 116 | def test_instance_norm_3d_static_affine(): 117 | return torch.nn.InstanceNorm3d(10, affine=True, track_running_stats=True) 118 | 119 | # DYNAMIC 120 | 121 | # @TODO(jwelsh): 1D dynamic test failing 122 | # @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3)]) 123 | # def test_instance_norm_1d_dynamic(): 124 | # return torch.nn.InstanceNorm1d(10, track_running_stats=False) 125 | 126 | 127 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3, 3)]) 128 | def test_instance_norm_2d_dynamic(): 129 | return torch.nn.InstanceNorm2d(10, track_running_stats=False) 130 | 131 | 132 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3, 3, 3)]) 133 | def test_instance_norm_3d_dynamic(): 134 | return torch.nn.InstanceNorm3d(10, track_running_stats=False) 135 | 136 | 137 | # @TODO(jwelsh): 1D dynamic test failing 138 | # @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3)]) 139 | # def test_instance_norm_1d_dynamic_affine(): 140 | # return torch.nn.InstanceNorm1d(10, affine=True, track_running_stats=False) 141 | 142 | 143 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3, 3)]) 144 | def test_instance_norm_2d_dynamic_affine(): 145 | return torch.nn.InstanceNorm2d(10, affine=True, track_running_stats=False) 146 | 147 | 148 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3, 3, 3)]) 149 | def test_instance_norm_3d_dynamic_affine(): 150 | return torch.nn.InstanceNorm3d(10, affine=True, track_running_stats=False) -------------------------------------------------------------------------------- /torch2trt/converters/interpolate/__init__.py: -------------------------------------------------------------------------------- 1 | from .interpolate import * 2 | -------------------------------------------------------------------------------- /torch2trt/converters/interpolate/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/interpolate/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/interpolate/__pycache__/interpolate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/interpolate/__pycache__/interpolate.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/interpolate/__pycache__/interpolate_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/converters/interpolate/__pycache__/interpolate_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /torch2trt/converters/interpolate/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "interpolate.pb.h" 8 | 9 | 10 | using namespace nvinfer1; 11 | 12 | 13 | namespace torch2trt 14 | { 15 | 16 | class interpolate_Plugin : public IPluginV2 { 17 | private: 18 | interpolate_Message message; 19 | at::TensorOptions tensor_options; 20 | std::vector input_sizes; 21 | std::vector output_sizes; 22 | 23 | public: 24 | interpolate_Plugin(interpolate_Message message) : message(message) {} 25 | 26 | const char* getPluginType() const override { 27 | return "interpolate"; 28 | }; 29 | 30 | const char* getPluginVersion() const override { 31 | return "1"; 32 | } 33 | 34 | int getNbOutputs() const override { 35 | return 1; 36 | } 37 | 38 | Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override { 39 | Dims dims; 40 | dims.nbDims = inputs->nbDims; 41 | 42 | dims.d[0] = inputs->d[0]; 43 | for (int i = 0; i < message.size_size(); i++) { 44 | dims.d[i + 1] = message.size(i); 45 | } 46 | 47 | return dims; 48 | } 49 | 50 | bool supportsFormat(DataType type, PluginFormat format) const override { 51 | if (format != PluginFormat::kNCHW) { 52 | return false; 53 | } 54 | if (type == DataType::kINT32 || type == DataType::kINT8) { 55 | return false; 56 | } 57 | return true; 58 | } 59 | 60 | void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, 61 | int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override { 62 | 63 | // set data type 64 | if (type == DataType::kFLOAT) { 65 | message.set_dtype(DataTypeMessage::kFloat); 66 | } else if (type == DataType::kHALF) { 67 | tensor_options = tensor_options.dtype(c10::kHalf); 68 | message.set_dtype(DataTypeMessage::kHalf); 69 | } 70 | 71 | // set input sizes 72 | for (int i = 0; i < inputDims[0].nbDims; i++) { 73 | message.add_input_size(inputDims[0].d[i]); 74 | } 75 | 76 | // set output sizes 77 | for (int i = 0; i < outputDims[0].nbDims; i++) { 78 | message.add_output_size(outputDims[0].d[i]); 79 | } 80 | } 81 | 82 | int initialize() override { 83 | // set device 84 | tensor_options = tensor_options.device(c10::kCUDA); 85 | 86 | // set data type 87 | if (message.dtype() == DataTypeMessage::kFloat) { 88 | tensor_options = tensor_options.dtype(c10::kFloat); 89 | } else if (message.dtype() == DataTypeMessage::kHalf) { 90 | tensor_options = tensor_options.dtype(c10::kHalf); 91 | } 92 | 93 | input_sizes.resize(message.input_size_size()); 94 | output_sizes.resize(message.output_size_size()); 95 | 96 | for (int i = 0; i < message.input_size_size(); i++) { 97 | input_sizes[i] = message.input_size(i); 98 | } 99 | for (int i = 0; i < message.output_size_size(); i++) { 100 | output_sizes[i] = message.output_size(i); 101 | } 102 | 103 | return 0; 104 | } 105 | 106 | void terminate() override {} 107 | 108 | size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } 109 | 110 | int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { 111 | // get input / output dimensions 112 | std::vector batch_input_sizes = input_sizes; 113 | std::vector batch_output_sizes = output_sizes; 114 | batch_input_sizes.insert(batch_input_sizes.begin(), batchSize); 115 | batch_output_sizes.insert(batch_output_sizes.begin(), batchSize); 116 | 117 | // create tensor wrappers 118 | at::Tensor input = at::from_blob((void*) inputs[0], batch_input_sizes, [](void*){}, tensor_options); 119 | at::Tensor output = at::from_blob(outputs[0], batch_output_sizes, [](void*){}, tensor_options); 120 | 121 | // create new torch cuda stream 122 | at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool(); 123 | at::cuda::CUDAStreamGuard torch_guard(torch_stream); 124 | 125 | // capture current work on tensorrt cuda stream 126 | cudaEvent_t event; 127 | cudaEventCreate(&event); 128 | cudaEventRecord(event, stream); 129 | 130 | // make torch cuda stream wait on tensorrt work 131 | cudaStreamWaitEvent(torch_stream.stream(), event, 0); 132 | 133 | // enqueue work 134 | if (message.mode() == "bilinear") { 135 | at::upsample_bilinear2d_out(output, input, {message.size(0), message.size(1)}, message.align_corners()); 136 | } else if (message.mode() == "nearest") { 137 | at::upsample_nearest2d_out(output, input, {message.size(0), message.size(1)}); 138 | } else if (message.mode() == "area") { 139 | at::adaptive_avg_pool2d_out(output, input, {message.size(0), message.size(1)}); 140 | } else if (message.mode() == "bicubic") { 141 | at::upsample_bicubic2d_out(output, input, {message.size(0), message.size(1)}, message.align_corners()); 142 | } 143 | 144 | // capture event on enqueued stream 145 | cudaEvent_t torch_event; 146 | cudaEventCreate(&torch_event); 147 | cudaEventRecord(torch_event, torch_stream.stream()); 148 | 149 | cudaStreamWaitEvent(stream, torch_event, 0); 150 | 151 | cudaEventDestroy(event); 152 | cudaEventDestroy(torch_event); 153 | 154 | return 0; 155 | } 156 | 157 | size_t getSerializationSize() const override { 158 | return message.SerializeAsString().size(); 159 | } 160 | 161 | void serialize(void* buffer) const override { 162 | message.SerializeToArray(buffer, getSerializationSize()); 163 | } 164 | 165 | void destroy() override {} 166 | 167 | IPluginV2* clone() const override { 168 | return new interpolate_Plugin(message); 169 | } 170 | 171 | void setPluginNamespace(const char* pluginNamespace) override {} 172 | 173 | const char *getPluginNamespace() const override { 174 | return "torch2trt"; 175 | } 176 | 177 | }; 178 | 179 | class interpolate_PluginCreator : public IPluginCreator { 180 | public: 181 | interpolate_PluginCreator() {} 182 | 183 | const char *getPluginNamespace() const override { 184 | return "torch2trt"; 185 | } 186 | 187 | const char *getPluginName() const override { 188 | return "interpolate"; 189 | } 190 | 191 | const char *getPluginVersion() const override { 192 | return "1"; 193 | } 194 | 195 | IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) override { 196 | interpolate_Message message; 197 | message.ParseFromArray(data, length); 198 | return new interpolate_Plugin(message); 199 | } 200 | 201 | void setPluginNamespace(const char *N) override {} 202 | const PluginFieldCollection *getFieldNames() override { return nullptr; } 203 | 204 | IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) override { return nullptr; } 205 | 206 | }; 207 | 208 | REGISTER_TENSORRT_PLUGIN(interpolate_PluginCreator); 209 | 210 | } 211 | -------------------------------------------------------------------------------- /torch2trt/converters/interpolate/interpolate.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | 4 | package torch2trt; 5 | 6 | enum DataTypeMessage { 7 | kFloat = 0; 8 | kHalf = 1; 9 | kInt8 = 2; 10 | kInt32 = 3; 11 | } 12 | 13 | 14 | message interpolate_Message { 15 | repeated int64 size = 1; 16 | string mode = 2; 17 | bool align_corners = 3; 18 | 19 | // below params are configured by TRT and not set by user 20 | DataTypeMessage dtype = 4; 21 | repeated int64 input_size = 5; 22 | repeated int64 output_size = 6; 23 | } 24 | -------------------------------------------------------------------------------- /torch2trt/converters/interpolate/interpolate.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch.nn.functional as F 3 | from torch2trt.torch2trt import * 4 | from torch2trt.module_test import add_module_test 5 | from .interpolate_pb2 import interpolate_Message 6 | import torch.nn as nn 7 | 8 | def get_interpolate_plugin(size, mode, align_corners): 9 | PLUGIN_NAME = 'interpolate' 10 | registry = trt.get_plugin_registry() 11 | creator = [c for c in registry.plugin_creator_list if c.name == PLUGIN_NAME and c.plugin_namespace == 'torch2trt'][0] 12 | message = interpolate_Message(size=size, mode=mode, align_corners=align_corners) 13 | return creator.deserialize_plugin(PLUGIN_NAME, message.SerializeToString()) 14 | 15 | 16 | @tensorrt_converter('torch.nn.functional.interpolate') 17 | def convert_interpolate(ctx): 18 | input = ctx.method_args[0] 19 | input_trt = trt_(ctx.network, input) 20 | output = ctx.method_return 21 | 22 | try: 23 | mode = get_arg(ctx, 'mode', pos=3, default='nearest') 24 | except KeyError: 25 | mode = 'nearest' 26 | 27 | try: 28 | align_corners = get_arg(ctx, 'align_corners', pos=4, default=None) 29 | except KeyError: 30 | align_corners = False 31 | 32 | # currently only works for NCHW 33 | size = list(output.shape[2:]) 34 | 35 | plugin = get_interpolate_plugin(size=size, mode=mode, align_corners=align_corners) 36 | 37 | layer = ctx.network.add_plugin_v2([input_trt], plugin) 38 | 39 | output._trt = layer.get_output(0) 40 | 41 | 42 | class Interpolate(torch.nn.Module): 43 | def __init__(self, size, mode, align_corners): 44 | super(Interpolate, self).__init__() 45 | self.size = size 46 | self.mode = mode 47 | self.align_corners = align_corners 48 | 49 | def forward(self, x): 50 | return F.interpolate(x, self.size, mode=self.mode, align_corners=self.align_corners) 51 | 52 | 53 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)]) 54 | def test_interpolate_nearest(): 55 | return Interpolate((224, 224), 'nearest', None) 56 | 57 | 58 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)]) 59 | def test_interpolate_bilinear(): 60 | return Interpolate((224, 224), 'bilinear', False) 61 | 62 | 63 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)]) 64 | def test_interpolate_bicubic(): 65 | return Interpolate((224, 224), 'bicubic', False) 66 | 67 | 68 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)]) 69 | def test_interpolate_area(): 70 | return Interpolate((56, 56), 'area', None) 71 | 72 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)]) 73 | def test_upsample_scale_factor2(): 74 | return nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False) -------------------------------------------------------------------------------- /torch2trt/converters/interpolate/interpolate_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: torch2trt/converters/interpolate/interpolate.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='torch2trt/converters/interpolate/interpolate.proto', 21 | package='torch2trt', 22 | syntax='proto3', 23 | serialized_pb=_b('\n2torch2trt/converters/interpolate/interpolate.proto\x12\ttorch2trt\"\x9c\x01\n\x13interpolate_Message\x12\x0c\n\x04size\x18\x01 \x03(\x03\x12\x0c\n\x04mode\x18\x02 \x01(\t\x12\x15\n\ralign_corners\x18\x03 \x01(\x08\x12)\n\x05\x64type\x18\x04 \x01(\x0e\x32\x1a.torch2trt.DataTypeMessage\x12\x12\n\ninput_size\x18\x05 \x03(\x03\x12\x13\n\x0boutput_size\x18\x06 \x03(\x03*?\n\x0f\x44\x61taTypeMessage\x12\n\n\x06kFloat\x10\x00\x12\t\n\x05kHalf\x10\x01\x12\t\n\x05kInt8\x10\x02\x12\n\n\x06kInt32\x10\x03\x62\x06proto3') 24 | ) 25 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 26 | 27 | _DATATYPEMESSAGE = _descriptor.EnumDescriptor( 28 | name='DataTypeMessage', 29 | full_name='torch2trt.DataTypeMessage', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | values=[ 33 | _descriptor.EnumValueDescriptor( 34 | name='kFloat', index=0, number=0, 35 | options=None, 36 | type=None), 37 | _descriptor.EnumValueDescriptor( 38 | name='kHalf', index=1, number=1, 39 | options=None, 40 | type=None), 41 | _descriptor.EnumValueDescriptor( 42 | name='kInt8', index=2, number=2, 43 | options=None, 44 | type=None), 45 | _descriptor.EnumValueDescriptor( 46 | name='kInt32', index=3, number=3, 47 | options=None, 48 | type=None), 49 | ], 50 | containing_type=None, 51 | options=None, 52 | serialized_start=224, 53 | serialized_end=287, 54 | ) 55 | _sym_db.RegisterEnumDescriptor(_DATATYPEMESSAGE) 56 | 57 | DataTypeMessage = enum_type_wrapper.EnumTypeWrapper(_DATATYPEMESSAGE) 58 | kFloat = 0 59 | kHalf = 1 60 | kInt8 = 2 61 | kInt32 = 3 62 | 63 | 64 | 65 | _INTERPOLATE_MESSAGE = _descriptor.Descriptor( 66 | name='interpolate_Message', 67 | full_name='torch2trt.interpolate_Message', 68 | filename=None, 69 | file=DESCRIPTOR, 70 | containing_type=None, 71 | fields=[ 72 | _descriptor.FieldDescriptor( 73 | name='size', full_name='torch2trt.interpolate_Message.size', index=0, 74 | number=1, type=3, cpp_type=2, label=3, 75 | has_default_value=False, default_value=[], 76 | message_type=None, enum_type=None, containing_type=None, 77 | is_extension=False, extension_scope=None, 78 | options=None), 79 | _descriptor.FieldDescriptor( 80 | name='mode', full_name='torch2trt.interpolate_Message.mode', index=1, 81 | number=2, type=9, cpp_type=9, label=1, 82 | has_default_value=False, default_value=_b("").decode('utf-8'), 83 | message_type=None, enum_type=None, containing_type=None, 84 | is_extension=False, extension_scope=None, 85 | options=None), 86 | _descriptor.FieldDescriptor( 87 | name='align_corners', full_name='torch2trt.interpolate_Message.align_corners', index=2, 88 | number=3, type=8, cpp_type=7, label=1, 89 | has_default_value=False, default_value=False, 90 | message_type=None, enum_type=None, containing_type=None, 91 | is_extension=False, extension_scope=None, 92 | options=None), 93 | _descriptor.FieldDescriptor( 94 | name='dtype', full_name='torch2trt.interpolate_Message.dtype', index=3, 95 | number=4, type=14, cpp_type=8, label=1, 96 | has_default_value=False, default_value=0, 97 | message_type=None, enum_type=None, containing_type=None, 98 | is_extension=False, extension_scope=None, 99 | options=None), 100 | _descriptor.FieldDescriptor( 101 | name='input_size', full_name='torch2trt.interpolate_Message.input_size', index=4, 102 | number=5, type=3, cpp_type=2, label=3, 103 | has_default_value=False, default_value=[], 104 | message_type=None, enum_type=None, containing_type=None, 105 | is_extension=False, extension_scope=None, 106 | options=None), 107 | _descriptor.FieldDescriptor( 108 | name='output_size', full_name='torch2trt.interpolate_Message.output_size', index=5, 109 | number=6, type=3, cpp_type=2, label=3, 110 | has_default_value=False, default_value=[], 111 | message_type=None, enum_type=None, containing_type=None, 112 | is_extension=False, extension_scope=None, 113 | options=None), 114 | ], 115 | extensions=[ 116 | ], 117 | nested_types=[], 118 | enum_types=[ 119 | ], 120 | options=None, 121 | is_extendable=False, 122 | syntax='proto3', 123 | extension_ranges=[], 124 | oneofs=[ 125 | ], 126 | serialized_start=66, 127 | serialized_end=222, 128 | ) 129 | 130 | _INTERPOLATE_MESSAGE.fields_by_name['dtype'].enum_type = _DATATYPEMESSAGE 131 | DESCRIPTOR.message_types_by_name['interpolate_Message'] = _INTERPOLATE_MESSAGE 132 | DESCRIPTOR.enum_types_by_name['DataTypeMessage'] = _DATATYPEMESSAGE 133 | 134 | interpolate_Message = _reflection.GeneratedProtocolMessageType('interpolate_Message', (_message.Message,), dict( 135 | DESCRIPTOR = _INTERPOLATE_MESSAGE, 136 | __module__ = 'torch2trt.converters.interpolate.interpolate_pb2' 137 | # @@protoc_insertion_point(class_scope:torch2trt.interpolate_Message) 138 | )) 139 | _sym_db.RegisterMessage(interpolate_Message) 140 | 141 | 142 | # @@protoc_insertion_point(module_scope) 143 | -------------------------------------------------------------------------------- /torch2trt/converters/max.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | from .unary import UnaryModule 4 | 5 | 6 | def __convert_max_elementwise(ctx): 7 | input_a = ctx.method_args[0] 8 | input_b = ctx.method_args[1] 9 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 10 | output = ctx.method_return 11 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.MAX) 12 | output._trt = layer.get_output(0) 13 | 14 | 15 | def __convert_max_reduce(ctx): 16 | input = ctx.method_args[0] 17 | dim = get_arg(ctx, 'dim', pos=1, default=tuple(range(1, input.ndim))) 18 | keepdim = get_arg(ctx, 'keepdim', pos=2, default=False) 19 | input_trt= trt_(ctx.network, input) 20 | output_val = ctx.method_return[0] 21 | output_idx = ctx.method_return[1] 22 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.MAX, torch_dim_to_trt_axes(dim), keepdim) 23 | output_val._trt = layer.get_output(0) 24 | 25 | 26 | @tensorrt_converter('torch.max') 27 | @tensorrt_converter('torch.Tensor.max') 28 | def convert_max(ctx): 29 | if len(ctx.method_args) > 1 and isinstance(ctx.method_args[1], torch.Tensor): 30 | __convert_max_elementwise(ctx) 31 | else: 32 | __convert_max_reduce(ctx) 33 | 34 | 35 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 36 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 37 | def test_max_reduce_dim1(): 38 | return UnaryModule(lambda x: torch.max(x, 1)[0]) 39 | 40 | 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 42 | def test_max_reduce_dim22(): 43 | return UnaryModule(lambda x: torch.max(x, 2)[0]) 44 | 45 | 46 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 47 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 48 | def test_max_reduce_dim1_keepdim(): 49 | return UnaryModule(lambda x: torch.max(x, 1, keepdim=True)[0]) 50 | 51 | 52 | class MaxElementwise(torch.nn.Module): 53 | def forward(self, x, y): 54 | return torch.max(x, y) 55 | 56 | 57 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3), (1, 3, 3)]) 58 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3), (1,)]) # broadcast 59 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3), (1, 3, 3)]) # broadcast 60 | def test_max_elementwise(): 61 | return MaxElementwise() -------------------------------------------------------------------------------- /torch2trt/converters/max_pool2d.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.functional.max_pool2d') 6 | def convert_max_pool2d(ctx): 7 | # parse args 8 | input = get_arg(ctx, 'input', pos=0, default=None) 9 | kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=None) 10 | stride = get_arg(ctx, 'stride', pos=2, default=None) 11 | padding = get_arg(ctx, 'padding', pos=3, default=0) 12 | dilation = get_arg(ctx, 'dilation', pos=4, default=1) 13 | ceil_mode = get_arg(ctx, 'ceil_mode', pos=5, default=False) 14 | 15 | # get input trt tensor (or create constant if it doesn't exist) 16 | input_trt = trt_(ctx.network, input) 17 | 18 | output = ctx.method_return 19 | 20 | # get kernel size 21 | if not isinstance(kernel_size, tuple): 22 | kernel_size = (kernel_size, ) * 2 23 | 24 | # get stride 25 | if not isinstance(stride, tuple): 26 | stride = (stride, ) * 2 27 | 28 | # get padding 29 | if not isinstance(padding, tuple): 30 | padding = (padding, ) * 2 31 | 32 | layer = ctx.network.add_pooling( 33 | input=input_trt, type=trt.PoolingType.MAX, window_size=kernel_size) 34 | 35 | layer.stride = stride 36 | layer.padding = padding 37 | 38 | if ceil_mode: 39 | layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP 40 | 41 | output._trt = layer.get_output(0) 42 | 43 | 44 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 6)]) 45 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 5, 7)]) 46 | def test_MaxPool2d_without_ceil_mode(): 47 | return torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) 48 | 49 | 50 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 6)]) 51 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 5, 7)]) 52 | def test_MaxPool2d_with_ceil_mode(): 53 | return torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) -------------------------------------------------------------------------------- /torch2trt/converters/mean.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.mean') 6 | @tensorrt_converter('torch.Tensor.mean') 7 | def convert_mean(ctx): 8 | input = ctx.method_args[0] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | 12 | # get dims from args or kwargs 13 | if 'dim' in ctx.method_kwargs: 14 | dim = ctx.method_kwargs['dim'] 15 | elif len(ctx.method_args) >= 2: 16 | dim = ctx.method_args[1] 17 | 18 | # convert list to tuple 19 | if isinstance(dim, list): 20 | dim = tuple(dim) 21 | 22 | if not isinstance(dim, tuple): 23 | dim = (dim, ) 24 | 25 | # create axes bitmask for reduce layer 26 | axes = 0 27 | for d in dim: 28 | axes |= 1 << (d - 1) # -1 to remove batch dimension 29 | 30 | # get whether to keep dimensions 31 | if 'keepdim' in ctx.method_kwargs: 32 | keep_dims = ctx.method_kwargs['keepdim'] 33 | elif len(ctx.method_args) == 3: 34 | keep_dims = ctx.method_args[2] 35 | else: 36 | keep_dims = False 37 | 38 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.AVG, axes, keep_dims) 39 | output._trt = layer.get_output(0) 40 | 41 | 42 | class Mean(torch.nn.Module): 43 | def __init__(self, dim, keepdim): 44 | super(Mean, self).__init__() 45 | self.dim = dim 46 | self.keepdim = keepdim 47 | def forward(self, x): 48 | return x.mean(self.dim, self.keepdim) 49 | 50 | 51 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 52 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 53 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 54 | def test_mean_channel(): 55 | return Mean(1, False) 56 | 57 | 58 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 59 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 60 | def test_mean_tuple(): 61 | return Mean((1, 2), False) 62 | 63 | 64 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 65 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 66 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 67 | def test_mean_keepdim(): 68 | return Mean(1, True) -------------------------------------------------------------------------------- /torch2trt/converters/min.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | from .unary import UnaryModule 4 | 5 | 6 | def __convert_min_elementwise(ctx): 7 | input_a = ctx.method_args[0] 8 | input_b = ctx.method_args[1] 9 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 10 | output = ctx.method_return 11 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.MIN) 12 | output._trt = layer.get_output(0) 13 | 14 | 15 | def __convert_min_reduce(ctx): 16 | input = ctx.method_args[0] 17 | dim = get_arg(ctx, 'dim', pos=1, default=tuple(range(1,input.ndim))) 18 | keepdim = get_arg(ctx, 'keepdim', pos=2, default=False) 19 | input_trt= trt_(ctx.network, input) 20 | output_val = ctx.method_return[0] 21 | output_idx = ctx.method_return[1] 22 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.MIN, torch_dim_to_trt_axes(dim), keepdim) 23 | output_val._trt = layer.get_output(0) 24 | 25 | 26 | @tensorrt_converter('torch.min') 27 | @tensorrt_converter('torch.Tensor.min') 28 | def convert_min(ctx): 29 | if len(ctx.method_args) > 1 and isinstance(ctx.method_args[1], torch.Tensor): 30 | __convert_min_elementwise(ctx) 31 | else: 32 | __convert_min_reduce(ctx) 33 | 34 | 35 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 36 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 37 | def test_min_reduce_dim1(): 38 | return UnaryModule(lambda x: torch.min(x, 1)[0]) 39 | 40 | 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 42 | def test_min_reduce_dim22(): 43 | return UnaryModule(lambda x: torch.min(x, 2)[0]) 44 | 45 | 46 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 47 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 48 | def test_min_reduce_dim1_keepdim(): 49 | return UnaryModule(lambda x: torch.min(x, 1, keepdim=True)[0]) 50 | 51 | 52 | class MinElementwise(torch.nn.Module): 53 | def forward(self, x, y): 54 | return torch.min(x, y) 55 | 56 | 57 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3), (1, 3, 3)]) 58 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3), (1,)]) # broadcast 59 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3), (1, 3, 3)]) # broadcast 60 | def test_min_elementwise(): 61 | return MinElementwise() -------------------------------------------------------------------------------- /torch2trt/converters/mul.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.mul') 6 | @tensorrt_converter('torch.Tensor.__imul__') 7 | @tensorrt_converter('torch.Tensor.__mul__') 8 | @tensorrt_converter('torch.Tensor.__rmul__') 9 | def convert_mul(ctx): 10 | input_a = ctx.method_args[0] 11 | input_b = ctx.method_args[1] 12 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 13 | output = ctx.method_return 14 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.PROD) 15 | output._trt = layer.get_output(0) 16 | 17 | 18 | class Mul(torch.nn.Module): 19 | def __init__(self): 20 | super(Mul, self).__init__() 21 | 22 | def forward(self, x, y): 23 | return x * y 24 | 25 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 26 | def test_mul_basic(): 27 | return Mul() 28 | 29 | 30 | class IMul(torch.nn.Module): 31 | def __init__(self): 32 | super(IMul, self).__init__() 33 | 34 | def forward(self, x, y): 35 | x *= y 36 | return x 37 | 38 | 39 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 40 | def test_mul_imul(): 41 | return IMul() 42 | 43 | 44 | class TorchMul(torch.nn.Module): 45 | def __init__(self): 46 | super(TorchMul, self).__init__() 47 | 48 | def forward(self, x, y): 49 | return torch.mul(x, y) 50 | 51 | 52 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 53 | def test_mul_torchmul(): 54 | return TorchMul() 55 | 56 | 57 | class RMulInt(torch.nn.Module): 58 | def __init__(self): 59 | super(RMulInt, self).__init__() 60 | 61 | def forward(self, x): 62 | return 10 * x 63 | 64 | 65 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 66 | def test_rmul_int(): 67 | return RMulInt() 68 | 69 | 70 | class RMulFloat(torch.nn.Module): 71 | def __init__(self): 72 | super(RMulFloat, self).__init__() 73 | 74 | def forward(self, x): 75 | return 10.0 * x 76 | 77 | 78 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 79 | def test_rmul_float(): 80 | return RMulFloat() -------------------------------------------------------------------------------- /torch2trt/converters/normalize.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.functional.normalize') 6 | def convert_normalize(ctx): 7 | # get args 8 | input = get_arg(ctx, 'input', pos=0, default=None) 9 | p = get_arg(ctx, 'p', pos=1, default=2) 10 | dim = get_arg(ctx, 'dim', pos=2, default=1) 11 | eps = get_arg(ctx, 'eps', pos=3, default=1e-12) 12 | 13 | # input_trt = input._trt 14 | output = ctx.method_return 15 | 16 | # add broadcastable scalar constants to network 17 | input_trt, eps_trt, p_trt, p_inv_trt = trt_(ctx.network, input, eps, p, 1.0 / p) 18 | 19 | # compute norm = sum(abs(x)**p, dim=dim)**(1./p) 20 | norm = ctx.network.add_unary(input_trt, trt.UnaryOperation.ABS).get_output(0) 21 | norm = ctx.network.add_elementwise(norm, p_trt, trt.ElementWiseOperation.POW).get_output(0) 22 | norm = ctx.network.add_reduce(norm, trt.ReduceOperation.SUM, torch_dim_to_trt_axes(dim), keep_dims=True).get_output(0) 23 | norm = ctx.network.add_elementwise(norm, p_inv_trt, trt.ElementWiseOperation.POW).get_output(0) 24 | 25 | # clamp norm = max(norm, eps) 26 | norm = ctx.network.add_elementwise(norm, eps_trt, trt.ElementWiseOperation.MAX).get_output(0) 27 | 28 | # divide input by norm 29 | output._trt = ctx.network.add_elementwise(input_trt, norm, trt.ElementWiseOperation.DIV).get_output(0) 30 | 31 | 32 | class Normalize(torch.nn.Module): 33 | def __init__(self, *args, **kwargs): 34 | super(Normalize, self).__init__() 35 | self.args = args 36 | self.kwargs = kwargs 37 | 38 | def forward(self, x): 39 | return torch.nn.functional.normalize(x, *self.args, **self.kwargs) 40 | 41 | 42 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 43 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 44 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 45 | def test_normalize_basic(): 46 | return Normalize() 47 | 48 | 49 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 50 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 51 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 52 | def test_normalize_l1_basic(): 53 | return Normalize(p=1.0) 54 | 55 | 56 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 57 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 58 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 59 | def test_normalize_l1p5_basic(): 60 | return Normalize(p=1.5) 61 | 62 | 63 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 64 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 65 | def test_normalize_l2_height(): 66 | return Normalize(p=2.0, dim=2) -------------------------------------------------------------------------------- /torch2trt/converters/pad.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.functional.pad') 6 | def convert_pad(ctx): 7 | input = ctx.method_args[0] 8 | input_trt = trt_(ctx.network, input) 9 | output = ctx.method_return 10 | 11 | pad = ctx.method_args[1] 12 | pre_padding = (pad[2], pad[0]) 13 | post_padding = (pad[3], pad[1]) 14 | 15 | # mode / value are ignored since not supported by TensorRT 16 | 17 | layer = ctx.network.add_padding(input_trt, pre_padding, post_padding) 18 | output._trt = layer.get_output(0) 19 | 20 | 21 | class Pad(torch.nn.Module): 22 | 23 | def __init__(self, pad): 24 | super(Pad, self).__init__() 25 | self.pad = pad 26 | 27 | def forward(self, x): 28 | return torch.nn.functional.pad(x, self.pad) 29 | 30 | 31 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 32 | def test_pad_basic(): 33 | return Pad((1, 2, 3, 4)) -------------------------------------------------------------------------------- /torch2trt/converters/permute.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.Tensor.permute') 6 | def convert_permute(ctx): 7 | input = ctx.method_args[0] 8 | input_trt = trt_(ctx.network, input) 9 | output = ctx.method_return 10 | 11 | # permutation -1 because TRT does not include batch dim 12 | if isinstance(ctx.method_args[1], int): 13 | permutation = tuple(ctx.method_args[1:]) # handle permute(a, b, c) 14 | else: 15 | permutation = tuple(ctx.method_args[1]) # handle permute([a, b, c]) 16 | 17 | assert(permutation[0] == 0) # cannot move batch dim 18 | 19 | trt_permutation = tuple([p - 1 for p in permutation])[1:] 20 | 21 | layer = ctx.network.add_shuffle(input_trt) 22 | layer.second_transpose = tuple(trt_permutation) 23 | 24 | output._trt = layer.get_output(0) 25 | 26 | 27 | class Permute(torch.nn.Module): 28 | def __init__(self, *args): 29 | super(Permute, self).__init__() 30 | self.args = args 31 | def forward(self, x): 32 | return x.permute(*self.args).contiguous() 33 | 34 | 35 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)]) 36 | def test_permute_2d_0123(): 37 | return Permute(0, 1, 2, 3) 38 | 39 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)]) 40 | def test_permute_2d_0312(): 41 | return Permute(0, 3, 1, 2) 42 | 43 | 44 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5, 6)]) 45 | def test_permute_3d_01234(): 46 | return Permute(0, 1, 2, 3, 4) 47 | 48 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5, 6)]) 49 | def test_permute_3d_04132(): 50 | return Permute(0, 4, 1, 3, 2) 51 | 52 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5, 6)]) 53 | def test_permute_list(): 54 | return Permute([0, 4, 1, 3, 2]) 55 | 56 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5, 6)]) 57 | def test_permute_tuple(): 58 | return Permute((0, 4, 1, 3, 2)) -------------------------------------------------------------------------------- /torch2trt/converters/pow.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.pow') 6 | @tensorrt_converter('torch.Tensor.__ipow__') 7 | @tensorrt_converter('torch.Tensor.__pow__') 8 | def convert_pow(ctx): 9 | input_a = ctx.method_args[0] 10 | input_b = ctx.method_args[1] 11 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 12 | output = ctx.method_return 13 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.POW) 14 | output._trt = layer.get_output(0) 15 | 16 | 17 | @tensorrt_converter('torch.Tensor.__rpow__') 18 | def convert_pow(ctx): 19 | input_a = ctx.method_args[1] 20 | input_b = ctx.method_args[0] # flipped for rpow 21 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 22 | output = ctx.method_return 23 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.POW) 24 | output._trt = layer.get_output(0) 25 | 26 | 27 | class Pow(torch.nn.Module): 28 | def __init__(self): 29 | super(Pow, self).__init__() 30 | 31 | def forward(self, x, y): 32 | return x ** y 33 | 34 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 35 | def test_pow_basic(): 36 | return Pow() 37 | 38 | 39 | # __ipow__ not yet impl in torch 40 | # class IPow(torch.nn.Module): 41 | # def __init__(self): 42 | # super(IPow, self).__init__() 43 | 44 | # def forward(self, x, y): 45 | # x **= y 46 | # return x 47 | 48 | 49 | # @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 50 | # def test_pow_ipow(): 51 | # return IPow() 52 | 53 | 54 | class TorchPow(torch.nn.Module): 55 | def __init__(self): 56 | super(TorchPow, self).__init__() 57 | 58 | def forward(self, x, y): 59 | return torch.pow(x, y) 60 | 61 | 62 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 63 | def test_torch_pow(): 64 | return TorchPow() 65 | 66 | 67 | class RpowInt(torch.nn.Module): 68 | def __init__(self): 69 | super(RpowInt, self).__init__() 70 | 71 | def forward(self, x): 72 | return 2 ** x 73 | 74 | 75 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 76 | def test_rpow_int(): 77 | return RpowInt() 78 | 79 | 80 | class RpowFloat(torch.nn.Module): 81 | def __init__(self): 82 | super(RpowFloat, self).__init__() 83 | 84 | def forward(self, x): 85 | return 2.0 ** x 86 | 87 | 88 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 89 | def test_rpow_float(): 90 | return RpowFloat() -------------------------------------------------------------------------------- /torch2trt/converters/prelu.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.functional.prelu') 6 | def convert_prelu(ctx): 7 | input = get_arg(ctx, 'input', pos=0, default=None) 8 | weight = get_arg(ctx, 'weight', pos=1, default=None) 9 | output = ctx.method_return 10 | 11 | weight_shape = [1] * (len(input.shape) - 1) 12 | weight_shape[0] = weight.numel() 13 | 14 | input_trt = trt_(ctx.network, input) 15 | 16 | 17 | # y = prelu(x) = relu(x) - alpha * relu(-x) 18 | weight_trt = ctx.network.add_constant(weight_shape, -weight.detach().view(weight_shape).cpu().numpy()).get_output(0) # detach so considered leaf 19 | 20 | # x >= 0 21 | a = ctx.network.add_activation(input_trt, trt.ActivationType.RELU).get_output(0) 22 | 23 | # x <= 0 24 | b = ctx.network.add_unary(input_trt, trt.UnaryOperation.NEG).get_output(0) 25 | b = ctx.network.add_activation(b, trt.ActivationType.RELU).get_output(0) 26 | b = ctx.network.add_elementwise(b, weight_trt, trt.ElementWiseOperation.PROD).get_output(0) 27 | 28 | # y = a + b 29 | y = ctx.network.add_elementwise(a, b, trt.ElementWiseOperation.SUM) 30 | 31 | output._trt = y.get_output(0) 32 | 33 | 34 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5)]) 35 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 36 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3, 3)]) 37 | def test_prelu_scalar(): 38 | return torch.nn.PReLU() 39 | 40 | 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5)]) 42 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 43 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3, 3)]) 44 | def test_prelu_vector(): 45 | m = torch.nn.PReLU(5) 46 | m.weight = torch.nn.Parameter(torch.randn(5)) # randn so each channel different 47 | return m -------------------------------------------------------------------------------- /torch2trt/converters/prod.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | from .unary import UnaryModule 4 | 5 | 6 | @tensorrt_converter('torch.prod') 7 | @tensorrt_converter('torch.Tensor.prod') 8 | def convert_prod(ctx): 9 | input = ctx.method_args[0] 10 | dim = get_arg(ctx, 'dim', pos=1, default=tuple(range(1, input.ndim))) 11 | keepdim = get_arg(ctx, 'keepdim', pos=2, default=False) 12 | input_trt= trt_(ctx.network, input) 13 | output = ctx.method_return 14 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.PROD, torch_dim_to_trt_axes(dim), keepdim) 15 | output._trt = layer.get_output(0) 16 | 17 | 18 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 19 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 20 | def test_prod_reduce_all(): 21 | return UnaryModule(lambda x: torch.prod(x)) 22 | 23 | 24 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 25 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 26 | def test_prod_reduce_dim1(): 27 | return UnaryModule(lambda x: torch.prod(x, 1)) 28 | 29 | 30 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 31 | def test_prod_reduce_dim22(): 32 | return UnaryModule(lambda x: torch.prod(x, 2)) 33 | 34 | 35 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 36 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 37 | def test_prod_reduce_dim1_keepdim(): 38 | return UnaryModule(lambda x: torch.prod(x, 1, keepdim=True)) -------------------------------------------------------------------------------- /torch2trt/converters/sigmoid.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.functional.sigmoid') 6 | @tensorrt_converter('torch.sigmoid') 7 | def convert_sigmoid(ctx): 8 | input = ctx.method_args[0] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | 12 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.SIGMOID) 13 | output._trt = layer.get_output(0) 14 | 15 | 16 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 17 | def test_sigmoid_basic(): 18 | return torch.nn.Sigmoid() -------------------------------------------------------------------------------- /torch2trt/converters/softmax.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.functional.softmax') 6 | def convert_softmax(ctx): 7 | input = ctx.method_args[0] 8 | input_trt = trt_(ctx.network, input) 9 | output = ctx.method_return 10 | 11 | # get dims from args or kwargs 12 | if 'dim' in ctx.method_kwargs: 13 | dim = ctx.method_kwargs['dim'] 14 | elif len(ctx.method_args) >= 2: 15 | dim = ctx.method_args[1] 16 | 17 | axes = 1 << (dim - 1) 18 | 19 | layer = ctx.network.add_softmax(input=input_trt) 20 | layer.axes = axes 21 | 22 | output._trt = layer.get_output(0) 23 | 24 | 25 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 26 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 27 | def test_softmax_module(): 28 | return torch.nn.Softmax(1) 29 | 30 | 31 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 32 | def test_softmax_module_dim2(): 33 | return torch.nn.Softmax(2) 34 | -------------------------------------------------------------------------------- /torch2trt/converters/split.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.split') 6 | @tensorrt_converter('torch.Tensor.split') 7 | def convert_split(ctx): 8 | input = get_arg(ctx, 'input', 0, None) 9 | input_trt = trt_(ctx.network, input) 10 | # we don't need to parse split/chunk (arg 1) 11 | # since we infer size from output tensors 12 | dim = get_arg(ctx, 'dim', 2, 0) 13 | 14 | outputs = ctx.method_return 15 | 16 | assert(dim >= 1) 17 | 18 | start = [0] * len(input.shape[1:]) # exclude batch 19 | stride = [1] * len(start) 20 | offset = 0 21 | trt_dim = dim - 1 22 | 23 | # add slice layers 24 | for i, output in enumerate(outputs): 25 | shape = list(output.shape[1:]) # exclude batch dim 26 | start[trt_dim] = offset 27 | layer = ctx.network.add_slice(input_trt, start=start, shape=shape, stride=stride) 28 | output._trt = layer.get_output(0) 29 | offset = offset + shape[trt_dim] 30 | 31 | 32 | class TorchSplit(torch.nn.Module): 33 | 34 | def __init__(self, *args, **kwargs): 35 | super(TorchSplit, self).__init__() 36 | self.args = args 37 | self.kwargs = kwargs 38 | 39 | def forward(self, x): 40 | return torch.split(x, *self.args, **self.kwargs) 41 | 42 | 43 | class TensorSplit(torch.nn.Module): 44 | 45 | def __init__(self, *args, **kwargs): 46 | super(TensorSplit, self).__init__() 47 | self.args = args 48 | self.kwargs = kwargs 49 | 50 | def forward(self, x): 51 | return x.split(*self.args, **self.kwargs) 52 | 53 | 54 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 55 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 56 | def test_torch_split_1_1(): 57 | return TorchSplit(1, 1) 58 | 59 | 60 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 61 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 62 | def test_torch_split_2_1(): 63 | return TorchSplit(2, 1) 64 | 65 | 66 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 67 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 68 | def test_torch_split_3_1(): 69 | return TorchSplit(3, 1) 70 | 71 | 72 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 73 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 74 | def test_torch_split_3_2(): 75 | return TorchSplit(3, 2) 76 | 77 | 78 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 79 | def test_tensor_split_3_2(): 80 | return TensorSplit(3, 2) -------------------------------------------------------------------------------- /torch2trt/converters/sub.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.sub') 6 | @tensorrt_converter('torch.Tensor.__isub__') 7 | @tensorrt_converter('torch.Tensor.__sub__') 8 | def convert_sub(ctx): 9 | input_a = ctx.method_args[0] 10 | input_b = ctx.method_args[1] 11 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 12 | output = ctx.method_return 13 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.SUB) 14 | output._trt = layer.get_output(0) 15 | 16 | 17 | @tensorrt_converter('torch.Tensor.__rsub__') 18 | def convert_sub(ctx): 19 | input_a = ctx.method_args[1] 20 | input_b = ctx.method_args[0] # flipped for rsub 21 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 22 | output = ctx.method_return 23 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.SUB) 24 | output._trt = layer.get_output(0) 25 | 26 | 27 | class Sub(torch.nn.Module): 28 | def __init__(self): 29 | super(Sub, self).__init__() 30 | 31 | def forward(self, x, y): 32 | return x - y 33 | 34 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 35 | def test_sub_basic(): 36 | return Sub() 37 | 38 | 39 | class ISub(torch.nn.Module): 40 | def __init__(self): 41 | super(ISub, self).__init__() 42 | 43 | def forward(self, x, y): 44 | x -= y 45 | return x 46 | 47 | 48 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 49 | def test_sub_isub(): 50 | return ISub() 51 | 52 | 53 | class TorchSub(torch.nn.Module): 54 | def __init__(self): 55 | super(TorchSub, self).__init__() 56 | 57 | def forward(self, x, y): 58 | return torch.sub(x, y) 59 | 60 | 61 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) 62 | def test_torch_sub(): 63 | return TorchSub() 64 | 65 | 66 | class RSubInt(torch.nn.Module): 67 | def __init__(self): 68 | super(RSubInt, self).__init__() 69 | 70 | def forward(self, x): 71 | return 1 - x 72 | 73 | 74 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 75 | def test_rsub_int(): 76 | return RSubInt() 77 | 78 | 79 | class RSubFloat(torch.nn.Module): 80 | def __init__(self): 81 | super(RSubFloat, self).__init__() 82 | 83 | def forward(self, x): 84 | return 1.0 - x 85 | 86 | 87 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 88 | def test_rsub_float(): 89 | return RSubFloat() -------------------------------------------------------------------------------- /torch2trt/converters/sum.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | from .unary import UnaryModule 4 | 5 | 6 | @tensorrt_converter('torch.sum') 7 | @tensorrt_converter('torch.Tensor.sum') 8 | def convert_sum(ctx): 9 | input = ctx.method_args[0] 10 | dim = get_arg(ctx, 'dim', pos=1, default=tuple(range(1, input.ndim))) 11 | keepdim = get_arg(ctx, 'keepdim', pos=2, default=False) 12 | input_trt= trt_(ctx.network, input) 13 | output = ctx.method_return 14 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.SUM, torch_dim_to_trt_axes(dim), keepdim) 15 | output._trt = layer.get_output(0) 16 | 17 | 18 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 19 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 20 | def test_sum_reduce_all(): 21 | return UnaryModule(lambda x: torch.sum(x)) 22 | 23 | 24 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 25 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 26 | def test_sum_reduce_dim1(): 27 | return UnaryModule(lambda x: torch.sum(x, 1)) 28 | 29 | 30 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 31 | def test_sum_reduce_dim22(): 32 | return UnaryModule(lambda x: torch.sum(x, 2)) 33 | 34 | 35 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 36 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 37 | def test_sum_reduce_dim1_keepdim(): 38 | return UnaryModule(lambda x: torch.sum(x, 1, keepdim=True)) -------------------------------------------------------------------------------- /torch2trt/converters/tanh.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.nn.functional.tanh') 6 | @tensorrt_converter('torch.tanh') 7 | def convert_tanh(ctx): 8 | input = ctx.method_args[0] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | 12 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.TANH) 13 | output._trt = layer.get_output(0) 14 | 15 | 16 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 17 | def test_tanh_basic(): 18 | return torch.nn.Tanh() -------------------------------------------------------------------------------- /torch2trt/converters/transpose.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.transpose') 6 | def convert_transpose(ctx): 7 | input = ctx.method_args[0] 8 | input_trt = trt_(ctx.network, input) 9 | output = ctx.method_return 10 | # permutation -1 because TRT does not include batch dim 11 | permutation = list(range(len(input.shape) - 1)) 12 | dim0 = ctx.method_args[1] - 1 13 | dim1 = ctx.method_args[2] - 1 14 | permutation[dim0] = dim1 15 | permutation[dim1] = dim0 16 | layer = ctx.network.add_shuffle(input_trt) 17 | layer.second_transpose = tuple(permutation) 18 | output._trt = layer.get_output(0) 19 | 20 | 21 | class Transpose(torch.nn.Module): 22 | def __init__(self, dim0, dim1): 23 | super(Transpose, self).__init__() 24 | self.dim0 = dim0 25 | self.dim1 = dim1 26 | def forward(self, x): 27 | return torch.transpose(x, self.dim0, self.dim1).contiguous() 28 | 29 | 30 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 31 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 32 | def test_transpose_12(): 33 | return Transpose(1, 2) 34 | -------------------------------------------------------------------------------- /torch2trt/converters/unary.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | def __convert_unary(ctx, op): 6 | input = get_arg(ctx, 'input', pos=0, default=None) 7 | input_trt = trt_(ctx.network, input) 8 | output = ctx.method_return 9 | layer = ctx.network.add_unary(input_trt, op) 10 | output._trt = layer.get_output(0) 11 | 12 | 13 | class UnaryModule(torch.nn.Module): 14 | def __init__(self, fn): 15 | super(UnaryModule, self).__init__() 16 | self.fn = fn 17 | 18 | def forward(self, x): 19 | return self.fn(x) 20 | 21 | # EXP : Exponentiation 22 | 23 | 24 | @tensorrt_converter('torch.exp') 25 | @tensorrt_converter('torch.exp_') 26 | @tensorrt_converter('torch.Tensor.exp') 27 | @tensorrt_converter('torch.Tensor.exp_') 28 | def convert_exp(ctx): 29 | __convert_unary(ctx, trt.UnaryOperation.EXP) 30 | 31 | 32 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 33 | def test_exp(): 34 | return UnaryModule(lambda x: torch.exp(x)) 35 | 36 | 37 | # LOG : Log (base e) 38 | 39 | 40 | @tensorrt_converter('torch.log') 41 | @tensorrt_converter('torch.log_') 42 | @tensorrt_converter('torch.Tensor.log') 43 | @tensorrt_converter('torch.Tensor.log_') 44 | def convert_log(ctx): 45 | __convert_unary(ctx, trt.UnaryOperation.LOG) 46 | 47 | 48 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 49 | def test_log(): 50 | return UnaryModule(lambda x: torch.log(x)) 51 | 52 | 53 | # SQRT : Square root 54 | 55 | 56 | @tensorrt_converter('torch.sqrt') 57 | @tensorrt_converter('torch.sqrt_') 58 | @tensorrt_converter('torch.Tensor.sqrt') 59 | @tensorrt_converter('torch.Tensor.sqrt_') 60 | def convert_sqrt(ctx): 61 | __convert_unary(ctx, trt.UnaryOperation.SQRT) 62 | 63 | 64 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 65 | def test_sqrt(): 66 | return UnaryModule(lambda x: torch.sqrt(x)) 67 | 68 | 69 | # RECIP : Reciprocal 70 | 71 | 72 | @tensorrt_converter('torch.reciprocal') 73 | @tensorrt_converter('torch.reciprocal_') 74 | @tensorrt_converter('torch.Tensor.reciprocal') 75 | @tensorrt_converter('torch.Tensor.reciprocal_') 76 | def convert_reciprocal(ctx): 77 | __convert_unary(ctx, trt.UnaryOperation.RECIP) 78 | 79 | 80 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 81 | def test_reciprocal(): 82 | return UnaryModule(lambda x: torch.reciprocal(x)) 83 | 84 | 85 | # ABS : Absolute value 86 | 87 | 88 | @tensorrt_converter('torch.abs') 89 | @tensorrt_converter('torch.abs_') 90 | @tensorrt_converter('torch.Tensor.abs') 91 | @tensorrt_converter('torch.Tensor.abs_') 92 | def convert_abs(ctx): 93 | __convert_unary(ctx, trt.UnaryOperation.ABS) 94 | 95 | 96 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 97 | def test_abs(): 98 | return UnaryModule(lambda x: torch.abs(x)) 99 | 100 | 101 | # NEG : Negation 102 | 103 | @tensorrt_converter('torch.neg') 104 | @tensorrt_converter('torch.neg_') 105 | @tensorrt_converter('torch.Tensor.neg') 106 | @tensorrt_converter('torch.Tensor.neg_') 107 | def convert_neg(ctx): 108 | __convert_unary(ctx, trt.UnaryOperation.NEG) 109 | 110 | 111 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 112 | def test_neg(): 113 | return UnaryModule(lambda x: torch.neg(x)) 114 | 115 | 116 | # SIN : Sine 117 | 118 | 119 | @tensorrt_converter('torch.sin') 120 | @tensorrt_converter('torch.sin_') 121 | @tensorrt_converter('torch.Tensor.sin') 122 | @tensorrt_converter('torch.Tensor.sin_') 123 | def convert_sin(ctx): 124 | __convert_unary(ctx, trt.UnaryOperation.SIN) 125 | 126 | 127 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 128 | def test_sin(): 129 | return UnaryModule(lambda x: torch.sin(x)) 130 | 131 | 132 | # COS : Cosine 133 | 134 | 135 | @tensorrt_converter('torch.cos') 136 | @tensorrt_converter('torch.cos_') 137 | @tensorrt_converter('torch.Tensor.cos') 138 | @tensorrt_converter('torch.Tensor.cos_') 139 | def convert_cos(ctx): 140 | __convert_unary(ctx, trt.UnaryOperation.COS) 141 | 142 | 143 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 144 | def test_cos(): 145 | return UnaryModule(lambda x: torch.cos(x)) 146 | 147 | 148 | # | TAN : Tangent 149 | 150 | 151 | @tensorrt_converter('torch.tan') 152 | @tensorrt_converter('torch.tan_') 153 | @tensorrt_converter('torch.Tensor.tan') 154 | @tensorrt_converter('torch.Tensor.tan_') 155 | def convert_cos(ctx): 156 | __convert_unary(ctx, trt.UnaryOperation.TAN) 157 | 158 | 159 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 160 | def test_tan(): 161 | return UnaryModule(lambda x: torch.tan(x)) 162 | 163 | 164 | # | SINH : Hyperbolic sine 165 | 166 | 167 | @tensorrt_converter('torch.sinh') 168 | @tensorrt_converter('torch.sinh_') 169 | @tensorrt_converter('torch.Tensor.sinh') 170 | @tensorrt_converter('torch.Tensor.sinh_') 171 | def convert_sinh(ctx): 172 | __convert_unary(ctx, trt.UnaryOperation.SINH) 173 | 174 | 175 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 176 | def test_sinh(): 177 | return UnaryModule(lambda x: torch.sinh(x)) 178 | 179 | 180 | # | COSH : Hyperbolic cosine 181 | 182 | 183 | @tensorrt_converter('torch.cosh') 184 | @tensorrt_converter('torch.cosh_') 185 | @tensorrt_converter('torch.Tensor.cosh') 186 | @tensorrt_converter('torch.Tensor.cosh_') 187 | def convert_cosh(ctx): 188 | __convert_unary(ctx, trt.UnaryOperation.COSH) 189 | 190 | 191 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 192 | def test_cosh(): 193 | return UnaryModule(lambda x: torch.cosh(x)) 194 | 195 | 196 | # | ASIN : Inverse sine 197 | 198 | 199 | @tensorrt_converter('torch.asin') 200 | @tensorrt_converter('torch.asin_') 201 | @tensorrt_converter('torch.Tensor.asin') 202 | @tensorrt_converter('torch.Tensor.asin_') 203 | def convert_asin(ctx): 204 | __convert_unary(ctx, trt.UnaryOperation.ASIN) 205 | 206 | 207 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 208 | def test_asin(): 209 | return UnaryModule(lambda x: torch.asin(x)) 210 | 211 | 212 | # | ACOS : Inverse cosine 213 | 214 | 215 | @tensorrt_converter('torch.acos') 216 | @tensorrt_converter('torch.acos_') 217 | @tensorrt_converter('torch.Tensor.acos') 218 | @tensorrt_converter('torch.Tensor.acos_') 219 | def convert_acos(ctx): 220 | __convert_unary(ctx, trt.UnaryOperation.ACOS) 221 | 222 | 223 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 224 | def test_acos(): 225 | return UnaryModule(lambda x: torch.acos(x)) 226 | 227 | 228 | # | ATAN : Inverse tangent 229 | 230 | 231 | @tensorrt_converter('torch.atan') 232 | @tensorrt_converter('torch.atan_') 233 | @tensorrt_converter('torch.Tensor.atan') 234 | @tensorrt_converter('torch.Tensor.atan_') 235 | def convert_atan(ctx): 236 | __convert_unary(ctx, trt.UnaryOperation.ATAN) 237 | 238 | 239 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 240 | def test_atan(): 241 | return UnaryModule(lambda x: torch.atan(x)) 242 | 243 | 244 | # | ASINH : Inverse hyperbolic sine 245 | # | 246 | # | ACOSH : Inverse hyperbolic cosine 247 | # | 248 | # | ATANH : Inverse hyperbolic tangent 249 | # | 250 | 251 | # CEIL : Ceiling 252 | 253 | 254 | @tensorrt_converter('torch.ceil') 255 | @tensorrt_converter('torch.ceil_') 256 | @tensorrt_converter('torch.Tensor.ceil') 257 | @tensorrt_converter('torch.Tensor.ceil_') 258 | def convert_ceil(ctx): 259 | __convert_unary(ctx, trt.UnaryOperation.CEIL) 260 | 261 | 262 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 263 | def test_ceil(): 264 | return UnaryModule(lambda x: torch.ceil(x)) 265 | 266 | 267 | # FLOOR : Floor 268 | 269 | 270 | @tensorrt_converter('torch.floor') 271 | @tensorrt_converter('torch.floor_') 272 | @tensorrt_converter('torch.Tensor.floor') 273 | @tensorrt_converter('torch.Tensor.floor_') 274 | def convert_floor(ctx): 275 | __convert_unary(ctx, trt.UnaryOperation.FLOOR) 276 | 277 | 278 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 279 | def test_floor(): 280 | return UnaryModule(lambda x: torch.floor(x)) -------------------------------------------------------------------------------- /torch2trt/converters/view.py: -------------------------------------------------------------------------------- 1 | from torch2trt.torch2trt import * 2 | from torch2trt.module_test import add_module_test 3 | 4 | 5 | @tensorrt_converter('torch.flatten') 6 | @tensorrt_converter('torch.Tensor.reshape') 7 | @tensorrt_converter('torch.Tensor.view') 8 | def convert_view(ctx): 9 | input = ctx.method_args[0] 10 | input_trt = trt_(ctx.network, input) 11 | output = ctx.method_return 12 | layer = ctx.network.add_shuffle(input_trt) 13 | layer.reshape_dims = tuple(output.shape[1:]) 14 | output._trt = layer.get_output(0) 15 | 16 | 17 | class View(torch.nn.Module): 18 | def __init__(self, *dims): 19 | super(View, self).__init__() 20 | self.dims = dims 21 | 22 | def forward(self, x): 23 | return x.view(*self.dims) 24 | 25 | 26 | 27 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 28 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 29 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 30 | def test_view_1d(): 31 | return View(1, -1) 32 | 33 | 34 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 35 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 36 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 37 | def test_view_2d(): 38 | return View(1, 1, -1) 39 | 40 | 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 42 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 43 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 44 | def test_view_3d(): 45 | return View(1, 1, 1, -1) 46 | -------------------------------------------------------------------------------- /torch2trt/init.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /torch2trt/libtorch2trt.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittlePieces/ObjectDetection/55ee1c120818f2f37d5ece5d4623fe38deb547e7/torch2trt/libtorch2trt.so -------------------------------------------------------------------------------- /torch2trt/module_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | 5 | class ModuleTest(object): 6 | def __init__(self, module_fn, dtype, device, input_shapes, **torch2trt_kwargs): 7 | self.module_fn = module_fn 8 | self.dtype = dtype 9 | self.device = device 10 | self.input_shapes = input_shapes 11 | self.torch2trt_kwargs = torch2trt_kwargs 12 | 13 | def module_name(self): 14 | return self.module_fn.__module__ + '.' + self.module_fn.__name__ 15 | 16 | 17 | MODULE_TESTS = [ 18 | ] 19 | 20 | 21 | def add_module_test(dtype, device, input_shapes, **torch2trt_kwargs): 22 | def register_module_test(module): 23 | global MODULE_TESTS 24 | MODULE_TESTS += [ModuleTest(module, dtype, device, input_shapes, **torch2trt_kwargs)] 25 | return module 26 | return register_module_test -------------------------------------------------------------------------------- /torch2trt/test.py: -------------------------------------------------------------------------------- 1 | from torch2trt import * 2 | from .module_test import ModuleTest, MODULE_TESTS 3 | import time 4 | import argparse 5 | import re 6 | import runpy 7 | from termcolor import colored 8 | 9 | 10 | def run(self): 11 | # create module 12 | module = self.module_fn() 13 | module = module.to(self.device) 14 | module = module.type(self.dtype) 15 | module = module.eval() 16 | 17 | # create inputs for conversion 18 | inputs_conversion = () 19 | for shape in self.input_shapes: 20 | inputs_conversion += (torch.zeros(shape).to(self.device).type(self.dtype), ) 21 | 22 | # convert module 23 | module_trt = torch2trt(module, inputs_conversion, **self.torch2trt_kwargs) 24 | 25 | # create inputs for torch/trt.. copy of inputs to handle inplace ops 26 | inputs = () 27 | for shape in self.input_shapes: 28 | inputs += (torch.randn(shape).to(self.device).type(self.dtype), ) 29 | inputs_trt = tuple([tensor.clone() for tensor in inputs]) 30 | 31 | 32 | # test output against original 33 | outputs = module(*inputs) 34 | outputs_trt = module_trt(*inputs_trt) 35 | 36 | if not isinstance(outputs, tuple): 37 | outputs = (outputs, ) 38 | 39 | # compute max error 40 | max_error = 0 41 | for i in range(len(outputs)): 42 | max_error_i = torch.max(torch.abs(outputs[i] - outputs_trt[i])) 43 | if max_error_i > max_error: 44 | max_error = max_error_i 45 | 46 | # benchmark pytorch throughput 47 | torch.cuda.current_stream().synchronize() 48 | t0 = time.time() 49 | for i in range(50): 50 | outputs = module(*inputs) 51 | torch.cuda.current_stream().synchronize() 52 | t1 = time.time() 53 | 54 | fps = 50.0 / (t1 - t0) 55 | 56 | # benchmark tensorrt throughput 57 | torch.cuda.current_stream().synchronize() 58 | t0 = time.time() 59 | for i in range(50): 60 | outputs = module_trt(*inputs) 61 | torch.cuda.current_stream().synchronize() 62 | t1 = time.time() 63 | 64 | fps_trt = 50.0 / (t1 - t0) 65 | 66 | # benchmark pytorch latency 67 | torch.cuda.current_stream().synchronize() 68 | t0 = time.time() 69 | for i in range(50): 70 | outputs = module(*inputs) 71 | torch.cuda.current_stream().synchronize() 72 | t1 = time.time() 73 | 74 | ms = 1000.0 * (t1 - t0) / 50.0 75 | 76 | # benchmark tensorrt latency 77 | torch.cuda.current_stream().synchronize() 78 | t0 = time.time() 79 | for i in range(50): 80 | outputs = module_trt(*inputs) 81 | torch.cuda.current_stream().synchronize() 82 | t1 = time.time() 83 | 84 | ms_trt = 1000.0 * (t1 - t0) / 50.0 85 | 86 | return max_error, fps, fps_trt, ms, ms_trt 87 | 88 | 89 | if __name__ == '__main__': 90 | 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--output', '-o', help='Test output file path', type=str, default='torch2trt_test.md') 93 | parser.add_argument('--name', help='Regular expression to filter modules to test by name', type=str, default='.*') 94 | parser.add_argument('--tolerance', help='Maximum error to print warning for entry', type=float, default='-1') 95 | parser.add_argument('--include', help='Addition python file to include defining additional tests', action='append', default=[]) 96 | args = parser.parse_args() 97 | 98 | for include in args.include: 99 | runpy.run_module(include) 100 | 101 | for test in MODULE_TESTS: 102 | 103 | # filter by module name 104 | name = test.module_name() 105 | if not re.search(args.name, name): 106 | continue 107 | 108 | # run test 109 | max_error, fps, fps_trt, ms, ms_trt = run(test) 110 | 111 | # write entry 112 | line = '| %s | %s | %s | %s | %.2E | %.3g | %.3g | %.3g | %.3g |' % (name, test.dtype.__repr__().split('.')[-1], str(test.input_shapes), str(test.torch2trt_kwargs), max_error, fps, fps_trt, ms, ms_trt) 113 | 114 | if args.tolerance >= 0 and max_error > args.tolerance: 115 | print(colored(line, 'yellow')) 116 | else: 117 | print(line) 118 | 119 | with open(args.output, 'a+') as f: 120 | f.write(line + '\n') 121 | -------------------------------------------------------------------------------- /torch2trt/torch2trt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorrt as trt 3 | from copy import copy 4 | import numpy as np 5 | from .calibration import TensorBatchDataset, DatasetCalibrator, DEFAULT_CALIBRATION_ALGORITHM 6 | 7 | 8 | # UTILITY FUNCTIONS 9 | 10 | 11 | def torch_dtype_to_trt(dtype): 12 | if dtype == torch.int8: 13 | return trt.int8 14 | elif dtype == torch.int32: 15 | return trt.int32 16 | elif dtype == torch.float16: 17 | return trt.float16 18 | elif dtype == torch.float32: 19 | return trt.float32 20 | else: 21 | raise TypeError('%s is not supported by tensorrt' % dtype) 22 | 23 | 24 | def torch_dtype_from_trt(dtype): 25 | if dtype == trt.int8: 26 | return torch.int8 27 | elif dtype == trt.int32: 28 | return torch.int32 29 | elif dtype == trt.float16: 30 | return torch.float16 31 | elif dtype == trt.float32: 32 | return torch.float32 33 | else: 34 | raise TypeError('%s is not supported by torch' % dtype) 35 | 36 | 37 | def torch_device_to_trt(device): 38 | if device.type == torch.device('cuda').type: 39 | return trt.TensorLocation.DEVICE 40 | elif device.type == torch.device('cpu').type: 41 | return trt.TensorLocation.HOST 42 | else: 43 | return TypeError('%s is not supported by tensorrt' % device) 44 | 45 | 46 | def torch_device_from_trt(device): 47 | if device == trt.TensorLocation.DEVICE: 48 | return torch.device('cuda') 49 | elif device == trt.TensorLocation.HOST: 50 | return torch.device('cpu') 51 | else: 52 | return TypeError('%s is not supported by torch' % device) 53 | 54 | 55 | def trt_num_inputs(engine): 56 | count = 0 57 | for i in range(engine.num_bindings): 58 | if engine.binding_is_input(i): 59 | count += 1 60 | return count 61 | 62 | 63 | def trt_num_outputs(engine): 64 | count = 0 65 | for i in range(engine.num_bindings): 66 | if not engine.binding_is_input(i): 67 | count += 1 68 | return count 69 | 70 | 71 | def torch_dim_to_trt_axes(dim): 72 | """Converts torch dim, or tuple of dims to a tensorrt axes bitmask""" 73 | if not isinstance(dim, tuple): 74 | dim = (dim, ) 75 | 76 | # create axes bitmask for reduce layer 77 | axes = 0 78 | for d in dim: 79 | axes |= 1 << (d - 1) # -1 to remove batch dimension 80 | 81 | return axes 82 | 83 | 84 | def add_trt_constant(network, tensor): 85 | shape = tuple(tensor.shape[1:]) 86 | array = tensor[0].detach().cpu().numpy() 87 | layer = network.add_constant(shape, array) 88 | return layer.get_output(0) 89 | 90 | 91 | def check_torch_dtype(*tensors): 92 | dtype = None 93 | for t in tensors: 94 | if isinstance(t, torch.Tensor): 95 | if dtype is None: 96 | dtype = t.dtype 97 | else: 98 | assert(dtype == t.dtype)#, 'Tensor data types must match') 99 | assert(dtype is not None)#, 'Data type could not be inferred from any item in list') 100 | return dtype 101 | 102 | 103 | def trt_(network, *tensors): 104 | """Creates missing TensorRT tensors and adds shuffle layers to make tensors broadcastable""" 105 | trt_tensors = [None] * len(tensors) 106 | 107 | dtype = check_torch_dtype(*tensors) 108 | 109 | # get broadcast dimension 110 | broadcast_num_dim = 0 111 | for t in tensors: 112 | if isinstance(t, torch.Tensor): 113 | if not hasattr(t, '_trt'): 114 | num_dim = len(t.shape) # don't exclude batch for constants 115 | else: 116 | num_dim = len(t._trt.shape) # non-leaf tensors must already have _trt, get shape from that 117 | if num_dim > broadcast_num_dim: 118 | broadcast_num_dim = num_dim 119 | 120 | 121 | for i, t in enumerate(tensors): 122 | trt_tensor = None 123 | 124 | # GET TRT TENSOR (OR CREATE TRT CONSTANT) 125 | 126 | # get tensor w/ _trt 127 | if isinstance(t, torch.Tensor) and hasattr(t, '_trt'): 128 | trt_tensor = t._trt 129 | 130 | # or... add constant for leaf tensor w/o _trt 131 | elif isinstance(t, torch.Tensor) and not hasattr(t, '_trt'): 132 | # add leaf tensor 133 | shape = tuple(t.shape) # don't exclude batch when adding constants...? 134 | weight = t.detach().cpu().numpy() 135 | t._trt = network.add_constant(shape, weight).get_output(0) 136 | trt_tensor = t._trt 137 | 138 | # or... add constant for scalar primitive 139 | elif isinstance(t, float) or isinstance(t, int): 140 | shape = (1,) * broadcast_num_dim 141 | scalar = t * torch.ones(shape, dtype=dtype).cpu().numpy() 142 | trt_tensor = network.add_constant(shape, scalar).get_output(0) 143 | 144 | assert(trt_tensor is not None) 145 | 146 | # MAKE TRT TENSOR BROADCASTABLE IF IT IS NOT ALREADY 147 | 148 | if len(trt_tensor.shape) < broadcast_num_dim: 149 | # append 1 size dims to front 150 | diff = broadcast_num_dim - len(trt_tensor.shape) 151 | shape = tuple([1] * diff + list(trt_tensor.shape)) 152 | layer = network.add_shuffle(trt_tensor) 153 | layer.reshape_dims = shape 154 | trt_tensor = layer.get_output(0) 155 | 156 | trt_tensors[i] = trt_tensor 157 | 158 | if len(trt_tensors) == 1: 159 | return trt_tensors[0] 160 | else: 161 | return tuple(trt_tensors) 162 | 163 | 164 | # CONVERSION REGISTRY AND HOOKS 165 | 166 | 167 | CONVERTERS = {} 168 | 169 | 170 | def get_arg(ctx, name, pos, default): 171 | if name in ctx.method_kwargs: 172 | return ctx.method_kwargs[name] 173 | elif len(ctx.method_args) > pos: 174 | return ctx.method_args[pos] 175 | else: 176 | return default 177 | 178 | 179 | def attach_converter(ctx, method, converter, method_str): 180 | """Gets a function that executes PyTorch method and TensorRT converter""" 181 | global DUMMY_CONVERTERS 182 | 183 | def wrapper(*args, **kwargs): 184 | skip = True 185 | 186 | # check if another (parent) converter has lock 187 | if not ctx.lock: 188 | if converter['is_real']: 189 | ctx.lock = True # only real converters can acquire lock 190 | skip = False 191 | 192 | # run original method 193 | outputs = method(*args, **kwargs) 194 | 195 | if not skip: 196 | ctx.method_args = args 197 | ctx.method_kwargs = kwargs 198 | ctx.method_return = outputs 199 | ctx.method_str = method_str 200 | 201 | # print('%s' % (converter.__name__,)) 202 | converter['converter'](ctx) 203 | 204 | # convert to None so conversion will fail for unsupported layers 205 | ctx.method_args = None 206 | ctx.method_kwargs = None 207 | ctx.method_return = None 208 | ctx.lock = False 209 | 210 | return outputs 211 | 212 | return wrapper 213 | 214 | 215 | class ConversionHook(object): 216 | """Attaches TensorRT converter to PyTorch method call""" 217 | 218 | def __init__(self, ctx, method, converter): 219 | self.ctx = ctx 220 | self.method_str = method 221 | self.converter = converter 222 | 223 | def _set_method(self, method): 224 | exec('%s = method' % self.method_str) 225 | 226 | def __enter__(self): 227 | try: 228 | self.method_impl = eval(self.method_str) 229 | except AttributeError: 230 | self.method_impl = None 231 | 232 | if self.method_impl: 233 | self._set_method(attach_converter(self.ctx, self.method_impl, self.converter, self.method_str)) 234 | 235 | def __exit__(self, type, val, tb): 236 | if self.method_impl: 237 | self._set_method(self.method_impl) 238 | 239 | 240 | class ConversionContext(object): 241 | def __init__(self, network, converters=CONVERTERS): 242 | self.network = network 243 | self.lock = False 244 | self.method_args = None 245 | self.method_kwargs = None 246 | self.method_return = None 247 | self.hooks = [ 248 | ConversionHook(self, method, converter) 249 | for method, converter in converters.items() 250 | ] 251 | 252 | def __enter__(self): 253 | for hook in self.hooks: 254 | hook.__enter__() 255 | return self 256 | 257 | def __exit__(self, type, val, tb): 258 | for hook in self.hooks: 259 | hook.__exit__(type, val, tb) 260 | 261 | def add_inputs(self, torch_inputs, names=None): 262 | if names is None: 263 | names = ['input_%d' % i for i in range(len(torch_inputs))] 264 | self.input_names = names 265 | 266 | for i, torch_input in enumerate(torch_inputs): 267 | if not hasattr(torch_input, '_trt'): 268 | trt_tensor = self.network.add_input( 269 | name=names[i], 270 | shape=tuple(torch_input.shape)[1:], 271 | dtype=torch_dtype_to_trt(torch_input.dtype), 272 | ) 273 | trt_tensor.location = torch_device_to_trt(torch_input.device) 274 | torch_input._trt = trt_tensor 275 | 276 | def mark_outputs(self, torch_outputs, names=None): 277 | if names is None: 278 | names = ['output_%d' % i for i in range(len(torch_outputs))] 279 | self.output_names = names 280 | 281 | for i, torch_output in enumerate(torch_outputs): 282 | trt_tensor = torch_output._trt 283 | trt_tensor.name = names[i] 284 | trt_tensor.location = torch_device_to_trt(torch_output.device) 285 | trt_tensor.dtype = torch_dtype_to_trt(torch_output.dtype) 286 | self.network.mark_output(trt_tensor) 287 | 288 | 289 | class TRTModule(torch.nn.Module): 290 | def __init__(self, engine=None, input_names=None, output_names=None): 291 | super(TRTModule, self).__init__() 292 | self._register_state_dict_hook(TRTModule._on_state_dict) 293 | self.engine = engine 294 | if self.engine is not None: 295 | self.context = self.engine.create_execution_context() 296 | self.input_names = input_names 297 | self.output_names = output_names 298 | 299 | def _on_state_dict(self, state_dict, prefix, local_metadata): 300 | state_dict[prefix + 'engine'] = bytearray(self.engine.serialize()) 301 | state_dict[prefix + 'input_names'] = self.input_names 302 | state_dict[prefix + 'output_names'] = self.output_names 303 | 304 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 305 | engine_bytes = state_dict[prefix + 'engine'] 306 | 307 | with trt.Logger() as logger, trt.Runtime(logger) as runtime: 308 | self.engine = runtime.deserialize_cuda_engine(engine_bytes) 309 | self.context = self.engine.create_execution_context() 310 | 311 | self.input_names = state_dict[prefix + 'input_names'] 312 | self.output_names = state_dict[prefix + 'output_names'] 313 | 314 | def forward(self, *inputs): 315 | batch_size = inputs[0].shape[0] 316 | bindings = [None] * (len(self.input_names) + len(self.output_names)) 317 | 318 | # create output tensors 319 | outputs = [None] * len(self.output_names) 320 | for i, output_name in enumerate(self.output_names): 321 | idx = self.engine.get_binding_index(output_name) 322 | dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) 323 | shape = (batch_size, ) + tuple(self.engine.get_binding_shape(idx)) 324 | device = torch_device_from_trt(self.engine.get_location(idx)) 325 | output = torch.empty(size=shape, dtype=dtype, device=device) 326 | outputs[i] = output 327 | bindings[idx] = output.data_ptr() 328 | 329 | for i, input_name in enumerate(self.input_names): 330 | idx = self.engine.get_binding_index(input_name) 331 | bindings[idx] = inputs[i].data_ptr() 332 | 333 | self.context.execute_async(batch_size, bindings, torch.cuda.current_stream().cuda_stream) 334 | 335 | outputs = tuple(outputs) 336 | if len(outputs) == 1: 337 | outputs = outputs[0] 338 | 339 | return outputs 340 | 341 | def enable_profiling(self): 342 | if not self.context.profiler: 343 | self.context.profiler = trt.Profiler() 344 | 345 | 346 | def torch2trt(module, 347 | inputs, 348 | input_names=None, 349 | output_names=None, 350 | log_level=trt.Logger.ERROR, 351 | max_batch_size=1, 352 | fp16_mode=False, 353 | max_workspace_size=0, 354 | strict_type_constraints=False, 355 | keep_network=True, 356 | int8_mode=False, 357 | int8_calib_dataset=None, 358 | int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM): 359 | 360 | inputs_in = inputs 361 | 362 | # copy inputs to avoid modifications to source data 363 | inputs = [tensor.clone()[0:1] for tensor in inputs] # only run single entry 364 | 365 | logger = trt.Logger(log_level) 366 | builder = trt.Builder(logger) 367 | network = builder.create_network() 368 | 369 | with ConversionContext(network) as ctx: 370 | 371 | if isinstance(inputs, list): 372 | inputs = tuple(inputs) 373 | if not isinstance(inputs, tuple): 374 | inputs = (inputs, ) 375 | ctx.add_inputs(inputs, input_names) 376 | 377 | outputs = module(*inputs) 378 | 379 | if not isinstance(outputs, tuple) and not isinstance(outputs, list): 380 | outputs = (outputs, ) 381 | ctx.mark_outputs(outputs, output_names) 382 | 383 | builder.max_workspace_size = max_workspace_size 384 | builder.fp16_mode = fp16_mode 385 | builder.max_batch_size = max_batch_size 386 | builder.strict_type_constraints = strict_type_constraints 387 | 388 | if int8_mode: 389 | 390 | # default to use input tensors for calibration 391 | if int8_calib_dataset is None: 392 | int8_calib_dataset = TensorBatchDataset(inputs_in) 393 | 394 | builder.int8_mode = True 395 | 396 | # @TODO(jwelsh): Should we set batch_size=max_batch_size? Need to investigate memory consumption 397 | builder.int8_calibrator = DatasetCalibrator(inputs, int8_calib_dataset, batch_size=1, algorithm=int8_calib_algorithm) 398 | 399 | engine = builder.build_cuda_engine(network) 400 | 401 | module_trt = TRTModule(engine, ctx.input_names, ctx.output_names) 402 | 403 | if keep_network: 404 | module_trt.network = network 405 | 406 | return module_trt 407 | 408 | 409 | # DEFINE ALL CONVERSION FUNCTIONS 410 | 411 | 412 | def tensorrt_converter(method, is_real=True): 413 | def register_converter(converter): 414 | CONVERTERS[method] = {'converter': converter, 'is_real': is_real} 415 | return converter 416 | return register_converter 417 | -------------------------------------------------------------------------------- /torch2trt/utils.py: -------------------------------------------------------------------------------- 1 | import graphviz 2 | 3 | 4 | def trt_network_to_dot_graph(network): 5 | dot = graphviz.Digraph(comment='Network') 6 | 7 | # add nodes (layers) 8 | for i in range(network.num_layers): 9 | layer = network.get_layer(i) 10 | dot.node(layer.name) 11 | 12 | # add nodes (inputs) 13 | for i in range(network.num_inputs): 14 | dot.node(network.get_input(i).name) 15 | 16 | # add nodes (outputs) 17 | for i in range(network.num_outputs): 18 | dot.node(network.get_output(i).name) 19 | 20 | # add layer->layer edges 21 | for a in range(network.num_layers): 22 | layer_a = network.get_layer(a) 23 | 24 | for b in range(network.num_layers): 25 | layer_b = network.get_layer(b) 26 | 27 | for i in range(layer_a.num_outputs): 28 | output_i = layer_a.get_output(i) 29 | 30 | for j in range(layer_b.num_inputs): 31 | input_j = layer_b.get_input(j) 32 | 33 | if output_i == input_j: 34 | dot.edge(layer_a.name, layer_b.name, label=str(input_j.shape)) 35 | 36 | # add input->layer edges 37 | for i in range(network.num_inputs): 38 | input_i = network.get_input(i) 39 | 40 | for b in range(network.num_layers): 41 | layer_b = network.get_layer(b) 42 | 43 | for j in range(layer_b.num_inputs): 44 | input_j = layer_b.get_input(j) 45 | 46 | if input_i == input_j: 47 | dot.edge(input_i.name, layer_b.name, label=str(input_j.shape)) 48 | 49 | # add layer->output edges 50 | for i in range(network.num_outputs): 51 | input_i = network.get_output(i) 52 | 53 | for b in range(network.num_layers): 54 | layer_b = network.get_layer(b) 55 | 56 | for j in range(layer_b.num_outputs): 57 | input_j = layer_b.get_output(j) 58 | 59 | if input_i == input_j: 60 | dot.edge(layer_b.name, input_i.name, label=str(input_j.shape)) 61 | 62 | return dot -------------------------------------------------------------------------------- /torchtrt2trt: -------------------------------------------------------------------------------- 1 | 2 | --------------------------------------------------------------------------------