├── README.md ├── TISTA.py └── TISTA_p_alpha.py /README.md: -------------------------------------------------------------------------------- 1 | Trainable ISTA (TISTA) for sparse signal recovery 2 | 3 | This code is an implementation of Trainable ISTA (TISTA) for sparse signal recovery in PyTorch. 4 | The details of the algorithm can be found in the paper: 5 | Daisuke Ito, Satoshi Takabe, Tadashi Wadayama, 6 | "Trainable ISTA for Sparse Signal Recovery", arXiv:1801.01978. 7 | (Computer experiments in the paper was performed with another TensorFlow implementation) 8 | 9 | -------------------------------------------------------------------------------- /TISTA.py: -------------------------------------------------------------------------------- 1 | # Trainable ISTA (TISTA) 2 | # 3 | # This code is an implementation of Trainable ISTA (TISTA) for sparse signal recovery in Pytorch.Tensor. 4 | # The details of the algorithm can be found in the paper: 5 | # Daisuke Ito, Satoshi Takabe, Tadashi Wadayama, 6 | # "Trainable ISTA for Sparse Signal Recovery", arXiv:1801.01978. 7 | # (Computer experiments in the paper was performed with another TensorFlow implementation) 8 | # 9 | # GPU is required for execution of this program. If you do not have GPU, 10 | # just change "device = torch.device('cuda')" to 'cpu'. 11 | # 12 | # This basic TISTA trains only $\gamma_t$. 13 | # 14 | # Last update 11/21/2018 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch.autograd import Variable 19 | import torch.nn.functional as F 20 | import torch.optim as optim 21 | import math 22 | import time 23 | 24 | # device 25 | device = torch.device('cuda') # choose 'cpu' or 'cuda' 26 | 27 | # global variables 28 | 29 | N = 500 # length of a source signal vector 30 | M = 250 # length of a observation vector 31 | p = 0.1 # probability for occurrence of non-zero components 32 | 33 | batch_size = 1000 # mini-batch size 34 | num_batch = 200 # number of mini-batches in a generation 35 | num_generations = 12 # number of generations 36 | snr = 40.0 # SNR for the system in dB 37 | 38 | alpha2 = 1.0 # variance of non-zero component 39 | alpha_std = math.sqrt(alpha2) 40 | max_layers = 12 # maximum number of layers 41 | adam_lr = 0.04 # initial learning parameter for Adam 42 | 43 | # random seed of torch 44 | torch.manual_seed(5) 45 | 46 | ### setting sensing matrix 47 | # sensing matrix with small variance 48 | A = torch.normal(0.0, std=math.sqrt(1.0/M) * torch.ones(M, N)) 49 | 50 | # sensing matrix with large variance 51 | #A = torch.normal(0.0, std=math.sqrt(1.0) * torch.ones(M, N)) 52 | 53 | # \pm 1 sensing matrix 54 | #A = 1.0-2.0*torch.bernoulli(0.5 * torch.ones(M, N)) 55 | ### end of setting sensing matrix 56 | 57 | At = A.t() 58 | W = At.mm((A.mm(At)).inverse()) # pseudo inverse matrix 59 | Wt = W.t() 60 | 61 | taa = (At.mm(A)).trace().to(device) # trace(A^T A) 62 | tww = (W.mm(Wt)).trace().to(device) # trace(W W^T) 63 | 64 | Wt = torch.Tensor(Wt).to(device) 65 | At = torch.Tensor(At).to(device) 66 | 67 | print("sensing matrix A\n", A.detach().numpy()) 68 | 69 | 70 | # detection for NaN 71 | def isnan(x): 72 | return x != x 73 | 74 | # mini-batch generator 75 | def generate_batch(): 76 | support = torch.bernoulli(p * torch.ones(batch_size, N)) 77 | nonzero = torch.normal(0.0, alpha_std * torch.ones(batch_size, N)) 78 | return torch.mul(nonzero, support) 79 | 80 | 81 | # definition of TISTA network 82 | class TISTA_NET(nn.Module): 83 | def __init__(self): 84 | super(TISTA_NET, self).__init__() 85 | self.gamma = nn.Parameter(torch.ones(max_layers)) #nn.Parameter(torch.normal(1.0, 0.1*torch.ones(max_layers))) 86 | print("TISTA initialized...") 87 | 88 | def gauss(self, x, var): 89 | return torch.exp(-torch.mul(x, x)/(2.0*var))/pow(2.0*math.pi*var,0.5) 90 | 91 | def MMSE_shrinkage(self, y, tau2): # MMSE shrinkage function 92 | return (y*alpha2/(alpha2+tau2))*p*self.gauss(y,(alpha2+tau2))/((1-p)*self.gauss(y, tau2) + p*self.gauss(y, (alpha2+tau2))) 93 | 94 | def eval_tau2(self, t, i): # error variance estimator 95 | v2 = (t.norm(2,1).pow(2.0) - M*sigma2)/taa 96 | v2.clamp(min=1e-9) 97 | tau2 = (v2/N)*(N+(self.gamma[i]*self.gamma[i]-2.0*self.gamma[i])*M)+self.gamma[i]*self.gamma[i]*tww*sigma2/N 98 | tau2 = (tau2.expand(N, batch_size)).t() 99 | return tau2 100 | 101 | def forward(self, x, s, max_itr): # TISTA network 102 | y = x.mm(At) + torch.Tensor(torch.normal(0.0, sigma_std*torch.ones(batch_size, M))).to(device) 103 | for i in range(max_itr): 104 | t = y - s.mm(At) 105 | tau2 = self.eval_tau2(t, i) 106 | r = s + t.mm(Wt)*self.gamma[i] 107 | s = self.MMSE_shrinkage(r, tau2) 108 | return s 109 | 110 | 111 | 112 | global sigma_std, sigma2, xi 113 | 114 | network = TISTA_NET().to(device) # generating an instance of TISTA network 115 | s_zero = torch.Tensor(torch.zeros(batch_size, N)).to(device) # initial value 116 | opt = optim.Adam(network.parameters(), lr=adam_lr) # setting for optimizer (Adam) 117 | 118 | # SNR calculation 119 | sum = 0.0 120 | for i in range(100): 121 | x = torch.Tensor(generate_batch()).to(device) 122 | y = x.mm(At) 123 | sum += (y.norm(2, 1).pow(2.0)).sum().item() 124 | ave = sum/(100.0 * batch_size) 125 | sigma2 = ave/(M*math.pow(10.0, snr/10.0)) 126 | sigma_std = math.sqrt(sigma2) 127 | xi = alpha2 + sigma2 128 | 129 | 130 | # incremental training loop 131 | start = time.time() 132 | 133 | for gen in (range(num_generations)): 134 | # training process 135 | for i in range(num_batch): 136 | if (gen > 10): # change learning rate of Adam 137 | opt = optim.Adam(network.parameters(), lr=adam_lr/50.0) 138 | x = torch.Tensor(generate_batch()).to(device) 139 | opt.zero_grad() 140 | x_hat = network(x, s_zero, gen+1).to(device) 141 | loss = F.mse_loss(x_hat, x) 142 | loss.backward() 143 | 144 | grads = torch.stack([param.grad for param in network.parameters()]) 145 | if isnan(grads).any(): # avoiding NaN in gradients 146 | continue 147 | 148 | 149 | opt.step() 150 | # end of training training 151 | 152 | 153 | # accuracy check after t-th incremental training 154 | nmse_sum = 0.0 155 | tot = 1 # batch size for accuracy check 156 | for i in range(tot): 157 | x = torch.Tensor(generate_batch()).to(device) 158 | x_hat = network(x, s_zero, gen+1).to(device) 159 | num = (x - x_hat).norm(2, 1).pow(2.0) 160 | denom = x.norm(2,1).pow(2.0) 161 | nmse = num/denom 162 | nmse_sum += torch.sum(nmse).item() 163 | 164 | nmse = 10.0*math.log(nmse_sum / (tot * batch_size))/math.log(10.0) #NMSE [dB] 165 | 166 | print('({0}) NMSE= {1:6.3f}'.format(gen + 1, nmse)) 167 | # end of accuracy check 168 | 169 | elapsed_time = time.time() - start 170 | print("elapsed_time:{0}".format(elapsed_time) + "[sec]") 171 | -------------------------------------------------------------------------------- /TISTA_p_alpha.py: -------------------------------------------------------------------------------- 1 | # Trainable ISTA (TISTA) 2 | # 3 | # This code is an implementation of Trainable ISTA (TISTA) for sparse signal recovery in Pytorch.Tensor. 4 | # The details of the algorithm can be found in the paper: 5 | # Daisuke Ito, Satoshi Takabe, Tadashi Wadayama, 6 | # "Trainable ISTA for Sparse Signal Recovery", arXiv:1801.01978. 7 | # (Computer experiments in the paper was performed with another TensorFlow implementation) 8 | # 9 | # GPU is required for execution of this program. If you do not have GPU, 10 | # just change "device = torch.device('cuda')" to 'cpu'. 11 | # 12 | # This TISTA trains parameters $\alpha^2$ and $p$ in Bernoulli-Gaussian prior distribution in addition to $\gamma_t$. 13 | # 14 | # Last update 11/21/2018 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch.autograd import Variable 19 | import torch.nn.functional as F 20 | import torch.optim as optim 21 | import math 22 | import time 23 | 24 | # device 25 | device = torch.device('cuda') # choose 'cpu' or 'cuda' 26 | 27 | # global variables 28 | 29 | N = 500 # length of a source signal vector 30 | M = 250 # length of a observation vector 31 | p = 0.1 # probability for occurrence of non-zero components 32 | 33 | batch_size = 1000 # mini-batch size 34 | num_batch = 200 # number of mini-batches in a generation 35 | num_generations = 12 # number of generations 36 | snr = 40.0 # SNR for the system in dB 37 | 38 | alpha2 = 1.0 # variance of non-zero component 39 | alpha_std = math.sqrt(alpha2) 40 | max_layers = 12 # maximum number of layers 41 | adam_lr = 0.04 # initial learning parameter for Adam 42 | 43 | # random seed of torch 44 | torch.manual_seed(5) 45 | 46 | ### setting sensing matrix 47 | # sensing matrix with small variance 48 | A = torch.normal(0.0, std=math.sqrt(1.0/M) * torch.ones(M, N)) 49 | 50 | # sensing matrix with large variance 51 | #A = torch.normal(0.0, std=math.sqrt(1.0) * torch.ones(M, N)) 52 | 53 | # \pm 1 sensing matrix 54 | #A = 1.0-2.0*torch.bernoulli(0.5 * torch.ones(M, N)) 55 | ### end of setting sensing matrix 56 | 57 | At = A.t() 58 | W = At.mm((A.mm(At)).inverse()) # pseudo inverse matrix 59 | Wt = W.t() 60 | 61 | taa = (At.mm(A)).trace().to(device) # trace(A^T A) 62 | tww = (W.mm(Wt)).trace().to(device) # trace(W W^T) 63 | 64 | Wt = torch.Tensor(Wt).to(device) 65 | At = torch.Tensor(At).to(device) 66 | 67 | print("sensing matrix A\n", A.detach().numpy()) 68 | 69 | 70 | # detection for NaN 71 | def isnan(x): 72 | return x != x 73 | 74 | # mini-batch generator 75 | def generate_batch(): 76 | support = torch.bernoulli(p * torch.ones(batch_size, N)) 77 | nonzero = torch.normal(0.0, alpha_std * torch.ones(batch_size, N)) 78 | return torch.mul(nonzero, support) 79 | 80 | 81 | # definition of TISTA network 82 | class TISTA_NET(nn.Module): 83 | def __init__(self): 84 | super(TISTA_NET, self).__init__() 85 | self.gamma = nn.Parameter(torch.ones(max_layers)) 86 | self.p = nn.Parameter(p*torch.ones(max_layers)) 87 | self.alpha2 = nn.Parameter(alpha2*torch.ones(max_layers)) 88 | print("TISTA initialized...") 89 | 90 | def gauss(self, x, var): 91 | return torch.exp(-torch.mul(x, x)/(2.0*var))/pow(2.0*math.pi*var,0.5) 92 | 93 | def MMSE_shrinkage(self, y, tau2): # MMSE shrinkage function 94 | return (y*self.alpha2[0]/(self.alpha2[0]+tau2))*self.p[0]*self.gauss(y,(self.alpha2[0]+tau2))/((1-self.p[0])*self.gauss(y, tau2) + self.p[0]*self.gauss(y, (self.alpha2[0]+tau2))) 95 | 96 | def eval_tau2(self, t, i): # error variance estimator 97 | v2 = (t.norm(2,1).pow(2.0) - M*sigma2)/taa 98 | v2.clamp(min=1e-9) 99 | tau2 = (v2/N)*(N+(self.gamma[i]*self.gamma[i]-2.0*self.gamma[i])*M)+self.gamma[i]*self.gamma[i]*tww*sigma2/N 100 | tau2 = (tau2.expand(N, batch_size)).t() 101 | return tau2 102 | 103 | def forward(self, x, s, max_itr): # TISTA network 104 | y = x.mm(At) + torch.Tensor(torch.normal(0.0, sigma_std*torch.ones(batch_size, M))).to(device) 105 | for i in range(max_itr): 106 | t = y - s.mm(At) 107 | tau2 = self.eval_tau2(t, i) 108 | r = s + t.mm(Wt)*self.gamma[i] 109 | s = self.MMSE_shrinkage(r, tau2) 110 | return s 111 | 112 | 113 | 114 | global sigma_std, sigma2, xi 115 | 116 | network = TISTA_NET().to(device) # generating an instance of TISTA network 117 | s_zero = torch.Tensor(torch.zeros(batch_size, N)).to(device) # initial value 118 | opt = optim.Adam(network.parameters(), lr=adam_lr) # setting for optimizer (Adam) 119 | 120 | # SNR calculation 121 | sum = 0.0 122 | for i in range(100): 123 | x = torch.Tensor(generate_batch()).to(device) 124 | y = x.mm(At) 125 | sum += (y.norm(2, 1).pow(2.0)).sum().item() 126 | ave = sum/(100.0 * batch_size) 127 | sigma2 = ave/(M*math.pow(10.0, snr/10.0)) 128 | sigma_std = math.sqrt(sigma2) 129 | xi = alpha2 + sigma2 130 | 131 | 132 | # incremental training loop 133 | start = time.time() 134 | 135 | for gen in (range(num_generations)): 136 | # training process 137 | for i in range(num_batch): 138 | if (gen > 10): # change learning rate of Adam 139 | opt = optim.Adam(network.parameters(), lr=adam_lr/50.0) 140 | x = torch.Tensor(generate_batch()).to(device) 141 | opt.zero_grad() 142 | x_hat = network(x, s_zero, gen+1).to(device) 143 | loss = F.mse_loss(x_hat, x) 144 | loss.backward() 145 | 146 | grads = torch.stack([param.grad for param in network.parameters()]) 147 | if isnan(grads).any(): # avoiding NaN in gradients 148 | continue 149 | 150 | 151 | opt.step() 152 | # end of training training 153 | 154 | 155 | # accuracy check after t-th incremental training 156 | nmse_sum = 0.0 157 | tot = 1 # batch size for accuracy check 158 | for i in range(tot): 159 | x = torch.Tensor(generate_batch()).to(device) 160 | x_hat = network(x, s_zero, gen+1).to(device) 161 | num = (x - x_hat).norm(2, 1).pow(2.0) 162 | denom = x.norm(2,1).pow(2.0) 163 | nmse = num/denom 164 | nmse_sum += torch.sum(nmse).item() 165 | 166 | nmse = 10.0*math.log(nmse_sum / (tot * batch_size))/math.log(10.0) #NMSE [dB] 167 | 168 | print('({0}) NMSE= {1:6.3f}'.format(gen + 1, nmse)) 169 | # end of accuracy check 170 | 171 | elapsed_time = time.time() - start 172 | print("elapsed_time:{0}".format(elapsed_time) + "[sec]") 173 | --------------------------------------------------------------------------------