├── imgs
├── result.png
├── weights.png
├── result_mnist.png
└── weights_mnist.png
├── LICENSE
├── .gitignore
├── README.md
├── train_mnist.py
├── train.py
└── network.py
/imgs/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takyamamoto/Hopfield-Network/HEAD/imgs/result.png
--------------------------------------------------------------------------------
/imgs/weights.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takyamamoto/Hopfield-Network/HEAD/imgs/weights.png
--------------------------------------------------------------------------------
/imgs/result_mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takyamamoto/Hopfield-Network/HEAD/imgs/result_mnist.png
--------------------------------------------------------------------------------
/imgs/weights_mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takyamamoto/Hopfield-Network/HEAD/imgs/weights_mnist.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 takyamamoto
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Hopfield Network
2 | Hopfield network (Amari-Hopfield network) implemented with Python. Two update rules are implemented: **Asynchronous** & **Synchronous**.
3 |
4 | ## Requirement
5 | - Python >= 3.5
6 | - numpy
7 | - matplotlib
8 | - skimage
9 | - tqdm
10 | - keras (to load MNIST dataset)
11 |
12 | ## Usage
13 | Run `train.py` or `train_mnist.py`.
14 |
15 | ## Demo
16 |
17 | ### train.py
18 | The following is the result of using **Synchronous** update.
19 | ```
20 | Start to data preprocessing...
21 | Start to train weights...
22 | 100%|██████████| 4/4 [00:06<00:00, 1.67s/it]
23 | Start to predict...
24 | 100%|██████████| 4/4 [00:02<00:00, 1.80it/s]
25 | Show prediction results...
26 | ```
27 |
28 |
29 | ```
30 | Show network weights matrix...
31 | ````
32 |
33 |
34 | ### train_mnist.py
35 | The following is the result of using **Asynchronous** update.
36 | ```
37 | Start to data preprocessing...
38 | Start to train weights...
39 | 100%|██████████| 3/3 [00:00<00:00, 274.99it/s]
40 | Start to predict...
41 | 100%|██████████| 3/3 [00:00<00:00, 32.52it/s]
42 | Show prediction results...
43 | ```
44 |
45 |
46 | ```
47 | Show network weights matrix...
48 | ````
49 |
50 |
51 | ## Reference
52 | - Amari, "Neural theory of association and concept-formation", SI. Biol. Cybernetics (1977) 26: 175. https://doi.org/10.1007/BF00365229
53 | - J. J. Hopfield, "Neural networks and physical systems with emergent collective computational abilities", Proceedings of the National Academy of Sciences of the USA, vol. 79 no. 8 pp. 2554–2558, April 1982.
54 |
--------------------------------------------------------------------------------
/train_mnist.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sun Jul 29 08:40:49 2018
4 |
5 | @author: user
6 | """
7 |
8 | import numpy as np
9 | from matplotlib import pyplot as plt
10 | from skimage.filters import threshold_mean
11 | import network
12 | from keras.datasets import mnist
13 |
14 | # Utils
15 | def reshape(data):
16 | dim = int(np.sqrt(len(data)))
17 | data = np.reshape(data, (dim, dim))
18 | return data
19 |
20 | def plot(data, test, predicted, figsize=(3, 3)):
21 | data = [reshape(d) for d in data]
22 | test = [reshape(d) for d in test]
23 | predicted = [reshape(d) for d in predicted]
24 |
25 | fig, axarr = plt.subplots(len(data), 3, figsize=figsize)
26 | for i in range(len(data)):
27 | if i==0:
28 | axarr[i, 0].set_title('Train data')
29 | axarr[i, 1].set_title("Input data")
30 | axarr[i, 2].set_title('Output data')
31 |
32 | axarr[i, 0].imshow(data[i])
33 | axarr[i, 0].axis('off')
34 | axarr[i, 1].imshow(test[i])
35 | axarr[i, 1].axis('off')
36 | axarr[i, 2].imshow(predicted[i])
37 | axarr[i, 2].axis('off')
38 |
39 | plt.tight_layout()
40 | plt.savefig("result_mnist.png")
41 | plt.show()
42 |
43 | def preprocessing(img):
44 | w, h = img.shape
45 | # Thresholding
46 | thresh = threshold_mean(img)
47 | binary = img > thresh
48 | shift = 2*(binary*1)-1 # Boolian to int
49 |
50 | # Reshape
51 | flatten = np.reshape(shift, (w*h))
52 | return flatten
53 |
54 | def main():
55 | # Load data
56 | (x_train, y_train), (_, _ )= mnist.load_data()
57 | data = []
58 | for i in range(3):
59 | xi = x_train[y_train==i]
60 | data.append(xi[0])
61 |
62 | # Preprocessing
63 | print("Start to data preprocessing...")
64 | data = [preprocessing(d) for d in data]
65 |
66 | # Create Hopfield Network Model
67 | model = network.HopfieldNetwork()
68 | model.train_weights(data)
69 |
70 | # Make test datalist
71 | test = []
72 | for i in range(3):
73 | xi = x_train[y_train==i]
74 | test.append(xi[1])
75 | test = [preprocessing(d) for d in test]
76 |
77 | predicted = model.predict(test, threshold=50, asyn=True)
78 | print("Show prediction results...")
79 | plot(data, test, predicted, figsize=(5, 5))
80 | print("Show network weights matrix...")
81 | model.plot_weights()
82 |
83 | if __name__ == '__main__':
84 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sun Jul 29 08:40:49 2018
4 |
5 | @author: user
6 | """
7 |
8 | import numpy as np
9 | np.random.seed(1)
10 | from matplotlib import pyplot as plt
11 | import skimage.data
12 | from skimage.color import rgb2gray
13 | from skimage.filters import threshold_mean
14 | from skimage.transform import resize
15 | import network
16 |
17 | # Utils
18 | def get_corrupted_input(input, corruption_level):
19 | corrupted = np.copy(input)
20 | inv = np.random.binomial(n=1, p=corruption_level, size=len(input))
21 | for i, v in enumerate(input):
22 | if inv[i]:
23 | corrupted[i] = -1 * v
24 | return corrupted
25 |
26 | def reshape(data):
27 | dim = int(np.sqrt(len(data)))
28 | data = np.reshape(data, (dim, dim))
29 | return data
30 |
31 | def plot(data, test, predicted, figsize=(5, 6)):
32 | data = [reshape(d) for d in data]
33 | test = [reshape(d) for d in test]
34 | predicted = [reshape(d) for d in predicted]
35 |
36 | fig, axarr = plt.subplots(len(data), 3, figsize=figsize)
37 | for i in range(len(data)):
38 | if i==0:
39 | axarr[i, 0].set_title('Train data')
40 | axarr[i, 1].set_title("Input data")
41 | axarr[i, 2].set_title('Output data')
42 |
43 | axarr[i, 0].imshow(data[i])
44 | axarr[i, 0].axis('off')
45 | axarr[i, 1].imshow(test[i])
46 | axarr[i, 1].axis('off')
47 | axarr[i, 2].imshow(predicted[i])
48 | axarr[i, 2].axis('off')
49 |
50 | plt.tight_layout()
51 | plt.savefig("result.png")
52 | plt.show()
53 |
54 | def preprocessing(img, w=128, h=128):
55 | # Resize image
56 | img = resize(img, (w,h), mode='reflect')
57 |
58 | # Thresholding
59 | thresh = threshold_mean(img)
60 | binary = img > thresh
61 | shift = 2*(binary*1)-1 # Boolian to int
62 |
63 | # Reshape
64 | flatten = np.reshape(shift, (w*h))
65 | return flatten
66 |
67 | def main():
68 | # Load data
69 | camera = skimage.data.camera()
70 | astronaut = rgb2gray(skimage.data.astronaut())
71 | horse = skimage.data.horse()
72 | coffee = rgb2gray(skimage.data.coffee())
73 |
74 | # Marge data
75 | data = [camera, astronaut, horse, coffee]
76 |
77 | # Preprocessing
78 | print("Start to data preprocessing...")
79 | data = [preprocessing(d) for d in data]
80 |
81 | # Create Hopfield Network Model
82 | model = network.HopfieldNetwork()
83 | model.train_weights(data)
84 |
85 | # Generate testset
86 | test = [get_corrupted_input(d, 0.3) for d in data]
87 |
88 | predicted = model.predict(test, threshold=0, asyn=False)
89 | print("Show prediction results...")
90 | plot(data, test, predicted)
91 | print("Show network weights matrix...")
92 | #model.plot_weights()
93 |
94 | if __name__ == '__main__':
95 | main()
96 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sun Jul 29 08:40:49 2018
4 |
5 | @author: user
6 | """
7 |
8 | import numpy as np
9 | from matplotlib import pyplot as plt
10 | import matplotlib.cm as cm
11 | from tqdm import tqdm
12 |
13 | class HopfieldNetwork(object):
14 | def train_weights(self, train_data):
15 | print("Start to train weights...")
16 | num_data = len(train_data)
17 | self.num_neuron = train_data[0].shape[0]
18 |
19 | # initialize weights
20 | W = np.zeros((self.num_neuron, self.num_neuron))
21 | rho = np.sum([np.sum(t) for t in train_data]) / (num_data*self.num_neuron)
22 |
23 | # Hebb rule
24 | for i in tqdm(range(num_data)):
25 | t = train_data[i] - rho
26 | W += np.outer(t, t)
27 |
28 | # Make diagonal element of W into 0
29 | diagW = np.diag(np.diag(W))
30 | W = W - diagW
31 | W /= num_data
32 |
33 | self.W = W
34 |
35 | def predict(self, data, num_iter=20, threshold=0, asyn=False):
36 | print("Start to predict...")
37 | self.num_iter = num_iter
38 | self.threshold = threshold
39 | self.asyn = asyn
40 |
41 | # Copy to avoid call by reference
42 | copied_data = np.copy(data)
43 |
44 | # Define predict list
45 | predicted = []
46 | for i in tqdm(range(len(data))):
47 | predicted.append(self._run(copied_data[i]))
48 | return predicted
49 |
50 | def _run(self, init_s):
51 | if self.asyn==False:
52 | """
53 | Synchronous update
54 | """
55 | # Compute initial state energy
56 | s = init_s
57 |
58 | e = self.energy(s)
59 |
60 | # Iteration
61 | for i in range(self.num_iter):
62 | # Update s
63 | s = np.sign(self.W @ s - self.threshold)
64 | # Compute new state energy
65 | e_new = self.energy(s)
66 |
67 | # s is converged
68 | if e == e_new:
69 | return s
70 | # Update energy
71 | e = e_new
72 | return s
73 | else:
74 | """
75 | Asynchronous update
76 | """
77 | # Compute initial state energy
78 | s = init_s
79 | e = self.energy(s)
80 |
81 | # Iteration
82 | for i in range(self.num_iter):
83 | for j in range(100):
84 | # Select random neuron
85 | idx = np.random.randint(0, self.num_neuron)
86 | # Update s
87 | s[idx] = np.sign(self.W[idx].T @ s - self.threshold)
88 |
89 | # Compute new state energy
90 | e_new = self.energy(s)
91 |
92 | # s is converged
93 | if e == e_new:
94 | return s
95 | # Update energy
96 | e = e_new
97 | return s
98 |
99 |
100 | def energy(self, s):
101 | return -0.5 * s @ self.W @ s + np.sum(s * self.threshold)
102 |
103 | def plot_weights(self):
104 | plt.figure(figsize=(6, 5))
105 | w_mat = plt.imshow(self.W, cmap=cm.coolwarm)
106 | plt.colorbar(w_mat)
107 | plt.title("Network Weights")
108 | plt.tight_layout()
109 | plt.savefig("weights.png")
110 | plt.show()
--------------------------------------------------------------------------------