├── .gitattributes ├── .gitignore ├── README.md ├── TVLoss.py └── images └── TVLoss_tf.jpeg /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */.DS_Store 2 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Total_Variation_Loss 2 | 3 | This a pytorch implementation of tv loss 4 | 5 | It is also implemented in TensonFlow... 6 | ![Aaron Swartz](https://github.com/jxgu1016/Total_Variation_Loss.pytorch/blob/master/images/TVLoss_tf.jpeg) 7 | -------------------------------------------------------------------------------- /TVLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | class TVLoss(nn.Module): 6 | def __init__(self,TVLoss_weight=1): 7 | super(TVLoss,self).__init__() 8 | self.TVLoss_weight = TVLoss_weight 9 | 10 | def forward(self,x): 11 | batch_size = x.size()[0] 12 | h_x = x.size()[2] 13 | w_x = x.size()[3] 14 | count_h = self._tensor_size(x[:,:,1:,:]) 15 | count_w = self._tensor_size(x[:,:,:,1:]) 16 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 17 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 18 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 19 | 20 | def _tensor_size(self,t): 21 | return t.size()[1]*t.size()[2]*t.size()[3] 22 | 23 | def main(): 24 | # x = Variable(torch.FloatTensor([[[1,2],[2,3]],[[1,2],[2,3]]]).view(1,2,2,2), requires_grad=True) 25 | # x = Variable(torch.FloatTensor([[[3,1],[4,3]],[[3,1],[4,3]]]).view(1,2,2,2), requires_grad=True) 26 | # x = Variable(torch.FloatTensor([[[1,1,1], [2,2,2],[3,3,3]],[[1,1,1], [2,2,2],[3,3,3]]]).view(1, 2, 3, 3), requires_grad=True) 27 | x = Variable(torch.FloatTensor([[[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]]]).view(1, 2, 3, 3),requires_grad=True) 28 | addition = TVLoss() 29 | z = addition(x) 30 | print x 31 | print z.data 32 | z.backward() 33 | print x.grad 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /images/TVLoss_tf.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxgu1016/Total_Variation_Loss.pytorch/79f8d8a8f029cc47142d636cd517822068784253/images/TVLoss_tf.jpeg --------------------------------------------------------------------------------