├── imgs ├── result.png ├── weights.png ├── result_mnist.png └── weights_mnist.png ├── LICENSE ├── .gitignore ├── README.md ├── train_mnist.py ├── train.py └── network.py /imgs/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/Hopfield-Network/HEAD/imgs/result.png -------------------------------------------------------------------------------- /imgs/weights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/Hopfield-Network/HEAD/imgs/weights.png -------------------------------------------------------------------------------- /imgs/result_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/Hopfield-Network/HEAD/imgs/result_mnist.png -------------------------------------------------------------------------------- /imgs/weights_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/Hopfield-Network/HEAD/imgs/weights_mnist.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 takyamamoto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hopfield Network 2 | Hopfield network (Amari-Hopfield network) implemented with Python. Two update rules are implemented: **Asynchronous** & **Synchronous**. 3 | 4 | ## Requirement 5 | - Python >= 3.5 6 | - numpy 7 | - matplotlib 8 | - skimage 9 | - tqdm 10 | - keras (to load MNIST dataset) 11 | 12 | ## Usage 13 | Run `train.py` or `train_mnist.py`. 14 | 15 | ## Demo 16 | 17 | ### train.py 18 | The following is the result of using **Synchronous** update. 19 | ``` 20 | Start to data preprocessing... 21 | Start to train weights... 22 | 100%|██████████| 4/4 [00:06<00:00, 1.67s/it] 23 | Start to predict... 24 | 100%|██████████| 4/4 [00:02<00:00, 1.80it/s] 25 | Show prediction results... 26 | ``` 27 | 28 | 29 | ``` 30 | Show network weights matrix... 31 | ```` 32 | 33 | 34 | ### train_mnist.py 35 | The following is the result of using **Asynchronous** update. 36 | ``` 37 | Start to data preprocessing... 38 | Start to train weights... 39 | 100%|██████████| 3/3 [00:00<00:00, 274.99it/s] 40 | Start to predict... 41 | 100%|██████████| 3/3 [00:00<00:00, 32.52it/s] 42 | Show prediction results... 43 | ``` 44 | 45 | 46 | ``` 47 | Show network weights matrix... 48 | ```` 49 | 50 | 51 | ## Reference 52 | - Amari, "Neural theory of association and concept-formation", SI. Biol. Cybernetics (1977) 26: 175. https://doi.org/10.1007/BF00365229 53 | - J. J. Hopfield, "Neural networks and physical systems with emergent collective computational abilities", Proceedings of the National Academy of Sciences of the USA, vol. 79 no. 8 pp. 2554–2558, April 1982. 54 | -------------------------------------------------------------------------------- /train_mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jul 29 08:40:49 2018 4 | 5 | @author: user 6 | """ 7 | 8 | import numpy as np 9 | from matplotlib import pyplot as plt 10 | from skimage.filters import threshold_mean 11 | import network 12 | from keras.datasets import mnist 13 | 14 | # Utils 15 | def reshape(data): 16 | dim = int(np.sqrt(len(data))) 17 | data = np.reshape(data, (dim, dim)) 18 | return data 19 | 20 | def plot(data, test, predicted, figsize=(3, 3)): 21 | data = [reshape(d) for d in data] 22 | test = [reshape(d) for d in test] 23 | predicted = [reshape(d) for d in predicted] 24 | 25 | fig, axarr = plt.subplots(len(data), 3, figsize=figsize) 26 | for i in range(len(data)): 27 | if i==0: 28 | axarr[i, 0].set_title('Train data') 29 | axarr[i, 1].set_title("Input data") 30 | axarr[i, 2].set_title('Output data') 31 | 32 | axarr[i, 0].imshow(data[i]) 33 | axarr[i, 0].axis('off') 34 | axarr[i, 1].imshow(test[i]) 35 | axarr[i, 1].axis('off') 36 | axarr[i, 2].imshow(predicted[i]) 37 | axarr[i, 2].axis('off') 38 | 39 | plt.tight_layout() 40 | plt.savefig("result_mnist.png") 41 | plt.show() 42 | 43 | def preprocessing(img): 44 | w, h = img.shape 45 | # Thresholding 46 | thresh = threshold_mean(img) 47 | binary = img > thresh 48 | shift = 2*(binary*1)-1 # Boolian to int 49 | 50 | # Reshape 51 | flatten = np.reshape(shift, (w*h)) 52 | return flatten 53 | 54 | def main(): 55 | # Load data 56 | (x_train, y_train), (_, _ )= mnist.load_data() 57 | data = [] 58 | for i in range(3): 59 | xi = x_train[y_train==i] 60 | data.append(xi[0]) 61 | 62 | # Preprocessing 63 | print("Start to data preprocessing...") 64 | data = [preprocessing(d) for d in data] 65 | 66 | # Create Hopfield Network Model 67 | model = network.HopfieldNetwork() 68 | model.train_weights(data) 69 | 70 | # Make test datalist 71 | test = [] 72 | for i in range(3): 73 | xi = x_train[y_train==i] 74 | test.append(xi[1]) 75 | test = [preprocessing(d) for d in test] 76 | 77 | predicted = model.predict(test, threshold=50, asyn=True) 78 | print("Show prediction results...") 79 | plot(data, test, predicted, figsize=(5, 5)) 80 | print("Show network weights matrix...") 81 | model.plot_weights() 82 | 83 | if __name__ == '__main__': 84 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jul 29 08:40:49 2018 4 | 5 | @author: user 6 | """ 7 | 8 | import numpy as np 9 | np.random.seed(1) 10 | from matplotlib import pyplot as plt 11 | import skimage.data 12 | from skimage.color import rgb2gray 13 | from skimage.filters import threshold_mean 14 | from skimage.transform import resize 15 | import network 16 | 17 | # Utils 18 | def get_corrupted_input(input, corruption_level): 19 | corrupted = np.copy(input) 20 | inv = np.random.binomial(n=1, p=corruption_level, size=len(input)) 21 | for i, v in enumerate(input): 22 | if inv[i]: 23 | corrupted[i] = -1 * v 24 | return corrupted 25 | 26 | def reshape(data): 27 | dim = int(np.sqrt(len(data))) 28 | data = np.reshape(data, (dim, dim)) 29 | return data 30 | 31 | def plot(data, test, predicted, figsize=(5, 6)): 32 | data = [reshape(d) for d in data] 33 | test = [reshape(d) for d in test] 34 | predicted = [reshape(d) for d in predicted] 35 | 36 | fig, axarr = plt.subplots(len(data), 3, figsize=figsize) 37 | for i in range(len(data)): 38 | if i==0: 39 | axarr[i, 0].set_title('Train data') 40 | axarr[i, 1].set_title("Input data") 41 | axarr[i, 2].set_title('Output data') 42 | 43 | axarr[i, 0].imshow(data[i]) 44 | axarr[i, 0].axis('off') 45 | axarr[i, 1].imshow(test[i]) 46 | axarr[i, 1].axis('off') 47 | axarr[i, 2].imshow(predicted[i]) 48 | axarr[i, 2].axis('off') 49 | 50 | plt.tight_layout() 51 | plt.savefig("result.png") 52 | plt.show() 53 | 54 | def preprocessing(img, w=128, h=128): 55 | # Resize image 56 | img = resize(img, (w,h), mode='reflect') 57 | 58 | # Thresholding 59 | thresh = threshold_mean(img) 60 | binary = img > thresh 61 | shift = 2*(binary*1)-1 # Boolian to int 62 | 63 | # Reshape 64 | flatten = np.reshape(shift, (w*h)) 65 | return flatten 66 | 67 | def main(): 68 | # Load data 69 | camera = skimage.data.camera() 70 | astronaut = rgb2gray(skimage.data.astronaut()) 71 | horse = skimage.data.horse() 72 | coffee = rgb2gray(skimage.data.coffee()) 73 | 74 | # Marge data 75 | data = [camera, astronaut, horse, coffee] 76 | 77 | # Preprocessing 78 | print("Start to data preprocessing...") 79 | data = [preprocessing(d) for d in data] 80 | 81 | # Create Hopfield Network Model 82 | model = network.HopfieldNetwork() 83 | model.train_weights(data) 84 | 85 | # Generate testset 86 | test = [get_corrupted_input(d, 0.3) for d in data] 87 | 88 | predicted = model.predict(test, threshold=0, asyn=False) 89 | print("Show prediction results...") 90 | plot(data, test, predicted) 91 | print("Show network weights matrix...") 92 | #model.plot_weights() 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jul 29 08:40:49 2018 4 | 5 | @author: user 6 | """ 7 | 8 | import numpy as np 9 | from matplotlib import pyplot as plt 10 | import matplotlib.cm as cm 11 | from tqdm import tqdm 12 | 13 | class HopfieldNetwork(object): 14 | def train_weights(self, train_data): 15 | print("Start to train weights...") 16 | num_data = len(train_data) 17 | self.num_neuron = train_data[0].shape[0] 18 | 19 | # initialize weights 20 | W = np.zeros((self.num_neuron, self.num_neuron)) 21 | rho = np.sum([np.sum(t) for t in train_data]) / (num_data*self.num_neuron) 22 | 23 | # Hebb rule 24 | for i in tqdm(range(num_data)): 25 | t = train_data[i] - rho 26 | W += np.outer(t, t) 27 | 28 | # Make diagonal element of W into 0 29 | diagW = np.diag(np.diag(W)) 30 | W = W - diagW 31 | W /= num_data 32 | 33 | self.W = W 34 | 35 | def predict(self, data, num_iter=20, threshold=0, asyn=False): 36 | print("Start to predict...") 37 | self.num_iter = num_iter 38 | self.threshold = threshold 39 | self.asyn = asyn 40 | 41 | # Copy to avoid call by reference 42 | copied_data = np.copy(data) 43 | 44 | # Define predict list 45 | predicted = [] 46 | for i in tqdm(range(len(data))): 47 | predicted.append(self._run(copied_data[i])) 48 | return predicted 49 | 50 | def _run(self, init_s): 51 | if self.asyn==False: 52 | """ 53 | Synchronous update 54 | """ 55 | # Compute initial state energy 56 | s = init_s 57 | 58 | e = self.energy(s) 59 | 60 | # Iteration 61 | for i in range(self.num_iter): 62 | # Update s 63 | s = np.sign(self.W @ s - self.threshold) 64 | # Compute new state energy 65 | e_new = self.energy(s) 66 | 67 | # s is converged 68 | if e == e_new: 69 | return s 70 | # Update energy 71 | e = e_new 72 | return s 73 | else: 74 | """ 75 | Asynchronous update 76 | """ 77 | # Compute initial state energy 78 | s = init_s 79 | e = self.energy(s) 80 | 81 | # Iteration 82 | for i in range(self.num_iter): 83 | for j in range(100): 84 | # Select random neuron 85 | idx = np.random.randint(0, self.num_neuron) 86 | # Update s 87 | s[idx] = np.sign(self.W[idx].T @ s - self.threshold) 88 | 89 | # Compute new state energy 90 | e_new = self.energy(s) 91 | 92 | # s is converged 93 | if e == e_new: 94 | return s 95 | # Update energy 96 | e = e_new 97 | return s 98 | 99 | 100 | def energy(self, s): 101 | return -0.5 * s @ self.W @ s + np.sum(s * self.threshold) 102 | 103 | def plot_weights(self): 104 | plt.figure(figsize=(6, 5)) 105 | w_mat = plt.imshow(self.W, cmap=cm.coolwarm) 106 | plt.colorbar(w_mat) 107 | plt.title("Network Weights") 108 | plt.tight_layout() 109 | plt.savefig("weights.png") 110 | plt.show() --------------------------------------------------------------------------------