├── README.md ├── __init__.py ├── analyzer.py ├── compare_fixed_point.py ├── dataset.py ├── linear_approximation.py ├── model.py ├── plot_trajectories.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Fixed Point Analysis 2 | This is the implementation of fixed point analysis for Recurrent Neural Network by PyTorch. 3 | 4 | Sussillo, D., & Barak, O. (2013). [Opening the Black Box: Low-Dimensional Dynamics in High-Dimensional Recurrent Neural Networks.](https://doi.org/10.1162/NECO_a_00409) 5 | 6 | Niru Maheswaranathan. et al. (2019) [Universality and individuality in neural dynamics across large populations of recurrent networks.](https://papers.nips.cc/paper/9694-universality-and-individuality-in-neural-dynamics-across-large-populations-of-recurrent-networks) 7 | 8 | This repository contains the code for the analysis on the canonical task **Frequency-cued sine wave**, which is studied on these papers. 9 | 10 | 11 | # Experiments 12 | 13 | First, train your model by `train.py`. 14 | 15 | ## Trajectories and topology of fixed points 16 | 17 | - Plot trajectories and fixed points 18 | ```bash 19 | $ python plot_trajectories.py --activation relu 20 | ``` 21 | 22 | ![trajectory_relu](https://user-images.githubusercontent.com/24406002/71605599-7164f600-2bad-11ea-8fb1-5ffccb8b3f42.png) 23 | 24 | 25 | - Different points in the same trajectory correspond to one fixed point, 26 | and different trajectories correspond to different fixed point. 27 | 28 | ```bash 29 | $ python compare_fixed_point.py --activation relu 30 | ``` 31 | 32 | 33 | ``` 34 | distance between 2 fixed point start from different IC; different time of same trajectory. 35 | 2.2076301320339553e-07 36 | distance between 2 fixed point start from different IC; same time of different trajectories. 37 | 0.13503964245319366 38 | ``` 39 | 40 | ## Eigenvalue decomposition of Jacobian around fixed points. 41 | 42 | ```bash 43 | $ python linear_approximation.py --activation relu 44 | ``` 45 | 46 | - Distribution of eigenvalues 47 | 48 | ![relu_eigenvalues](https://user-images.githubusercontent.com/24406002/71605806-48ddfb80-2baf-11ea-8f33-62a9c10355eb.png) 49 | 50 | 51 | - There is the correlation with the frequencies of trajectories and the values of 52 | the imaginary part of the maximum eigenvalue of Jacobians. 53 | 54 | ![freq_relu](https://user-images.githubusercontent.com/24406002/71605816-54c9bd80-2baf-11ea-8310-fd92b3aff1eb.png) 55 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tripdancer0916/pytorch-fixed-point-analysis/00cec6f42b0f0b6f88c6dd7a37842c8fb8678807/__init__.py -------------------------------------------------------------------------------- /analyzer.py: -------------------------------------------------------------------------------- 1 | """class for fixed point analysis""" 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | class FixedPoint(object): 9 | def __init__(self, model, device, gamma=0.01, speed_tor=1e-06, max_epochs=200000, 10 | lr_decay_epoch=10000): 11 | self.model = model 12 | self.device = device 13 | self.gamma = gamma 14 | self.speed_tor = speed_tor 15 | self.max_epochs = max_epochs 16 | self.lr_decay_epoch = lr_decay_epoch 17 | 18 | self.model.eval() 19 | 20 | def calc_speed(self, hidden_activated, const_signal): 21 | input_signal = const_signal.permute(1, 0, 2) 22 | pre_activates = self.model.w_in(input_signal[0]) + self.model.w_hh(hidden_activated) 23 | 24 | if self.model.activation == 'relu': 25 | activated = F.relu(pre_activates) 26 | else: 27 | activated = torch.tanh(pre_activates) 28 | 29 | speed = torch.norm(activated - hidden_activated) 30 | 31 | return speed 32 | 33 | def find_fixed_point(self, init_hidden, const_signal, view=False): 34 | new_hidden = init_hidden.clone() 35 | gamma = self.gamma 36 | result_ok = True 37 | i = 0 38 | while True: 39 | hidden_activated = Variable(new_hidden).to(self.device) 40 | hidden_activated.requires_grad = True 41 | speed = self.calc_speed(hidden_activated, const_signal) 42 | if view and i % 1000 == 0: 43 | print(f'epoch: {i}, speed={speed.item()}') 44 | if speed.item() < self.speed_tor: 45 | print(f'epoch: {i}, speed={speed.item()}') 46 | break 47 | speed.backward() 48 | if i % self.lr_decay_epoch == 0 and i > 0: 49 | gamma *= 0.5 50 | if i == self.max_epochs: 51 | print(f'forcibly finished. speed={speed.item()}') 52 | result_ok = False 53 | break 54 | i += 1 55 | 56 | new_hidden = hidden_activated - gamma * hidden_activated.grad 57 | 58 | fixed_point = new_hidden[0, 0] 59 | return fixed_point, result_ok 60 | 61 | def calc_jacobian(self, fixed_point, const_signal_tensor): 62 | fixed_point = torch.unsqueeze(fixed_point, dim=1) 63 | fixed_point = Variable(fixed_point).to(self.device) 64 | fixed_point.requires_grad = True 65 | input_signal = const_signal_tensor.permute(1, 0, 2) 66 | w_hh = self.model.w_hh.weight 67 | w_hh.requires_grad = False 68 | w_hh = w_hh.to(self.device) 69 | pre_activates = torch.unsqueeze(self.model.w_in(input_signal[0])[0], dim=1) + \ 70 | w_hh @ fixed_point + torch.unsqueeze(self.model.w_hh.bias, dim=1) 71 | 72 | if self.model.activation == 'relu': 73 | activated = F.relu(pre_activates) 74 | else: 75 | activated = torch.tanh(pre_activates) 76 | 77 | jacobian = torch.zeros(self.model.n_hid, self.model.n_hid) 78 | for i in range(self.model.n_hid): 79 | output = torch.zeros(self.model.n_hid, 1).to(self.device) 80 | output[i] = 1. 81 | jacobian[:, i:i + 1] = torch.autograd.grad(activated, fixed_point, grad_outputs=output, retain_graph=True)[ 82 | 0] 83 | 84 | jacobian = jacobian.numpy().T 85 | 86 | return jacobian 87 | -------------------------------------------------------------------------------- /compare_fixed_point.py: -------------------------------------------------------------------------------- 1 | """compare fixed points""" 2 | 3 | import argparse 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from analyzer import FixedPoint 10 | from model import RecurrentNeuralNetwork 11 | 12 | 13 | def main(activation): 14 | os.makedirs('figures', exist_ok=True) 15 | freq_range = 51 16 | time_length = 40 17 | 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | model = RecurrentNeuralNetwork(n_in=1, n_out=1, n_hid=200, device=device, 20 | activation=activation, sigma=0, use_bias=True).to(device) 21 | 22 | model_path = f'trained_model/{activation}/epoch_1000.pth' 23 | model.load_state_dict(torch.load(model_path, map_location=device)) 24 | 25 | model.eval() 26 | 27 | analyzer = FixedPoint(model=model, device=device) 28 | 29 | # compare fixed points 30 | freq = 17 31 | const_signal1 = np.repeat(freq / freq_range + 0.25, time_length) 32 | const_signal1 = np.expand_dims(const_signal1, axis=1) 33 | const_signal_tensor1 = torch.from_numpy(np.array([const_signal1])) 34 | 35 | print(const_signal_tensor1.shape) 36 | 37 | hidden = torch.zeros(1, 200) 38 | hidden = hidden.to(device) 39 | const_signal_tensor1 = const_signal_tensor1.float().to(device) 40 | with torch.no_grad(): 41 | hidden_list, _, _ = model(const_signal_tensor1, hidden) 42 | 43 | # different time of same trajectory. 44 | fixed_point1, _ = analyzer.find_fixed_point(torch.unsqueeze(hidden_list[:, 20, :], dim=0).to(device), 45 | const_signal_tensor1, view=True) 46 | fixed_point2, _ = analyzer.find_fixed_point(torch.unsqueeze(hidden_list[:, 15, :], dim=0).to(device), 47 | const_signal_tensor1) 48 | 49 | print('distance between 2 fixed point start from different IC; different time of same trajectory.') 50 | print(torch.norm(fixed_point1 - fixed_point2).item()) 51 | 52 | # same time of different trajectories. 53 | freq = 18 54 | const_signal2 = np.repeat(freq / freq_range + 0.25, time_length) 55 | const_signal2 = np.expand_dims(const_signal2, axis=1) 56 | const_signal_tensor2 = torch.from_numpy(np.array([const_signal2])) 57 | 58 | hidden = torch.zeros(1, 200) 59 | hidden = hidden.to(device) 60 | const_signal_tensor2 = const_signal_tensor2.float().to(device) 61 | with torch.no_grad(): 62 | hidden_list, _, _ = model(const_signal_tensor2, hidden) 63 | 64 | fixed_point3, _ = analyzer.find_fixed_point(torch.unsqueeze(hidden_list[:, 20, :], dim=0).to(device), 65 | const_signal_tensor2) 66 | print('distance between 2 fixed point start from different IC; same time of different trajectories.') 67 | print(torch.norm(fixed_point1 - fixed_point3).item()) 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser(description='PyTorch RNN training') 72 | parser.add_argument('--activation', type=str, default='tanh') 73 | args = parser.parse_args() 74 | # print(args) 75 | main(args.activation) 76 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """generating input and target""" 2 | 3 | import numpy as np 4 | import torch.utils.data as data 5 | 6 | 7 | class SineWave(data.Dataset): 8 | def __init__(self, time_length=50, freq_range=10): 9 | self.time_length = time_length 10 | self.freq_range = freq_range 11 | 12 | def __len__(self): 13 | return 200 14 | 15 | def __getitem__(self, item): 16 | freq = np.random.randint(1, self.freq_range + 1) 17 | const_signal = np.repeat(freq / self.freq_range + 0.25, self.time_length) 18 | const_signal = np.expand_dims(const_signal, axis=1) 19 | t = np.arange(0, self.time_length*0.025, 0.025) 20 | target = np.sin(freq * t) 21 | target = np.expand_dims(target, axis=1) 22 | return const_signal, target 23 | -------------------------------------------------------------------------------- /linear_approximation.py: -------------------------------------------------------------------------------- 1 | """fixed point analysis""" 2 | 3 | import argparse 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | from analyzer import FixedPoint 12 | from model import RecurrentNeuralNetwork 13 | 14 | 15 | def main(activation): 16 | os.makedirs('figures', exist_ok=True) 17 | freq_range = 51 18 | time_length = 40 19 | 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | model = RecurrentNeuralNetwork(n_in=1, n_out=1, n_hid=200, device=device, 22 | activation=activation, sigma=0, use_bias=True).to(device) 23 | 24 | model_path = f'trained_model/{activation}/epoch_1000.pth' 25 | model.load_state_dict(torch.load(model_path, map_location=device)) 26 | 27 | model.eval() 28 | 29 | freq = 17 30 | const_signal = np.repeat(freq / freq_range + 0.25, time_length) 31 | const_signal = np.expand_dims(const_signal, axis=1) 32 | const_signal_tensor = torch.from_numpy(np.array([const_signal])) 33 | 34 | analyzer = FixedPoint(model=model, device=device) 35 | 36 | hidden = torch.zeros(1, 200) 37 | hidden = hidden.to(device) 38 | const_signal_tensor = const_signal_tensor.float().to(device) 39 | with torch.no_grad(): 40 | hidden_list, _, _ = model(const_signal_tensor, hidden) 41 | 42 | fixed_point, _ = analyzer.find_fixed_point(torch.unsqueeze(hidden_list[:, 20, :], dim=0).to(device), 43 | const_signal_tensor, view=True) 44 | 45 | # linear approximation around fixed point 46 | jacobian = analyzer.calc_jacobian(fixed_point, const_signal_tensor) 47 | 48 | # eigenvalue decomposition 49 | w, v = np.linalg.eig(jacobian) 50 | w_real = list() 51 | w_im = list() 52 | for eig in w: 53 | w_real.append(eig.real) 54 | w_im.append(eig.imag) 55 | plt.scatter(w_real, w_im) 56 | plt.xlabel(r'$Re(\lambda)$') 57 | plt.ylabel(r'$Im(\lambda)$') 58 | plt.savefig(f'figures/{activation}_eigenvalues.png', dpi=100) 59 | 60 | eig_freq = list() 61 | dynamics_freq = list() 62 | for i in range(20): 63 | freq = np.random.randint(1, freq_range + 1) 64 | const_signal = np.repeat(freq / freq_range + 0.25, time_length) 65 | const_signal = np.expand_dims(const_signal, axis=1) 66 | const_signal_tensor = torch.from_numpy(np.array([const_signal])) 67 | 68 | hidden = torch.zeros(1, 200) 69 | hidden = hidden.to(device) 70 | const_signal_tensor = const_signal_tensor.float().to(device) 71 | with torch.no_grad(): 72 | hidden_list, _, _ = model(const_signal_tensor, hidden) 73 | 74 | fixed_point, result_ok = analyzer.find_fixed_point(torch.unsqueeze(hidden_list[:, 20, :], dim=0).to(device), 75 | const_signal_tensor) 76 | if not result_ok: 77 | continue 78 | 79 | jacobian = analyzer.calc_jacobian(fixed_point, const_signal_tensor) 80 | w, v = np.linalg.eig(jacobian) 81 | max_index = np.argmax(abs(w)) 82 | eig_freq.append(abs(w[max_index].imag)) 83 | dynamics_freq.append(freq) 84 | 85 | plt.figure() 86 | plt.scatter(eig_freq, dynamics_freq) 87 | plt.xlabel(r'$|Im(\lambda_{max})|$') 88 | plt.ylabel(r'$\omega$') 89 | plt.title('relationship of frequency') 90 | plt.savefig(f'figures/freq_{activation}.png', dpi=100) 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser(description='PyTorch RNN training') 95 | parser.add_argument('--activation', type=str, default='tanh') 96 | args = parser.parse_args() 97 | # print(args) 98 | main(args.activation) 99 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """define recurrent neural networks""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class RecurrentNeuralNetwork(nn.Module): 9 | def __init__(self, n_in, n_out, n_hid, device, 10 | activation='relu', sigma=0.05, use_bias=True): 11 | super(RecurrentNeuralNetwork, self).__init__() 12 | self.n_in = n_in 13 | self.n_hid = n_hid 14 | self.n_out = n_out 15 | self.w_in = nn.Linear(n_in, n_hid, bias=use_bias) 16 | self.w_hh = nn.Linear(n_hid, n_hid, bias=use_bias) 17 | self.w_out = nn.Linear(n_hid, n_out, bias=use_bias) 18 | 19 | self.activation = activation 20 | self.sigma = sigma 21 | self.device = device 22 | 23 | def forward(self, input_signal, hidden): 24 | num_batch = input_signal.size(0) 25 | length = input_signal.size(1) 26 | hidden_list = torch.zeros(length, num_batch, self.n_hid).type_as(input_signal.data) 27 | output_list = torch.zeros(length, num_batch, self.n_out).type_as(input_signal.data) 28 | 29 | input_signal = input_signal.permute(1, 0, 2) 30 | 31 | for t in range(length): 32 | 33 | pre_activates = self.w_in(input_signal[t]) + self.w_hh(hidden) 34 | 35 | if self.activation == 'relu': 36 | hidden = F.relu(pre_activates) 37 | else: 38 | hidden = torch.tanh(pre_activates) 39 | 40 | output = self.w_out(hidden) 41 | hidden_list[t] = hidden 42 | output_list[t] = output 43 | hidden_list = hidden_list.permute(1, 0, 2) 44 | output_list = output_list.permute(1, 0, 2) 45 | return hidden_list, output_list, hidden 46 | -------------------------------------------------------------------------------- /plot_trajectories.py: -------------------------------------------------------------------------------- 1 | """plot trajectories and fixed points in the PCA space.""" 2 | 3 | import argparse 4 | import os 5 | 6 | import matplotlib 7 | import numpy as np 8 | import torch 9 | 10 | matplotlib.use('Agg') 11 | 12 | import matplotlib.pyplot as plt 13 | from mpl_toolkits.mplot3d import Axes3D 14 | 15 | from sklearn.decomposition import PCA 16 | 17 | from analyzer import FixedPoint 18 | from model import RecurrentNeuralNetwork 19 | 20 | 21 | def main(activation): 22 | os.makedirs('figures', exist_ok=True) 23 | freq_range = 51 24 | time_length = 40 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | model = RecurrentNeuralNetwork(n_in=1, n_out=1, n_hid=200, device=device, 28 | activation=activation, sigma=0, use_bias=True).to(device) 29 | 30 | model_path = f'trained_model/{activation}/epoch_1000.pth' 31 | model.load_state_dict(torch.load(model_path, map_location=device)) 32 | 33 | model.eval() 34 | 35 | analyzer = FixedPoint(model=model, device=device, max_epochs=200000) 36 | 37 | hidden_list_list = np.zeros([30 * time_length, model.n_hid]) 38 | fixed_point_list = np.zeros([15, model.n_hid]) 39 | i = 0 40 | while i < 15: 41 | freq = np.random.randint(10, freq_range + 1) 42 | const_signal = np.repeat(freq / freq_range + 0.25, time_length) 43 | const_signal = np.expand_dims(const_signal, axis=1) 44 | const_signal_tensor = torch.from_numpy(np.array([const_signal])) 45 | 46 | hidden = torch.zeros(1, 200) 47 | hidden = hidden.to(device) 48 | const_signal_tensor = const_signal_tensor.float().to(device) 49 | with torch.no_grad(): 50 | hidden_list, _, _ = model(const_signal_tensor, hidden) 51 | 52 | fixed_point, result_ok = analyzer.find_fixed_point(torch.unsqueeze(hidden_list[:, 20, :], dim=0).to(device), 53 | const_signal_tensor) 54 | if not result_ok: 55 | continue 56 | 57 | hidden_list_list[i * time_length:(i + 1) * time_length, ...] = hidden_list.cpu().numpy()[:, ...] 58 | fixed_point_list[i] = fixed_point.detach().cpu().numpy() 59 | i += 1 60 | 61 | pca = PCA(n_components=3) 62 | pca.fit(hidden_list_list) 63 | 64 | fig = plt.figure() 65 | ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=45, azim=134) 66 | 67 | ax.set_xlabel('PC1') 68 | ax.set_ylabel('PC2') 69 | ax.set_zlabel('PC3') 70 | 71 | print(hidden_list_list.shape) 72 | print(fixed_point_list.shape) 73 | pc_trajectory = pca.transform(hidden_list_list) 74 | pc_fixed_point = pca.transform(fixed_point_list) 75 | 76 | for i in range(15): 77 | ax.plot(pc_trajectory.T[0, i * time_length:(i + 1) * time_length], 78 | pc_trajectory.T[1, i * time_length:(i + 1) * time_length], 79 | pc_trajectory.T[2, i * time_length:(i + 1) * time_length], color='royalblue') 80 | ax.scatter(pc_fixed_point.T[0], pc_fixed_point.T[1], pc_fixed_point.T[2], color='red', marker='x') 81 | plt.title('trajectory') 82 | plt.savefig(f'figures/trajectory_{activation}.png', dpi=100) 83 | 84 | 85 | if __name__ == '__main__': 86 | parser = argparse.ArgumentParser(description='PyTorch RNN training') 87 | parser.add_argument('--activation', type=str, default='tanh') 88 | args = parser.parse_args() 89 | # print(args) 90 | main(args.activation) 91 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """training models""" 2 | 3 | import argparse 4 | import os 5 | import sys 6 | 7 | import numpy as np 8 | import torch 9 | import torch.optim as optim 10 | 11 | sys.path.append('../') 12 | 13 | from torch.autograd import Variable 14 | 15 | from dataset import SineWave 16 | from model import RecurrentNeuralNetwork 17 | 18 | 19 | def main(activation): 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | print(device) 22 | 23 | os.makedirs('trained_model', exist_ok=True) 24 | save_path = f'trained_model/{activation}' 25 | os.makedirs(save_path, exist_ok=True) 26 | 27 | model = RecurrentNeuralNetwork(n_in=1, n_out=1, n_hid=200, device=device, 28 | activation=activation, sigma=0, use_bias=True).to(device) 29 | 30 | train_dataset = SineWave(freq_range=51, time_length=40) 31 | 32 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=50, 33 | num_workers=2, shuffle=True, 34 | worker_init_fn=lambda x: np.random.seed()) 35 | 36 | print(model) 37 | 38 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 39 | lr=0.001, weight_decay=0.0001) 40 | 41 | for epoch in range(2001): 42 | model.train() 43 | for i, data in enumerate(train_dataloader): 44 | inputs, target, = data 45 | inputs, target, = inputs.float(), target.float() 46 | inputs, target = Variable(inputs).to(device), Variable(target).to(device) 47 | 48 | hidden = torch.zeros(50, 200) 49 | hidden = hidden.to(device) 50 | 51 | optimizer.zero_grad() 52 | hidden = hidden.detach() 53 | hidden_list, output, hidden = model(inputs, hidden) 54 | 55 | loss = torch.nn.MSELoss()(output, target) 56 | loss.backward() 57 | optimizer.step() 58 | 59 | if epoch > 0 and epoch % 200 == 0: 60 | print(f'Train Epoch: {epoch}, Loss: {loss.item():.6f}') 61 | print('output', output[0, :, 0].cpu().detach().numpy()) 62 | print('target', target[0, :, 0].cpu().detach().numpy()) 63 | torch.save(model.state_dict(), os.path.join(save_path, f'epoch_{epoch}.pth')) 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser(description='PyTorch RNN training') 68 | parser.add_argument('--activation', type=str, default='tanh') 69 | args = parser.parse_args() 70 | # print(args) 71 | main(args.activation) 72 | --------------------------------------------------------------------------------