├── How to implement a neural network for handwritten digit recognition in Java.pdf
├── README.md
├── imageResource
├── image-20241217173028697.png
├── image-20241217184752242.png
├── image-20241217185102809.png
├── image-20241217185807044.png
├── image-20241217185809888.png
├── image-20241217185844494.png
├── image-20241217190426681.png
└── image-20241217190518353.png
├── resources
├── TrainImg
│ └── txt.txt
├── VerifyImg
│ └── txt.txt
├── database
│ ├── t10k-images.idx3-ubyte
│ ├── t10k-labels.idx1-ubyte
│ ├── train-images-idx3-ubyte
│ └── train-labels-idx1-ubyte
└── model
│ └── Model_Data.txt
├── src
├── Application.java
└── functions
│ ├── DevelopTool.java
│ ├── GUI.java
│ ├── LanguageSetting.java
│ ├── MNISTReader.java
│ └── Processing.java
└── 如何在Java中实现手写数字识别神经网络.pdf
/How to implement a neural network for handwritten digit recognition in Java.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/How to implement a neural network for handwritten digit recognition in Java.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Java Hand Writing Digit Recognition
2 |
3 | > # Java手写数字识别
4 |
5 | Read [How to implement a neural network for handwritten digit recognition in Java.pdf](https://github.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/blob/main/How%20to%20implement%20a%20neural%20network%20for%20handwritten%20digit%20recognition%20in%20Java.pdf) for more information!
6 |
7 | > 阅读 [如何在Java中实现手写数字识别神经网络.pdf](https://github.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/blob/main/如何在Java中实现手写数字识别神经网络.pdf) 文件以获得更多信息!
8 |
9 |
10 |
11 | ## Setup
12 |
13 | > ## 初始化
14 |
15 | Import it and then run the `Application` class to start the program.
16 |
17 | > 导入后运行`Application` 类以启动程序
18 |
19 |
20 |
21 | **STEP 1:**
22 |
23 | $\color{red}VERY~~IMPORTANT$
24 |
25 | $\color{red}VERY~~IMPORTANT$
26 |
27 | $\color{red}VERY~~IMPORTANT$
28 |
29 | When starting the program for the first time, please click on `Settings` -> `Refresh Resource Files Again` in the upper left corner. (Only for the first time setup)
30 |
31 | > 初次启动程序请单击左上角 `设置` -> `重新刷新资源文件` (仅仅针对第一次启动)
32 |
33 |
34 |
35 | **STEP 2:**
36 |
37 | Then, click on `Read the Previous Model` to obtain the already trained model. (Every single time the program is started, the program will not read any model. Instead, it will randomize the neural network.) :)
38 |
39 | > 然后,点击`读取先前模型`来获取已经训练好的模型(每次启动程序,程序都不会读取任何模型,而是把神经网络做一个随机):)
40 |
41 |
42 |
43 | **STEP 3:**
44 |
45 | Try to write a number and then click on `Predict This Image` button.
46 |
47 | > 尝试写一个数字,然后点击`预测此图像`
48 |
49 | 
50 |
51 | ## Verify Accuracy Of The Current Model
52 |
53 | > ## 验证当前模型的准确度
54 |
55 |
56 |
57 | Click `Verify the Model Accuracy with the Validation Set` button.
58 |
59 | > 按下`以验证集验证模型准确度`按钮
60 |
61 | 
62 |
--------------------------------------------------------------------------------
/imageResource/image-20241217173028697.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217173028697.png
--------------------------------------------------------------------------------
/imageResource/image-20241217184752242.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217184752242.png
--------------------------------------------------------------------------------
/imageResource/image-20241217185102809.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217185102809.png
--------------------------------------------------------------------------------
/imageResource/image-20241217185807044.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217185807044.png
--------------------------------------------------------------------------------
/imageResource/image-20241217185809888.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217185809888.png
--------------------------------------------------------------------------------
/imageResource/image-20241217185844494.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217185844494.png
--------------------------------------------------------------------------------
/imageResource/image-20241217190426681.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217190426681.png
--------------------------------------------------------------------------------
/imageResource/image-20241217190518353.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217190518353.png
--------------------------------------------------------------------------------
/resources/TrainImg/txt.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/resources/TrainImg/txt.txt
--------------------------------------------------------------------------------
/resources/VerifyImg/txt.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/resources/VerifyImg/txt.txt
--------------------------------------------------------------------------------
/resources/database/t10k-images.idx3-ubyte:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/resources/database/t10k-images.idx3-ubyte
--------------------------------------------------------------------------------
/resources/database/t10k-labels.idx1-ubyte:
--------------------------------------------------------------------------------
1 | '
--------------------------------------------------------------------------------
/resources/database/train-images-idx3-ubyte:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/resources/database/train-images-idx3-ubyte
--------------------------------------------------------------------------------
/resources/database/train-labels-idx1-ubyte:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/resources/database/train-labels-idx1-ubyte
--------------------------------------------------------------------------------
/src/Application.java:
--------------------------------------------------------------------------------
1 | import functions.GUI;
2 | import java.io.IOException;
3 |
4 | public class Application {
5 | public static void main(String[] args) throws IOException {
6 | new GUI();
7 | }
8 |
9 | }
10 |
--------------------------------------------------------------------------------
/src/functions/DevelopTool.java:
--------------------------------------------------------------------------------
1 | package functions;
2 |
3 | import javax.imageio.ImageIO;
4 | import java.awt.image.BufferedImage;
5 | import java.io.*;
6 | import java.util.ArrayList;
7 | import java.util.stream.Stream;
8 |
9 | public class DevelopTool {
10 | static final String PATH_TRAIN_IMG_FOLDER = "resources/TrainImg";
11 | static final String PATH_VERIFY_IMG_FOLDER = "resources/VerifyImg";
12 | static final String PATH_TRAIN_TXT = "resources/database/MNIST_TRAIN_DATA.txt";
13 | static final String PATH_VERIFY_TXT = "resources/database/MNIST_VERIFY_DATA.txt";
14 | private DevelopTool(){}
15 |
16 | public static void clearCache() {
17 | // 清空图像文件
18 | System.out.println("[Deleting img folder...]");
19 | deleteAllFilesInFolder(PATH_TRAIN_IMG_FOLDER);
20 | deleteAllFilesInFolder(PATH_VERIFY_IMG_FOLDER);
21 |
22 | // 删除txt文件
23 | System.out.println("[Deleting txt file...]");
24 | deleteTxtFile(PATH_TRAIN_TXT);
25 | deleteTxtFile(PATH_VERIFY_TXT);
26 |
27 | }
28 |
29 | // 对外初始化
30 | public static void regenerateResourcesFile() throws IOException {
31 | System.out.println("[Initializing...]");
32 |
33 | // 清空图像文件
34 | System.out.println("[Deleting img folder...]");
35 | deleteAllFilesInFolder(PATH_TRAIN_IMG_FOLDER);
36 | deleteAllFilesInFolder(PATH_VERIFY_IMG_FOLDER);
37 |
38 |
39 | // 删除txt文件
40 | System.out.println("[Deleting txt file...]");
41 | deleteTxtFile(PATH_TRAIN_TXT);
42 | deleteTxtFile(PATH_VERIFY_TXT);
43 |
44 |
45 | // 写入txt
46 | writeTxtFile(PATH_TRAIN_TXT);
47 | writeTxtFile(PATH_VERIFY_TXT);
48 |
49 | // 生成train img文件
50 | genImg(PATH_TRAIN_IMG_FOLDER);
51 | genImg(PATH_VERIFY_IMG_FOLDER);
52 |
53 | System.out.println("[Finish]");
54 |
55 |
56 | }
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 | private static void genImg(String PATH) throws IOException {
66 | BufferedReader reader;
67 | if(PATH.equals(PATH_TRAIN_IMG_FOLDER)) {
68 | reader = new BufferedReader(new FileReader(PATH_TRAIN_TXT));
69 | }
70 | else
71 | {
72 | reader = new BufferedReader(new FileReader(PATH_VERIFY_TXT));
73 | }
74 | String line ;
75 |
76 | int countProcess = 1;
77 |
78 | while((line = reader.readLine()) != null){
79 | if(line.charAt(0) == 'L') {
80 | ++countProcess;
81 | continue;
82 | }
83 |
84 | int length = stringToDoubleArray(line).length;
85 |
86 | double[][] array = new double[length][length];
87 | double[] temp = stringToDoubleArray(line);
88 | for(int i = 0; i < length; ++i){
89 | array[0][i] = temp[i];
90 | }
91 | for(int i = 1; i <= length-1; ++i){
92 | line = reader.readLine();
93 | temp = stringToDoubleArray(line);
94 | for(int j = 0; j < length; ++j){
95 | array[i][j] = temp[j];
96 | }
97 | }
98 |
99 |
100 | if(PATH.equals(PATH_TRAIN_IMG_FOLDER)) {
101 | if(countProcess%1000==0)
102 | {
103 | System.out.println("[Generating image. Process] \t"+String.format("%.2f",((double)countProcess/60000.00*100.00))+"%");
104 | }
105 | }
106 | else
107 | {
108 | if(countProcess%1000==0)
109 | {
110 | System.out.println("[Generating image. Process] \t"+String.format("%.2f",((double)countProcess/10000.00*100.00))+"%");
111 | }
112 | }
113 |
114 |
115 | if(PATH.equals(PATH_TRAIN_IMG_FOLDER))
116 | arrayToImage(array, PATH_TRAIN_IMG_FOLDER, ""+countProcess);
117 | else
118 | arrayToImage(array, PATH_VERIFY_IMG_FOLDER, ""+countProcess);
119 | }
120 | reader.close();
121 | }
122 | private static void writeTxtFile(String path) throws IOException {
123 | ArrayList mnist;
124 | if(path.equals(PATH_TRAIN_TXT)) {
125 | mnist = MNISTReader.getMNISTTrainDataFromFile();
126 | }
127 | else
128 | {
129 | mnist = MNISTReader.getMNISTVerifyDataFromFile();
130 | }
131 | BufferedWriter writer = new BufferedWriter(new FileWriter(path,true));
132 |
133 | System.out.println("[Writing txt file...]");
134 | for(int i = 0; i < mnist.size(); i++){
135 | double[][] arr = mnist.get(i);
136 | for(int j = 0; j < arr.length-1; j++){
137 | for(int k = 0; k < arr[j].length; k++){
138 | writer.write(arr[j][k] + " ");
139 | }
140 | writer.newLine();
141 | }
142 | writer.write("Label ↑:"+mnist.get(i)[mnist.get(i).length-1][0]);
143 | writer.newLine();
144 | }
145 | writer.close();
146 | }
147 | // 给一个double[][] 输出图片
148 | private static void arrayToImage(double[][] array, String path, String name) throws IOException {
149 | int height = array.length;
150 | int width = array[0].length;
151 |
152 | // Create a BufferedImage
153 | BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
154 |
155 | // Populate the image with the array values
156 | for (int y = 0; y < height; y++) {
157 | for (int x = 0; x < width; x++) {
158 | // Clamp and scale the value from [0, 1] to [0, 255]
159 | int value = (int) Math.round(Math.min(1.0, Math.max(0.0, array[y][x])) * 255);
160 | int gray = (value << 16) | (value << 8) | value; // Grayscale value
161 | image.setRGB(x, y, gray);
162 | }
163 | }
164 |
165 | // Create the output file
166 | File outputFile = new File(path, name + ".png");
167 |
168 | // Ensure the directory exists
169 | if (!outputFile.getParentFile().exists()) {
170 | outputFile.getParentFile().mkdirs();
171 | }
172 |
173 | // Write the image to file
174 | ImageIO.write(image, "png", outputFile);
175 | }
176 | // 给一个图片 输出double[][]
177 | // 静态方法,传入图片路径,返回double[][]数组
178 |
179 | public static double[][] convertToGrayscale(String path) throws IOException {
180 | // 读取图片
181 | BufferedImage image = ImageIO.read(new File(path));
182 |
183 | // 获取图片的宽度和高度
184 | int width = image.getWidth();
185 | int height = image.getHeight();
186 |
187 | // 创建二维数组存储灰度值
188 | double[][] grayscale = new double[height][width];
189 |
190 | // 遍历每个像素,获取精确的灰度值
191 | for (int y = 0; y < height; y++) {
192 | for (int x = 0; x < width; x++) {
193 | // 读取像素的ARGB值
194 | int argb = image.getRGB(x, y);
195 |
196 | // 获取灰度值(假设图片为灰度图,RGB通道值相同)
197 | int gray = argb & 0xFF; // 直接取低8位(灰度图中,R=G=B)
198 |
199 | // 将灰度值归一化到 0.0~1.0 范围
200 | grayscale[y][x] = gray / 255.0;
201 | }
202 | }
203 |
204 | return grayscale;
205 | }
206 | // 删除 img文件夹所有子文件
207 | private static void deleteAllFilesInFolder(String path) {
208 |
209 | File folder = new File(path);
210 |
211 | // Check if the path is a valid directory
212 | if (!folder.exists() || !folder.isDirectory()) {
213 | System.err.println("[Warning, Class: DevelopTool] Invalid directory: " + path);
214 | return;
215 | }
216 |
217 | // Get all files in the directory
218 | File[] files = folder.listFiles();
219 | if (files == null || files.length == 0) {
220 | System.err.println("[Warning, Class: DevelopTool] No files to delete in directory: " + path);
221 | return;
222 | }
223 |
224 | System.err.println("[Deleting data. Please wait]");
225 | // Iterate through the files and delete each one
226 | for (File file : files) {
227 | if (file.isFile()) {
228 | if (!file.delete()) {
229 | System.err.println("[Warning, Class:DevelopTool] Failed to delete file: " + file.getName());
230 | return;
231 | }
232 | } else if (file.isDirectory()) {
233 | System.err.println("[Warning, Class:DevelopTool] Skipping subdirectory: " + file.getName());
234 | }
235 | }
236 |
237 | }
238 | // 删除 txt文件
239 | private static void deleteTxtFile(String path) {
240 | File file = new File(path);
241 | if(file.exists()) {
242 | file.delete();
243 | }
244 |
245 | }
246 | // 获取double
247 | private static double[] stringToDoubleArray(String input) {
248 | return Stream.of(input.split(" "))
249 | .mapToDouble(Double::parseDouble)
250 | .toArray();
251 | }
252 |
253 |
254 | // return 28*29 *n Batch
255 | public static ArrayList getTrainBatch() throws IOException {
256 | ArrayList batch = new ArrayList<>();
257 |
258 | BufferedReader reader = new BufferedReader(new FileReader(PATH_TRAIN_TXT));
259 | String line = null;
260 | int batchSize = GUI.BATCH_SET_SIZE;
261 |
262 | for(int i = 0;i getVerifyBatch() throws IOException {
279 | ArrayList batch = new ArrayList<>();
280 |
281 | BufferedReader reader = new BufferedReader(new FileReader(PATH_VERIFY_TXT));
282 | String line = null;
283 | int batchSize = 10000; // 验证集的大小是10000
284 |
285 | System.out.println(batchSize);
286 | for(int i = 0;i> 24) & 0xff;
317 | int red = (argb >> 16) & 0xff;
318 | int green = (argb >> 8) & 0xff;
319 | int blue = argb & 0xff;
320 | // 灰度值的计算(这里简单平均RGB三个分量,对于灰度图本身RGB相等,也可以直接取其一)
321 | int gray = (red + green + blue) / 3;
322 | // 将灰度值归一化到0到1区间
323 | result[y][x] = (double) gray / 255.0;
324 | }
325 | }
326 | return result;
327 | }
328 | }
329 |
--------------------------------------------------------------------------------
/src/functions/GUI.java:
--------------------------------------------------------------------------------
1 | package functions;
2 |
3 | import javax.swing.*;
4 | import java.awt.*;
5 | import java.awt.event.*;
6 | import java.io.IOException;
7 | import java.util.Random;
8 |
9 |
10 | public class GUI extends JFrame {
11 | int CONSTANT_K = 300;
12 | static boolean stopButtonPressed = false;
13 | protected static int BATCH_SET_SIZE = 10000;
14 | public GUI() {
15 | // 先弹出语言界面
16 | SwingUtilities.invokeLater(() -> {
17 | // 创建 JDialog 弹窗,并设置为模态对话框
18 | JDialog dialog = new JDialog((Frame) null, "Language:", true);
19 | dialog.setSize(300, 150);
20 | dialog.setResizable(false);
21 | dialog.setLocationRelativeTo(null);
22 | dialog.setDefaultCloseOperation(JDialog.DO_NOTHING_ON_CLOSE);
23 |
24 | JPanel panel = new JPanel();
25 | panel.setLayout(new FlowLayout(FlowLayout.CENTER, 10, 10));
26 |
27 | // 使用默认字体设置标签字体
28 | JLabel label = new JLabel("语言设置/ Language Setting");
29 | label.setFont(new Font(Font.DIALOG, Font.PLAIN, 16));
30 | panel.add(label);
31 |
32 | JButton chineseButton = new JButton("中文");
33 | // 使用默认字体设置按钮字体
34 | chineseButton.setFont(new Font(Font.DIALOG, Font.PLAIN, 14));
35 | chineseButton.setPreferredSize(new Dimension(80, 30));
36 | chineseButton.addActionListener(new ActionListener() {
37 | @Override
38 | public void actionPerformed(ActionEvent e) {
39 | // 这里可以添加选择中文后的具体逻辑,比如设置程序语言为中文等
40 | LanguageSetting.runChineseText();
41 | // 关闭对话框,解除阻塞,后续代码才能继续执行
42 | dialog.dispose();
43 | }
44 | });
45 | panel.add(chineseButton);
46 |
47 | JButton englishButton = new JButton("English");
48 | // 使用默认字体设置按钮字体
49 | englishButton.setFont(new Font(Font.DIALOG, Font.PLAIN, 14));
50 | englishButton.setPreferredSize(new Dimension(80, 30));
51 | englishButton.addActionListener(new ActionListener() {
52 | @Override
53 | public void actionPerformed(ActionEvent e) {
54 | // 这里可以添加选择英文后的具体逻辑,比如设置程序语言为英文等
55 | LanguageSetting.runEnglishText();
56 | // 关闭对话框,解除阻塞,后续代码才能继续执行
57 | dialog.dispose();
58 | }
59 | });
60 | panel.add(englishButton);
61 |
62 | dialog.add(panel);
63 | dialog.setVisible(true);
64 |
65 | // 这里的代码在对话框关闭前不会执行,因为对话框是模态的,阻塞了此处代码的运行
66 | GUI.this.setSize(1024, 768);
67 |
68 | initApplication();
69 |
70 | GUI.this.setVisible(true);
71 | });
72 |
73 |
74 | }
75 |
76 | private void initApplication() {
77 |
78 | // 这里设置基本属性
79 | {
80 |
81 | this.setSize(703, 820);
82 |
83 | // 设置界面标题
84 | this.setTitle(LanguageSetting.c1);
85 |
86 | // 设置软件置顶
87 | this.setAlwaysOnTop(true);
88 |
89 | // 设置界面居中
90 | this.setLocationRelativeTo(null);
91 |
92 | // 游戏关闭那么就结束运行
93 | this.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
94 |
95 | // 取消默认的界面内部的居中方式,这样下面图片才可以设置坐标
96 | this.setLayout(null);
97 |
98 | // 不允许resize
99 | this.setResizable(false);
100 | }
101 |
102 | // 这里设置按钮
103 | {
104 | JMenuBar menuBar = new JMenuBar();
105 | JMenu fileMenu = new JMenu(LanguageSetting.b1);
106 | JMenuItem reset = new JMenuItem(LanguageSetting.b2);
107 | JMenuItem clearCache = new JMenuItem(LanguageSetting.b3);
108 |
109 |
110 | clearCache.addActionListener(new ActionListener() {
111 | @Override
112 | public void actionPerformed(ActionEvent e) {
113 | GUI.this.setVisible(false);
114 | DevelopTool.clearCache();
115 | GUI.this.setVisible(true);
116 | System.out.println("[已成功清除缓存]");
117 | }
118 | });
119 | reset.addActionListener(new ActionListener() {
120 | @Override
121 | public void actionPerformed(ActionEvent e) {
122 | showJDialog1();
123 | }
124 | });
125 |
126 | fileMenu.add(reset);
127 | fileMenu.add(clearCache);
128 | menuBar.add(fileMenu);
129 |
130 | this.getContentPane().add(menuBar);
131 | this.setJMenuBar(menuBar);
132 | }
133 |
134 |
135 |
136 | // 这里设置界面内的内容
137 | {
138 | final int panelX = 20;
139 | final int panelY = 20;
140 |
141 | this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
142 |
143 | // 创建用于放置28*28个小方块的面板
144 | JPanel newJPanel = new JPanel();
145 | newJPanel.setLayout(null); // 设置为null布局,便于手动定位小方块
146 |
147 | int squareSize = 10; // 每个小方块的边长(包含边框)
148 | int spacing = 0; // 小方块之间的间隔,用于显示边框效果
149 |
150 | for (int i = 0; i < 28; i++) {
151 | for (int j = 0; j < 28; j++) {
152 | // 创建每个小方块(以JPanel模拟)
153 | JPanel squarePanel = new JPanel() {
154 | @Override
155 | protected void paintComponent(Graphics g) {
156 | super.paintComponent(g);
157 | g.setColor(new Color(0,0,0));
158 | // 绘制小方块的边框,这里绘制四条边
159 | g.drawRect(0, 0, getWidth(), getHeight());
160 | }
161 | };
162 | squarePanel.setBackground(new Color(0, 0, 0));
163 | squarePanel.setBounds(i * (squareSize + spacing), j * (squareSize + spacing), squareSize, squareSize);
164 | newJPanel.add(squarePanel);
165 | }
166 | }
167 | newJPanel.setBounds(panelX, panelY, 280, 280);
168 |
169 |
170 | // 键盘监听
171 | Random rand = new Random();
172 | newJPanel.addMouseMotionListener(new MouseMotionAdapter() {
173 | @Override
174 | public void mouseDragged(MouseEvent e) {
175 | for (int dx = -5; dx <= 5; dx++) {
176 | for (int dy = -5; dy <= 5; dy++) {
177 | if (e.getX() + dx >= 1 && e.getX() + dx <= 279 || e.getY() + dy >= 1 && e.getY() + dy <= 279)
178 | updatePanel(e.getX(), e.getY(), dx, dy, newJPanel);
179 | }
180 | }
181 | }
182 | });
183 | newJPanel.addMouseListener(new MouseAdapter() {
184 | @Override
185 | public void mousePressed(MouseEvent e) {
186 | for (int dx = -5; dx <= 5; dx++) {
187 | for (int dy = -5; dy <= 5; dy++) {
188 | if (e.getX() + dx >= 1 && e.getX() + dx <= 279 || e.getY() + dy >= 1 && e.getY() + dy <= 279)
189 | updatePanel(e.getX(), e.getY(), dx, dy, newJPanel);
190 | }
191 | }
192 | }
193 | });
194 |
195 |
196 | // 将包含小方块的面板添加到窗口的内容面板
197 | this.getContentPane().add(newJPanel);
198 |
199 |
200 |
201 |
202 | JButton reset = new JButton(LanguageSetting.a1);
203 | reset.addActionListener(new ActionListener() {
204 | @Override
205 | public void actionPerformed(ActionEvent e) {
206 | for(int x=0;x<28;x++)
207 | {
208 | for(int y=0;y<28;y++)
209 | {
210 | newJPanel.getComponentAt(x*10,y*10).setBackground(new Color(0,0,0));
211 | }
212 | }
213 |
214 | System.out.println("已重置画布!");
215 | }
216 | });
217 | reset.setBounds(20 ,310,130,30);
218 | this.getContentPane().add(reset);
219 |
220 | JButton stopButton = new JButton(LanguageSetting.a9);
221 | stopButton.addActionListener(new ActionListener() {
222 | @Override
223 | public void actionPerformed(ActionEvent e) {
224 | stopButtonPressed = true;
225 | }
226 | });
227 | stopButton.setBounds(330,370,150,20);
228 | this.getContentPane().add(stopButton);
229 |
230 | JButton recordButton = new JButton(LanguageSetting.a10);
231 | recordButton.addActionListener(new ActionListener() {
232 | @Override
233 | public void actionPerformed(ActionEvent e) {
234 | try {
235 | Processing.record();
236 | System.out.println("[成功导出模型到桌面!]");
237 | } catch (IOException ex) {
238 | throw new RuntimeException(ex);
239 | }
240 | }
241 | });
242 | recordButton.setBounds(330,410,350,20);
243 | this.getContentPane().add(recordButton);
244 |
245 | JButton loadButton = new JButton(LanguageSetting.a11);
246 | loadButton.addActionListener(new ActionListener() {
247 | @Override
248 | public void actionPerformed(ActionEvent e) {
249 | GUI.this.setVisible(false);
250 | try {
251 | Processing.readModel();
252 | } catch (IOException ex) {
253 | throw new RuntimeException(ex);
254 | }
255 | System.out.println("[已成功导入模型!]");
256 | GUI.this.setVisible(true);
257 | }
258 | });
259 | loadButton.setBounds(330,440,350,20);
260 | this.getContentPane().add(loadButton);
261 |
262 | JButton setAsTrainImg = new JButton(LanguageSetting.a5);
263 | setAsTrainImg.setBounds(330,120,150,20);
264 | setAsTrainImg.addActionListener(new ActionListener() {
265 | @Override
266 | public void actionPerformed(ActionEvent e) {
267 | JTextField jt = (JTextField) GUI.this.getContentPane().getComponentAt(330,200);
268 | GUI.BATCH_SET_SIZE = Integer.parseInt(jt.getText());
269 | Thread thread = new Thread(new Runnable() {
270 | @Override
271 | public void run() {
272 | try {
273 | Processing.run();
274 | } catch (IOException ex) {
275 | throw new RuntimeException(ex);
276 | }
277 | }
278 | });
279 | thread.start();
280 | }
281 | });
282 | this.getContentPane().add(setAsTrainImg);
283 |
284 | JLabel text1 = new JLabel(LanguageSetting.a6);
285 | text1.setBounds(330,140,170,20);
286 | this.getContentPane().add(text1);
287 |
288 | JTextField input1 = new JTextField();
289 | input1.setText("10000");
290 | input1.setBounds(330, 170, 75, 40);
291 |
292 | input1.setFont(new Font("custom", Font.BOLD,25));
293 | this.getContentPane().add(input1);
294 |
295 | JButton getImgFromFolder = new JButton(LanguageSetting.a7);
296 | getImgFromFolder.setBounds(330, 220, 350, 20);
297 | getImgFromFolder.addActionListener(new ActionListener() {
298 | @Override
299 | public void actionPerformed(ActionEvent e) {
300 | try {
301 | getRandomImgFromFile();
302 | Processing.getDigit(GUI.this);
303 | } catch (IOException ex) {
304 | throw new RuntimeException(ex);
305 | }
306 | }
307 | });
308 |
309 |
310 | this.getContentPane().add(getImgFromFolder);
311 |
312 | // 训练该样本的准确性
313 | JButton verifyThisModel = new JButton(LanguageSetting.a8);
314 | verifyThisModel.setBounds(330, 260, 350, 20);
315 | verifyThisModel.addActionListener(new ActionListener() {
316 | @Override
317 | public void actionPerformed(ActionEvent e) {
318 | // 创建一个新线程来执行长时间运行的任务
319 | Thread thread = new Thread(new Runnable() {
320 | @Override
321 | public void run() {
322 | try {
323 | Processing.VerifyAccuracy(GUI.this);
324 | } catch (IOException ex) {
325 | throw new RuntimeException(ex);
326 | }
327 | }
328 | });
329 | thread.start();
330 | }
331 | });
332 | this.getContentPane().add(verifyThisModel);
333 |
334 |
335 | JLabel correctNumber = new JLabel("0");
336 | correctNumber.setFont(new Font("custom", Font.BOLD,18));
337 | correctNumber.setBounds(430,280,150,30);
338 | this.getContentPane().add(correctNumber);
339 |
340 | JLabel per10000 = new JLabel("/ 10000");
341 | per10000.setFont(new Font("custom", Font.BOLD,18));
342 | per10000.setBounds(480,280,150,30);
343 | this.getContentPane().add(per10000);
344 |
345 | JLabel countText = new JLabel("Count:");
346 | countText.setFont(new Font("custom", Font.BOLD,18));
347 | countText.setBounds(330,280,150,30);
348 | this.getContentPane().add(countText);
349 |
350 |
351 | JLabel accuracyNumber = new JLabel("00.00%");
352 | accuracyNumber.setFont(new Font("custom", Font.BOLD,18));
353 | accuracyNumber.setBounds(430,320,150,30);
354 | this.getContentPane().add(accuracyNumber);
355 |
356 | JLabel accracyText = new JLabel("Accuracy:");
357 | accracyText.setFont(new Font("custom", Font.BOLD,18));
358 | accracyText.setBounds(330,300,150,30);
359 | this.getContentPane().add(accracyText);
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 | JButton getPredict = new JButton(LanguageSetting.a3);
368 | getPredict.setBounds(20 ,370,280,20);
369 | getPredict.addActionListener(new ActionListener() {
370 | @Override
371 | public void actionPerformed(ActionEvent e) {
372 | System.out.println("已经更新获取的数据");
373 | Processing.getDigit(GUI.this);
374 | }
375 | });
376 | this.getContentPane().add(getPredict);
377 |
378 |
379 | // 设置随机图片按钮
380 | JButton randomButton = new JButton(LanguageSetting.a2);
381 | randomButton.addActionListener(new ActionListener() {
382 | @Override
383 | public void actionPerformed(ActionEvent e) {
384 | for(int x=0;x<28;x++)
385 | {
386 | for(int y=0;y<28;y++)
387 | {
388 | int randColor = (int)(Math.random()*255.00);
389 |
390 | newJPanel.getComponentAt(x*10,y*10).setBackground(new Color(randColor,randColor,randColor));
391 | }
392 | }
393 | }
394 | });
395 | randomButton.setBounds(20 +280 - 130 ,310,130,30);
396 | this.getContentPane().add(randomButton);
397 |
398 | // 加入0~9预测
399 | {
400 | JPanel predictPanel = new JPanel();
401 | predictPanel.setLayout(null);
402 | Font font = new Font("custom", Font.BOLD,20);
403 | JLabel digit0 = new JLabel("[0]");
404 | JLabel digit1 = new JLabel("[1]");
405 | JLabel digit2 = new JLabel("[2]");
406 | JLabel digit3 = new JLabel("[3]");
407 | JLabel digit4 = new JLabel("[4]");
408 | JLabel digit5 = new JLabel("[5]");
409 | JLabel digit6 = new JLabel("[6]");
410 | JLabel digit7 = new JLabel("[7]");
411 | JLabel digit8 = new JLabel("[8]");
412 | JLabel digit9 = new JLabel("[9]");
413 | JLabel preDigit0 = new JLabel("0.00%");
414 | JLabel preDigit1 = new JLabel("0.00%");
415 | JLabel preDigit2 = new JLabel("0.00%");
416 | JLabel preDigit3 = new JLabel("0.00%");
417 | JLabel preDigit4 = new JLabel("0.00%");
418 | JLabel preDigit5 = new JLabel("0.00%");
419 | JLabel preDigit6 = new JLabel("0.00%");
420 | JLabel preDigit7 = new JLabel("0.00%");
421 | JLabel preDigit8 = new JLabel("0.00%");
422 | JLabel preDigit9 = new JLabel("0.00%");
423 | digit0.setFont(font);
424 | digit1.setFont(font);
425 | digit2.setFont(font);
426 | digit3.setFont(font);
427 | digit4.setFont(font);
428 | digit5.setFont(font);
429 | digit6.setFont(font);
430 | digit7.setFont(font);
431 | digit8.setFont(font);
432 | digit9.setFont(font);
433 |
434 | preDigit0.setFont(font);
435 | preDigit1.setFont(font);
436 | preDigit2.setFont(font);
437 | preDigit3.setFont(font);
438 | preDigit4.setFont(font);
439 | preDigit5.setFont(font);
440 | preDigit6.setFont(font);
441 | preDigit7.setFont(font);
442 | preDigit8.setFont(font);
443 | preDigit9.setFont(font);
444 |
445 |
446 | int YBIAS = 30;
447 | digit0.setBounds(0,0*YBIAS,35,24);
448 | digit1.setBounds(0,1*YBIAS,35,24);
449 | digit2.setBounds(0,2*YBIAS,35,24);
450 | digit3.setBounds(0,3*YBIAS,35,24);
451 | digit4.setBounds(0,4*YBIAS,35,24);
452 | digit5.setBounds(0,5*YBIAS,35,24);
453 | digit6.setBounds(0,6*YBIAS,35,24);
454 | digit7.setBounds(0,7*YBIAS,35,24);
455 | digit8.setBounds(0,8*YBIAS,35,24);
456 | digit9.setBounds(0,9*YBIAS,35,24);
457 | preDigit0.setBounds(35,0*YBIAS,140,24);
458 | preDigit1.setBounds(35,1*YBIAS,140,24);
459 | preDigit2.setBounds(35,2*YBIAS,140,24);
460 | preDigit3.setBounds(35,3*YBIAS,140,24);
461 | preDigit4.setBounds(35,4*YBIAS,140,24);
462 | preDigit5.setBounds(35,5*YBIAS,140,24);
463 | preDigit6.setBounds(35,6*YBIAS,140,24);
464 | preDigit7.setBounds(35,7*YBIAS,140,24);
465 | preDigit8.setBounds(35,8*YBIAS,140,24);
466 | preDigit9.setBounds(35,9*YBIAS,140,24);
467 |
468 | predictPanel.add(digit0);
469 | predictPanel.add(digit1);
470 | predictPanel.add(digit2);
471 | predictPanel.add(digit3);
472 | predictPanel.add(digit4);
473 | predictPanel.add(digit5);
474 | predictPanel.add(digit6);
475 | predictPanel.add(digit7);
476 | predictPanel.add(digit8);
477 | predictPanel.add(digit9);
478 |
479 | predictPanel.add(preDigit0);
480 | predictPanel.add(preDigit1);
481 | predictPanel.add(preDigit2);
482 | predictPanel.add(preDigit3);
483 | predictPanel.add(preDigit4);
484 | predictPanel.add(preDigit5);
485 | predictPanel.add(preDigit6);
486 | predictPanel.add(preDigit7);
487 | predictPanel.add(preDigit8);
488 | predictPanel.add(preDigit9);
489 |
490 |
491 | predictPanel.setBounds(20,400,300,480);
492 | this.getContentPane().add(predictPanel);
493 | }
494 |
495 | // 加入参数调整
496 | {
497 | JLabel lambdaText = new JLabel (LanguageSetting.a12);
498 | JLabel alphaText = new JLabel (LanguageSetting.a14);
499 | JLabel stopVarText = new JLabel (LanguageSetting.a15);
500 | JLabel miniBatchSize = new JLabel (LanguageSetting.a16);
501 |
502 | lambdaText.setFont(new Font("customize",Font.BOLD,17));
503 | alphaText.setFont(new Font("customize",Font.BOLD,17));
504 | stopVarText.setFont(new Font("customize",Font.BOLD,17));
505 | miniBatchSize.setFont(new Font("customize",Font.BOLD,17));
506 |
507 | lambdaText.setBounds(330,470,350,20);
508 | alphaText.setBounds(330,530,350,20);
509 | stopVarText.setBounds(330,590,350,20);
510 | miniBatchSize.setBounds(330,650,350,20);
511 |
512 | this.getContentPane().add(lambdaText);
513 | this.getContentPane().add(alphaText);
514 | this.getContentPane().add(stopVarText);
515 | this.getContentPane().add(miniBatchSize);
516 |
517 | JTextField lambda = new JTextField();
518 | lambda.setText(String.valueOf(Processing.LAMBDA_L1));
519 | int constant1 = 30;
520 | lambda.setFont(new Font("customize",Font.BOLD,17));
521 | lambda.setBounds(330,470+constant1,100,25);
522 | this.getContentPane().add(lambda);
523 |
524 | JTextField alpha = new JTextField();
525 | alpha.setText(String.valueOf(Processing.ALPHA));
526 | alpha.setFont(new Font("customize",Font.BOLD,17));
527 | alpha.setBounds(330,530+constant1,100,25);
528 | this.getContentPane().add(alpha);
529 |
530 | JTextField stopEpoch = new JTextField();
531 | stopEpoch.setText(String.valueOf(Processing.EPOCH_TOTAL));
532 | stopEpoch.setFont(new Font("customize",Font.BOLD,17));
533 | stopEpoch.setBounds(330,590+constant1,100,25);
534 | this.getContentPane().add(stopEpoch);
535 |
536 | JTextField miniBatch = new JTextField();
537 | miniBatch.setText(String.valueOf(Processing.DEFINE_BATCH_SIZE));
538 | miniBatch.setFont(new Font("customize",Font.BOLD,17));
539 | miniBatch.setBounds(330,650+constant1,100,25);
540 | this.getContentPane().add(miniBatch);
541 |
542 |
543 | JButton update1 = new JButton(LanguageSetting.a13);
544 | update1.addActionListener(new ActionListener() {
545 | @Override
546 | public void actionPerformed(ActionEvent e) {
547 | System.out.println("已更改正则化L1、L2 Lambda为"+lambda.getText());
548 | lambda.setForeground(Color.black);
549 | Processing.LAMBDA_L1 = Double.parseDouble(lambda.getText());
550 | Processing.LAMBDA_L2 = Double.parseDouble(lambda.getText());
551 | }
552 | });
553 | update1.setFont(new Font("customize",Font.BOLD,17));
554 | update1.setBounds(330+115,470+constant1,100,25);
555 | this.getContentPane().add(update1);
556 |
557 | JButton update2 = new JButton(LanguageSetting.a13);
558 | update2.addActionListener(new ActionListener() {
559 | @Override
560 | public void actionPerformed(ActionEvent e) {
561 | alpha.setForeground(Color.black);
562 | Processing.ALPHA = Double.parseDouble(alpha.getText());
563 | System.out.println("已更改学习率Alpha为"+Processing.ALPHA);
564 | }
565 | });
566 | update2.setFont(new Font("customize",Font.BOLD,17));
567 | update2.setBounds(330+115,530+constant1,100,25);
568 | this.getContentPane().add(update2);
569 |
570 | JButton update3 = new JButton(LanguageSetting.a13);
571 | update3.addActionListener(new ActionListener() {
572 | @Override
573 | public void actionPerformed(ActionEvent e) {
574 | System.out.println("已更改Epoch轮数为"+stopEpoch.getText());
575 | stopEpoch.setForeground(Color.black);
576 | Processing.EPOCH_TOTAL = (int) Double.parseDouble(stopEpoch.getText());
577 | }
578 | });
579 | update3.setFont(new Font("customize",Font.BOLD,17));
580 | update3.setBounds(330+115,590+constant1,100,25);
581 | this.getContentPane().add(update3);
582 |
583 | JButton update4 = new JButton(LanguageSetting.a13);
584 | update4.addActionListener(new ActionListener() {
585 | @Override
586 | public void actionPerformed(ActionEvent e) {
587 | System.out.println("已更改Batch值为"+miniBatch.getText());
588 | miniBatch.setForeground(Color.black);
589 | Processing.EPOCH_TOTAL = (int) Double.parseDouble(miniBatch.getText());
590 | }
591 | });
592 | update4.setFont(new Font("customize",Font.BOLD,17));
593 | update4.setBounds(330+115,650+constant1,100,25);
594 | this.getContentPane().add(update4);
595 |
596 |
597 |
598 | }
599 |
600 |
601 | this.setVisible(true);
602 | }
603 |
604 | // 设置浓度滑条
605 | {
606 | // 创建滑条(范围:0-100,初始值:0)
607 | JSlider slider = new JSlider(0, 100, 0);
608 | slider.setMajorTickSpacing(10); // 每10个单位显示一个大刻度
609 | slider.setMinorTickSpacing(1); // 每1个单位显示一个小刻度
610 | slider.setPaintTicks(true); // 显示刻度线
611 | slider.setPaintLabels(true); // 显示刻度标签
612 | slider.setValue(91);
613 |
614 | slider.addMouseListener(new MouseAdapter() {
615 | @Override
616 | public void mouseReleased(MouseEvent e) {
617 | CONSTANT_K = (101-slider.getValue())*35;
618 |
619 | }
620 | });
621 |
622 | // 创建按钮
623 | JButton defButton = new JButton(LanguageSetting.a4);
624 |
625 | // 添加按钮点击事件
626 |
627 | defButton.addActionListener(new ActionListener() {
628 |
629 | @Override
630 | public void actionPerformed(ActionEvent e) {
631 | System.out.println("已还原为默认值!");
632 | CONSTANT_K = 300;
633 | }
634 | });
635 | // 创建容器JPanel并设置布局
636 | JPanel panel = new JPanel();
637 | panel.setLayout(new FlowLayout());
638 |
639 | panel.setBounds(330,10,200,200);
640 | // 将滑条和按钮添加到JPanel
641 | panel.add(slider);
642 | panel.add(defButton);
643 |
644 | // 将panel添加到JFrame(假设你已经创建了JFrame)
645 | add(panel);
646 |
647 | }
648 |
649 |
650 | }
651 |
652 | private void getRandomImgFromFile() throws IOException {
653 | JPanel newJPanel = (JPanel) this.getContentPane().getComponentAt(20,20);
654 | Random rand = new Random();
655 | double[][] array = DevelopTool.imgToArray("resources/TrainImg/"+rand.nextInt(GUI.BATCH_SET_SIZE)+".png");
656 | for(int x=0;x<28;x++)
657 | {
658 | for (int y=0;y<28;y++)
659 | {
660 | int RGB = (int)(array[x][y]*255.00);
661 | newJPanel.getComponentAt(y*10,x*10).setBackground(new Color(RGB,RGB,RGB));
662 | }
663 | }
664 | }
665 |
666 |
667 | private void updatePanel(int x, int y, int dx, int dy, JPanel newJPanel) {
668 | // 找不到零件就返回
669 | if (newJPanel.getComponentAt(x + dx, y + dy) == null) return;
670 | int RGB = (int) ((dx * dx + dy * dy / 50.0) / (double) CONSTANT_K * 255.0); // 非常小的一个数字,例如0.21
671 | int originalRGB = newJPanel.getComponentAt(x + dx, y + dy).getBackground().getRed(); //0, 或者150这种
672 | if (originalRGB + RGB > 255) {
673 | newJPanel.getComponentAt(x + dx, y + dy).setBackground(new Color(255,255,255));
674 | return;
675 | }
676 |
677 | newJPanel.getComponentAt(x + dx, y + dy).setBackground(new Color(originalRGB + RGB, originalRGB + RGB, originalRGB + RGB));
678 | }
679 |
680 | private void showJDialog1() {
681 | // 创建对话框
682 | JDialog dialog = new JDialog(this, "警告", true);
683 |
684 | // 创建面板用于放置组件,这里设置为null布局,以便进行绝对定位
685 | JPanel panel = new JPanel(null);
686 |
687 | JLabel text = new JLabel("你确定要这么做吗?可能会等待较长时间");
688 |
689 | JButton yesButton = new JButton("Yes");
690 | yesButton.addActionListener(new ActionListener() {
691 | @Override
692 | public void actionPerformed(ActionEvent e) {
693 | dialog.setVisible(false);
694 | dialog.dispose();
695 | GUI.this.setVisible(false);
696 | try {
697 | DevelopTool.regenerateResourcesFile();
698 | } catch (IOException ex) {
699 | throw new RuntimeException(ex);
700 | }
701 | GUI.this.setVisible(true);
702 | }
703 | });
704 |
705 | JButton noButton = new JButton("No");
706 | noButton.addActionListener(new ActionListener() {
707 | @Override
708 | public void actionPerformed(ActionEvent e) {
709 | dialog.dispose();
710 | }
711 | });
712 |
713 | // 设置按钮的大小和位置
714 | yesButton.setBounds(10, 30, 80, 30);
715 | noButton.setBounds(150, 30, 80, 30);
716 |
717 | // 设置文本标签的位置和大小,这里简单设置了下,可根据实际调整
718 | text.setBounds(10, 10, 300, 20);
719 |
720 | // 将组件添加到面板中
721 | panel.add(text);
722 | panel.add(yesButton);
723 | panel.add(noButton);
724 |
725 | // 将面板添加到对话框中
726 | dialog.add(panel);
727 |
728 | // 设置对话框大小
729 | dialog.setSize(275, 110);
730 |
731 | // 设置对话框相对于当前窗口的位置
732 | dialog.setLocationRelativeTo(this);
733 |
734 | // 设置对话框始终在最上层显示
735 | dialog.setAlwaysOnTop(true);
736 | // 不可resize
737 | dialog.setResizable(false);
738 |
739 | // 让对话框可见
740 | dialog.setVisible(true);
741 | }
742 | }
743 |
--------------------------------------------------------------------------------
/src/functions/LanguageSetting.java:
--------------------------------------------------------------------------------
1 | package functions;
2 |
3 | public class LanguageSetting {
4 | private LanguageSetting() {}
5 | public static String a1;
6 | public static String a2;
7 | public static String a3;
8 | public static String a4;
9 | public static String a5;
10 | public static String a6;
11 | public static String a7;
12 | public static String a8;
13 | public static String a9;
14 | public static String a10;
15 | public static String a11;
16 | public static String a12;
17 | public static String a13;
18 | public static String a14;
19 | public static String a15;
20 | public static String a16;
21 | public static String b1;
22 | public static String b2;
23 | public static String b3;
24 | public static String c1;
25 |
26 | public static void runChineseText(){
27 | a1 = "重置画布";
28 | a2 = "随机画布";
29 | a3 = "预测此图像";
30 | a4 = "还原为默认值";
31 | a5 = "开始训练";
32 | a6 = "设置训练样本量";
33 | a7 = "从训练集中获取图像";
34 | a8 = "以验证集验证模型准确度";
35 | a9 = "停止训练";
36 | a10 = "导出当前模型到桌面";
37 | a11 = "读取先前模型";
38 | a12 = "正则化L1 & L2 值 ->";
39 | a13 = "更改";
40 | a14 = "学习率 Alpha 值->";
41 | a15 = "设置总训练轮数 ->";
42 | a16 = "设置 Batch 值 ->";
43 | b1 = "设置";
44 | b2 = "重新刷新资源文件";
45 | b3 = "清除所有图片和缓存数据库(.txt数据)";
46 | c1 = "Java手写数字识别";
47 | }
48 | public static void runEnglishText() {
49 | a1 = "Reset Canvas";
50 | a2 = "Random Canvas";
51 | a3 = "Predict This Image";
52 | a4 = "Restore to Default Values";
53 | a5 = "Start Training";
54 | a6 = "Set the Training Sample Size";
55 | a7 = "Get Images from the Training Set";
56 | a8 = "Verify the Model Accuracy with the Validation Set";
57 | a9 = "Stop Training";
58 | a10 = "Export the Current Model to the Desktop";
59 | a11 = "Read the Previous Model";
60 | a12 = "Regularization L1 & L2 Values ->";
61 | a13 = "Change";
62 | a14 = "Learning Rate Alpha Value ->";
63 | a15 = "Set the Total Training Times ->";
64 | a16 = "Set the Batch Value ->";
65 | b1 = "Settings";
66 | b2 = "Refresh Resource Files";
67 | b3 = "Clear All Images and Cache Databases(.txt files)";
68 | c1 = "Java Hand Writing Digits Recognition";
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/src/functions/MNISTReader.java:
--------------------------------------------------------------------------------
1 | package functions;
2 |
3 | import java.io.*;
4 | import java.util.ArrayList;
5 |
6 | /**
7 | * This class implements a reader for the MNIST dataset of handwritten digits. The dataset is found
8 | * at http://yann.lecun.com/exdb/mnist/.
9 | *
10 | * Author: Gabe Johnson
11 | *
12 | * Modified by: James Thornton & Zhibin Jiang
13 | */
14 | public class MNISTReader {
15 |
16 |
17 | /**
18 | * @param args
19 | * args[0]: label file; args[1]: data file.
20 | * @throws IOException
21 | */
22 |
23 | public static void main(String[] args) throws IOException {
24 |
25 | }
26 |
27 | /**
28 | * @throws IOException
29 | */
30 |
31 |
32 | public static ArrayList getMNISTTrainDataFromFile() throws IOException {
33 | String[] args = {"resources/database/train-labels-idx1-ubyte","resources/database/train-images-idx3-ubyte"};
34 | System.out.println("Retrieving MNIST data...");
35 |
36 | DataInputStream labels = new DataInputStream(new FileInputStream(args[0]));
37 | DataInputStream images = new DataInputStream(new FileInputStream(args[1]));
38 |
39 | int magicNumber = labels.readInt();
40 | if (magicNumber != 2049) {
41 | System.err.println("Label file has wrong magic number: " + magicNumber + " (should be 2049)");
42 | System.exit(0);
43 | }
44 | magicNumber = images.readInt();
45 | if (magicNumber != 2051) {
46 | System.err.println("Image file has wrong magic number: " + magicNumber + " (should be 2051)");
47 | System.exit(0);
48 | }
49 |
50 | int numLabels = labels.readInt();
51 | int numImages = images.readInt();
52 | int numRows = images.readInt();
53 | int numCols = images.readInt();
54 | if (numLabels != numImages) {
55 | System.err.println("Image file and label file do not contain the same number of entries.");
56 | System.err.println("Label file contains: " + numLabels);
57 | System.err.println("Image file contains: " + numImages);
58 | System.exit(0);
59 | }
60 |
61 | long start = System.currentTimeMillis();
62 | int numLabelsRead = 0;
63 | int numImagesRead = 0;
64 |
65 | ArrayList result = new ArrayList();
66 |
67 | int counter = 0;
68 | while (labels.available() > 0 && numLabelsRead < numLabels) {
69 | counter++;
70 |
71 | if(counter%1000==0)
72 | {
73 | System.out.println("[Reading from training database. Process] \t"+String.format("%.2f",((double)counter/60000.00*100.00))+"%");
74 | }
75 |
76 |
77 | //if(counter==3) break; // functions.test in case
78 | byte label = labels.readByte();
79 | numLabelsRead++;
80 | double[][] image = new double[numCols+1][numRows];
81 | for (int colIdx = 0; colIdx < numCols; colIdx++) {
82 | for (int rowIdx = 0; rowIdx < numRows; rowIdx++) {
83 | image[colIdx][rowIdx] = images.readUnsignedByte()/256.0;
84 | }
85 | }
86 | numImagesRead++;
87 |
88 | double[] labelArray = new double[1];
89 | labelArray[0] = label;
90 | image[numCols] = labelArray;
91 |
92 | result.add(image);
93 | }
94 | // At this point, 'label' and 'image' agree and you can do whatever you like with them.
95 | labels.close();
96 | images.close();
97 | return result;
98 | }
99 | public static ArrayList getMNISTVerifyDataFromFile() throws IOException {
100 | String[] args = {"resources/database/t10k-labels.idx1-ubyte","resources/database/t10k-images.idx3-ubyte"};
101 | System.out.println("Retrieving MNIST data...");
102 |
103 | DataInputStream labels = new DataInputStream(new FileInputStream(args[0]));
104 | DataInputStream images = new DataInputStream(new FileInputStream(args[1]));
105 |
106 | int magicNumber = labels.readInt();
107 | if (magicNumber != 2049) {
108 | System.err.println("Label file has wrong magic number: " + magicNumber + " (should be 2049)");
109 | System.exit(0);
110 | }
111 | magicNumber = images.readInt();
112 | if (magicNumber != 2051) {
113 | System.err.println("Image file has wrong magic number: " + magicNumber + " (should be 2051)");
114 | System.exit(0);
115 | }
116 |
117 | int numLabels = labels.readInt();
118 | int numImages = images.readInt();
119 | int numRows = images.readInt();
120 | int numCols = images.readInt();
121 | if (numLabels != numImages) {
122 | System.err.println("Image file and label file do not contain the same number of entries.");
123 | System.err.println("Label file contains: " + numLabels);
124 | System.err.println("Image file contains: " + numImages);
125 | System.exit(0);
126 | }
127 |
128 | long start = System.currentTimeMillis();
129 | int numLabelsRead = 0;
130 | int numImagesRead = 0;
131 |
132 | ArrayList result = new ArrayList();
133 |
134 | int counter = 0;
135 | while (labels.available() > 0 && numLabelsRead < numLabels) {
136 | counter++;
137 |
138 | if(counter%1000==0)
139 | {
140 | System.out.println("[Reading from verify database. Process] \t"+String.format("%.2f",((double)counter/10000.00*100.00))+"%");
141 | }
142 |
143 |
144 | //if(counter==3) break; // functions.test in case
145 | byte label = labels.readByte();
146 | numLabelsRead++;
147 | double[][] image = new double[numCols+1][numRows];
148 | for (int colIdx = 0; colIdx < numCols; colIdx++) {
149 | for (int rowIdx = 0; rowIdx < numRows; rowIdx++) {
150 | image[colIdx][rowIdx] = images.readUnsignedByte()/256.0;
151 | }
152 | }
153 | numImagesRead++;
154 |
155 | double[] labelArray = new double[1];
156 | labelArray[0] = label;
157 | image[numCols] = labelArray;
158 |
159 | result.add(image);
160 | }
161 | // At this point, 'label' and 'image' agree and you can do whatever you like with them.
162 | labels.close();
163 | images.close();
164 | return result;
165 | }
166 |
167 | }
--------------------------------------------------------------------------------
/src/functions/Processing.java:
--------------------------------------------------------------------------------
1 | package functions;
2 |
3 |
4 | import javax.swing.*;
5 | import java.awt.*;
6 | import java.io.*;
7 | import java.util.ArrayList;
8 | import java.util.Random;
9 |
10 | // 历史最小:0.030220112
11 | // Updated: 2025/1/10
12 | public class Processing {
13 |
14 | /**
15 | * 默认变量
16 | */
17 | // 输出模型文件夹
18 | static final String EXPORT_PART = System.getProperty("user.home") + "\\Desktop\\" + "Model_Data.txt";
19 | // 读取模型文件
20 | static final String IMPORT_PATH = "resources/model/Model_Data.txt";
21 | /**
22 | * 三层神经网络,结构组成如下:
23 | * 输入层:X,用于接收外部输入数据。
24 | * 隐藏层:A 层与 B 层,负责对输入数据进行处理和特征提取。
25 | * 输出层:Y,输出最终的网络计算结果。
26 | * 网络包含四层结构,配置有三权重、三偏置用于调整计算逻辑。
27 | */
28 | static double[][] weightXA = new double[784][20];
29 | static double[][] weightAB = new double[20][20];
30 | static double[][] weightBY = new double[20][10];
31 | static double[] biasA = new double[20];
32 | static double[] biasB = new double[20];
33 | static double[] biasY = new double[10];
34 | /**
35 | * 这里是Adam优化器的数据
36 | */
37 | static double BETA1 = 0.9;
38 | static double BETA2 = 0.999;
39 | static double ALPHA = 0.001;
40 | static double EPSILON = 1.0E-08;
41 | static double LAMBDA_L1 = 0.001;
42 | static double LAMBDA_L2 = 0.001;
43 |
44 | static double[][] weightXA_Mt = new double[784][20];
45 | static double[][] weightXA_Vt = new double[784][20];
46 | static int weightXA_Iteration = 0;
47 | static double[][] weightAB_Mt = new double[20][20];
48 | static double[][] weightAB_Vt = new double[20][20];
49 | static int weightAB_Iteration = 0;
50 | static double[][] weightBY_Mt = new double[20][10];
51 | static double[][] weightBY_Vt = new double[20][10];
52 | static int weightBY_Iteration = 0;
53 | static double[] biasA_Mt = new double[20];
54 | static double[] biasA_Vt = new double[20];
55 | static int biasA_Iteration = 0;
56 | static double[] biasB_Mt = new double[20];
57 | static double[] biasB_Vt = new double[20];
58 | static int biasB_Iteration = 0;
59 | static double[] biasY_Mt = new double[20];
60 | static double[] biasY_Vt = new double[20];
61 | static int biasY_Iteration = 0;
62 | // 总训练轮数
63 | static int EPOCH_TOTAL = 5000;
64 | // 每个Batch大小
65 | static int DEFINE_BATCH_SIZE = 64;
66 |
67 | /**
68 | * 其余储存变量
69 | */
70 | // 储存最近10次的损失函数值
71 | static ArrayList lossValueHistory = new ArrayList<>();
72 |
73 | static {
74 | initialize(); // 初始化
75 | }
76 |
77 |
78 | /**
79 | * run()用于开始训练模型”
80 | */
81 | public static void run() throws IOException {
82 | System.out.println("开始训练...");
83 | System.out.println("总计样本量" + GUI.BATCH_SET_SIZE);
84 | System.out.println("Batch = " + DEFINE_BATCH_SIZE);
85 | GUI.stopButtonPressed = false;
86 | Long start = System.currentTimeMillis();
87 | Random rand = new Random(0);
88 |
89 | ArrayList allBatch = DevelopTool.getTrainBatch();
90 | allBatch = normalization(allBatch); // 预处理
91 |
92 | int EPOCH_CURRENT = 0;
93 | if (allBatch.size() > 10000 && EPOCH_TOTAL != EPOCH_CURRENT) {
94 | System.err.println(String.format("[警告] 数据集过大[%d]", allBatch.size()) + ",每轮Epoch可能时间较长,请耐心等待");
95 | }
96 | double lossValue = getLossValueForBatchSet(allBatch); // 当前损失函数。在while执行完更新
97 |
98 | // 停止按钮和Epoch到达训练两个判断条件
99 | while (!GUI.stopButtonPressed && EPOCH_TOTAL > EPOCH_CURRENT) {
100 |
101 | System.out.println(String.format("损失函数[%.16f]", lossValue) + String.format("\t最近10次损失方差[%.4f] ×10^-4", lossFunctionVar() * 10000.00) + "\tEpoch = " + (++EPOCH_CURRENT) + String.format("进度[%.2f%%]", (double) EPOCH_CURRENT / (double) EPOCH_TOTAL * 100.00));
102 | ArrayList batch = new ArrayList<>();
103 | // 这个for运行完之后,就代表一个EPOCH结束
104 | for (int i = 0; i < GUI.BATCH_SET_SIZE / DEFINE_BATCH_SIZE; i++) {
105 | for (int j = 0; j < DEFINE_BATCH_SIZE; j++) {
106 | // 投入图像
107 | int randomIndex = rand.nextInt(GUI.BATCH_SET_SIZE);
108 | batch.add(allBatch.get(randomIndex));
109 | }
110 | trainBatch(batch);
111 | batch.clear();
112 | //new Scanner(System.in).nextLine();
113 | }
114 | // 更新lossValueHistory
115 | lossValue = getLossValueForBatchSet(allBatch);
116 | if (lossValueHistory.size() > 10) lossValueHistory.removeFirst();
117 | lossValueHistory.add(lossValue);
118 | }
119 | record();
120 | Long end = System.currentTimeMillis();
121 | System.out.println("用时:" + String.format("%.3f", (double) (end - start) / 1000.00) + "s");
122 | System.out.println("[训练完成!]");
123 | GUI.stopButtonPressed = false;
124 | }
125 |
126 | /**
127 | * 用于对一个batch进行梯度下降
128 | */
129 | private static void trainBatch(ArrayList batch) {
130 | // 这里先计算所有batch的网络
131 | // 在同一个batch中,所用到的权重、偏置、神经元数据都必须是上一batch的。
132 | // 因此在这里我们复制了全局中的权重、偏置、神经元数据,然后对于更新完的权重,我们存在这里
133 | // 但是在权重更新过程中我们不适用这里的weight,而是weight
134 | // 在整个batch更新完之后,我们才给weight 赋值为 我们更新的 weight
135 | double[][] XList = new double[batch.size()][784];
136 | double[][] AList = new double[batch.size()][20];
137 | double[][] BList = new double[batch.size()][20];
138 | double[][] YList = new double[batch.size()][10];
139 |
140 | double[][] gradient_weightXA = new double[784][20];
141 | double[][] gradient_weightAB = new double[20][20];
142 | double[][] gradient_weightBY = new double[20][10];
143 | double[] gradient_biasA = new double[20];
144 | double[] gradient_biasB = new double[20];
145 | double[] gradient_biasY = new double[10];
146 |
147 | ++weightXA_Iteration;
148 | ++weightAB_Iteration;
149 | ++weightBY_Iteration;
150 | ++biasA_Iteration;
151 | ++biasB_Iteration;
152 | ++biasY_Iteration;
153 |
154 | for (int numberOfBatch = 0; numberOfBatch < batch.size(); numberOfBatch++) {
155 | double[] X = new double[784];
156 | double[] A = new double[20];
157 | double[] B = new double[20];
158 | double[] Y = new double[10];
159 | double[][] currentImg = batch.get(numberOfBatch);
160 | // 存到X
161 | int counter = 0;
162 | for (int x = 0; x < 28; x++) {
163 | for (int y = 0; y < 28; y++) {
164 | X[counter++] = currentImg[x][y];
165 | }
166 | }
167 | // 计算隐藏层A
168 | for (int i = 0; i < 20; i++) {
169 | for (int left = 0; left < 784; left++) {
170 | A[i] += X[left] * weightXA[left][i];
171 | }
172 | // 添加偏置后,压入sigmoid
173 | A[i] += biasA[i];
174 | A[i] = sigmoid(A[i]);
175 | }
176 | // 计算隐藏层B
177 | for (int i = 0; i < 20; i++) {
178 | for (int left = 0; left < 20; left++) {
179 | B[i] += A[left] * weightAB[left][i];
180 | }
181 | B[i] += biasB[i];
182 | B[i] = sigmoid(B[i]);
183 | }
184 | // 计算Y
185 | for (int i = 0; i < 10; i++) {
186 | for (int left = 0; left < 20; left++) {
187 | Y[i] += B[left] * weightBY[left][i];
188 | }
189 | Y[i] += biasY[i];
190 | Y[i] = sigmoid(Y[i]);
191 | }
192 | XList[numberOfBatch] = arrayCopy(X);
193 | AList[numberOfBatch] = arrayCopy(A);
194 | BList[numberOfBatch] = arrayCopy(B);
195 | YList[numberOfBatch] = arrayCopy(Y);
196 | }
197 |
198 | // 设置梯度
199 | for (int n = 0; n < batch.size(); n++) {
200 | ArrayList storage1 = new ArrayList<>(); // 0~10
201 | ArrayList storage2 = new ArrayList<>(); // 0~20 : i
202 | ArrayList storage3 = new ArrayList<>(); // 0~20 : r
203 |
204 | // biasY
205 | double label = batch.get(n)[28][0];
206 | for (int j = 0; j < 10; j++) {
207 | double ygt;
208 | if (label == j) ygt = 1.0;
209 | else ygt = 0.0;
210 | double value = (YList[n][j] - ygt) * YList[n][j] * (1.0 - YList[n][j]);
211 | gradient_biasY[j] += value;
212 | storage1.add(value);
213 | }
214 | // weightBY
215 | for (int i = 0; i < 20; i++) {
216 | double temp = 0; // 添加temp
217 | for (int j = 0; j < 10; j++) {
218 | double value = storage1.get(j) * BList[n][i];
219 | gradient_weightBY[i][j] += value;
220 |
221 | temp += storage1.get(j) * weightBY[i][j] * BList[n][i] * (1.0 - BList[n][i]); // 累加
222 | }
223 | storage2.add(temp);
224 | }
225 | // biasB
226 | for (int i = 0; i < 20; i++) {
227 | gradient_biasB[i] += storage2.get(i);
228 | }
229 | // weightAB
230 | for (int r = 0; r < 20; r++) {
231 | double temp = 0;
232 | for (int i = 0; i < 20; i++) {
233 | gradient_weightAB[r][i] += storage2.get(i) * AList[n][r];
234 | temp += storage2.get(i) * AList[n][r] * weightAB[r][i] * (1.0 - AList[n][r]);
235 | }
236 | storage3.add(temp);
237 | }
238 | // biasA
239 | for (int r = 0; r < 20; r++) {
240 | gradient_biasA[r] += storage3.get(r);
241 | }
242 |
243 | // weightXA
244 | for (int l = 0; l < 784; l++) {
245 | for (int r = 0; r < 20; r++) {
246 | gradient_weightXA[l][r] += storage3.get(r) * XList[n][l];
247 | }
248 | }
249 | }
250 |
251 |
252 | // 给梯度添加L1和L2正则化 // 这有必要吗
253 | {
254 | // gradient_weightBY
255 | for (int i = 0; i < 20; i++) {
256 | for (int j = 0; j < 10; j++) {
257 |
258 | gradient_weightBY[i][j] += 2 * LAMBDA_L2 * weightBY[i][j] + LAMBDA_L1 * sign(weightBY[i][j]);
259 | }
260 | }
261 | // gradient_weightAB
262 | for (int right = 0; right < 20; right++) {
263 | for (int i = 0; i < 20; i++) {
264 | gradient_weightAB[right][i] += 2 * LAMBDA_L2 * weightAB[right][i] + LAMBDA_L1 * sign(weightAB[right][i]);
265 | }
266 | }
267 | // gradient_weightXA
268 | for (int left = 0; left < 784; left++) {
269 | for (int right = 0; right < 20; right++) {
270 | gradient_weightXA[left][right] += 2 * LAMBDA_L2 * weightXA[left][right] + LAMBDA_L1 * sign(weightXA[left][right]);
271 | }
272 | }
273 | }
274 |
275 |
276 | // Adam 优化器对权重和偏置进行下降
277 |
278 | // biasY
279 | for (int j = 0; j < 10; j++) {
280 | biasY_Mt[j] = BETA1 * biasY_Mt[j] + (1.0 - BETA1) * gradient_biasY[j];
281 | biasY_Vt[j] = BETA2 * biasY_Vt[j] + (1.0 - BETA2) * gradient_biasY[j] * gradient_biasY[j];
282 | double MtHat = biasY_Mt[j] / (1.0 - Math.pow(BETA1, biasY_Iteration));
283 | double VtHat = biasY_Vt[j] / (1.0 - Math.pow(BETA2, biasY_Iteration));
284 | biasY[j] = biasY[j] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON);
285 | }
286 | // weight_BY
287 | for (int i = 0; i < 20; i++) {
288 | for (int j = 0; j < 10; j++) {
289 | weightBY_Mt[i][j] = BETA1 * weightBY_Mt[i][j] + (1.0 - BETA1) * gradient_weightBY[i][j];
290 | weightBY_Vt[i][j] = BETA2 * weightBY_Vt[i][j] + (1.0 - BETA2) * gradient_weightBY[i][j] * gradient_weightBY[i][j];
291 | double MtHat = weightBY_Mt[i][j] / (1.0 - Math.pow(BETA1, weightBY_Iteration));
292 | double VtHat = weightBY_Vt[i][j] / (1.0 - Math.pow(BETA2, weightBY_Iteration));
293 | weightBY[i][j] = weightBY[i][j] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON);
294 |
295 | }
296 | }
297 | // biasB
298 | for (int i = 0; i < 20; i++) {
299 | biasB_Mt[i] = BETA1 * biasB_Mt[i] + (1.0 - BETA1) * gradient_biasB[i];
300 | biasB_Vt[i] = BETA2 * biasB_Vt[i] + (1.0 - BETA2) * gradient_biasB[i] * gradient_biasB[i];
301 |
302 | double MtHat = biasB_Mt[i] / (1.0 - Math.pow(BETA1, biasB_Iteration));
303 | double VtHat = biasB_Vt[i] / (1.0 - Math.pow(BETA2, biasB_Iteration));
304 |
305 | biasB[i] = biasB[i] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON);
306 | }
307 | // weightAB
308 | for (int right = 0; right < 20; right++) {
309 | for (int i = 0; i < 20; i++) {
310 | weightAB_Mt[right][i] = BETA1 * weightAB_Mt[right][i] + (1.0 - BETA1) * gradient_weightAB[right][i];
311 | weightAB_Vt[right][i] = BETA2 * weightAB_Vt[right][i] + (1.0 - BETA2) * gradient_weightAB[right][i] * gradient_weightAB[right][i];
312 |
313 | double MtHat = weightAB_Mt[right][i] / (1.0 - Math.pow(BETA1, weightAB_Iteration));
314 | double VtHat = weightAB_Vt[right][i] / (1.0 - Math.pow(BETA2, weightAB_Iteration));
315 |
316 | weightAB[right][i] = weightAB[right][i] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON);
317 | }
318 | }
319 | // biasA
320 | for (int right = 0; right < 20; right++) {
321 | biasA_Mt[right] = BETA1 * biasA_Mt[right] + (1.0 - BETA1) * gradient_biasA[right];
322 | biasA_Vt[right] = BETA2 * biasA_Vt[right] + (1.0 - BETA2) * gradient_biasA[right] * gradient_biasA[right];
323 |
324 | double MtHat = biasA_Mt[right] / (1.0 - Math.pow(BETA1, biasA_Iteration));
325 | double VtHat = biasA_Vt[right] / (1.0 - Math.pow(BETA2, biasA_Iteration));
326 |
327 | biasA[right] = biasA[right] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON);
328 | }
329 | // weightXA
330 | for (int left = 0; left < 784; left++) {
331 | for (int right = 0; right < 20; right++) {
332 | weightXA_Mt[left][right] = BETA1 * weightXA_Mt[left][right] + (1.0 - BETA1) * gradient_weightXA[left][right];
333 | weightXA_Vt[left][right] = BETA2 * weightXA_Vt[left][right] + (1.0 - BETA2) * gradient_weightXA[left][right] * gradient_weightXA[left][right];
334 |
335 | double MtHat = weightXA_Mt[left][right] / (1.0 - Math.pow(BETA1, weightXA_Iteration));
336 | double VtHat = weightXA_Vt[left][right] / (1.0 - Math.pow(BETA2, weightXA_Iteration));
337 |
338 | weightXA[left][right] = weightXA[left][right] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON);
339 | }
340 | }
341 |
342 | }
343 |
344 | /**
345 | * 获取GUI的数字,并且输出判断
346 | */
347 | public static void getDigit(JFrame frame) {
348 | JPanel newPanel = (JPanel) frame.getContentPane().getComponentAt(20, 20);
349 | // 黑色是0,白色是1
350 | double[][] array = new double[28][28];
351 | for (int x = 0; x < 28; x++) {
352 | for (int y = 0; y < 28; y++) {
353 | array[x][y] = (double) (newPanel.getComponentAt(y * 10, x * 10).getBackground().getRed()) / 255.0;
354 | }
355 | }
356 | array = normalization(array);
357 | // 重新绘制
358 | for (int x = 0; x < 28; x++) {
359 | for (int y = 0; y < 28; y++) {
360 | int RBG = (int) (255.00 * array[x][y]);
361 | if (RBG < 0) RBG = 0;
362 | newPanel.getComponentAt(y * 10, x * 10).setBackground(new Color(RBG, RBG, RBG));
363 | }
364 | }
365 |
366 |
367 | double[] X = new double[784];
368 | double[] A = new double[20];
369 | double[] B = new double[20];
370 | double[] Y = new double[10];
371 | double[][] currentImg = array;
372 | // 存到X
373 | int counter = 0;
374 | for (int x = 0; x < 28; x++) {
375 | for (int y = 0; y < 28; y++) {
376 | X[counter++] = currentImg[x][y];
377 | }
378 | }
379 | // 计算隐藏层A
380 | for (int i = 0; i < 20; i++) {
381 | for (int left = 0; left < 784; left++) {
382 | A[i] += X[left] * weightXA[left][i];
383 | }
384 | // 添加偏置后,压入sigmoid
385 | A[i] += biasA[i];
386 | A[i] = sigmoid(A[i]);
387 | }
388 | // 计算隐藏层B
389 | for (int i = 0; i < 20; i++) {
390 | for (int left = 0; left < 20; left++) {
391 | B[i] += A[left] * weightAB[left][i];
392 | }
393 | B[i] += biasB[i];
394 | B[i] = sigmoid(B[i]);
395 | }
396 | // 计算Y
397 | for (int i = 0; i < 10; i++) {
398 | for (int left = 0; left < 20; left++) {
399 | Y[i] += B[left] * weightBY[left][i];
400 | }
401 | Y[i] += biasY[i];
402 | Y[i] = sigmoid(Y[i]);
403 | }
404 |
405 | JPanel predictPanel = (JPanel) frame.getContentPane().getComponentAt(20, 400);
406 | for (int i = 0; i < 10; i++) {
407 | String temp = String.format("%.2f", Y[i] * 100) + "%";
408 | JLabel tempLabel = (JLabel) predictPanel.getComponentAt(35, i * 30);
409 | JLabel tempLabel2 = (JLabel) predictPanel.getComponentAt(0, i * 30);
410 | tempLabel.setText(temp);
411 | int constant = 100;
412 | int RGB = 255 - (constant + (int) (Y[i] * (255.0 - (double) constant)));
413 | tempLabel.setForeground(new Color(RGB, RGB, RGB));
414 | tempLabel2.setForeground(new Color(RGB, RGB, RGB));
415 | }
416 | int maxIndex = 0;
417 | for (int i = 0; i < 10; i++) {
418 | if (Y[i] > Y[maxIndex]) {
419 | maxIndex = i;
420 | }
421 | }
422 | JLabel tempLabel = (JLabel) predictPanel.getComponentAt(35, maxIndex * 30);
423 | JLabel tempLabel2 = (JLabel) predictPanel.getComponentAt(0, maxIndex * 30);
424 | tempLabel.setForeground(Color.RED);
425 | tempLabel2.setForeground(Color.RED);
426 |
427 |
428 | }
429 |
430 | /**
431 | * 验证模型准确度
432 | */
433 | public static void VerifyAccuracy(JFrame jframe) throws IOException {
434 | int totalCorrectNumber = 0;
435 | ArrayList allBatch = DevelopTool.getVerifyBatch();
436 | allBatch = normalization(allBatch);
437 | int loopCounter = 0;
438 | for (int numberOfBatch = 0; numberOfBatch < allBatch.size(); numberOfBatch++) {
439 | ++loopCounter;
440 |
441 | double[] X = new double[784];
442 | double[] A = new double[20];
443 | double[] B = new double[20];
444 | double[] Y = new double[10];
445 | double[][] currentImg = allBatch.get(numberOfBatch);
446 | // 存到X
447 | int counter = 0;
448 | for (int x = 0; x < 28; x++) {
449 | for (int y = 0; y < 28; y++) {
450 | X[counter++] = currentImg[x][y];
451 | }
452 | }
453 | // 计算隐藏层A
454 | for (int i = 0; i < 20; i++) {
455 | for (int left = 0; left < 784; left++) {
456 | A[i] += X[left] * weightXA[left][i];
457 | }
458 | // 添加偏置后,压入sigmoid
459 | A[i] += biasA[i];
460 | A[i] = sigmoid(A[i]);
461 | }
462 | // 计算隐藏层B
463 | for (int i = 0; i < 20; i++) {
464 | for (int left = 0; left < 20; left++) {
465 | B[i] += A[left] * weightAB[left][i];
466 | }
467 | B[i] += biasB[i];
468 | B[i] = sigmoid(B[i]);
469 | }
470 | // 计算Y
471 | for (int i = 0; i < 10; i++) {
472 | for (int left = 0; left < 20; left++) {
473 | Y[i] += B[left] * weightBY[left][i];
474 | }
475 | Y[i] += biasY[i];
476 | Y[i] = sigmoid(Y[i]);
477 | }
478 |
479 |
480 | JPanel paintingPanel = (JPanel) jframe.getContentPane().getComponentAt(20, 20);
481 | // 重新绘制画板
482 | for (int x = 0; x < 28; x++) {
483 | for (int y = 0; y < 28; y++) {
484 | int RGB = (int) (currentImg[x][y] * 255.00);
485 | paintingPanel.getComponentAt(y * 10, x * 10).setBackground(new Color(RGB, RGB, RGB));
486 | }
487 | }
488 | // 重新绘制digit
489 | JPanel predictPanel = (JPanel) jframe.getContentPane().getComponentAt(20, 400);
490 | for (int i = 0; i < 10; i++) {
491 | String temp = String.format("%.2f", Y[i] * 100) + "%";
492 | JLabel tempLabel = (JLabel) predictPanel.getComponentAt(35, i * 30);
493 | JLabel tempLabel2 = (JLabel) predictPanel.getComponentAt(0, i * 30);
494 | tempLabel.setText(temp);
495 | int constant = 100;
496 | int RGB = 255 - (constant + (int) (Y[i] * (255.0 - (double) constant)));
497 | tempLabel.setForeground(new Color(RGB, RGB, RGB));
498 | tempLabel2.setForeground(new Color(RGB, RGB, RGB));
499 | }
500 | // 获取最大值Index
501 | int maxIndex = 0;
502 | for (int i = 0; i < 10; i++) {
503 | if (Y[i] > Y[maxIndex]) {
504 | maxIndex = i;
505 | }
506 | }
507 | int label = (int) allBatch.get(numberOfBatch)[28][0];
508 |
509 | if (label == maxIndex) ++totalCorrectNumber;
510 |
511 | JLabel countNum = (JLabel) jframe.getContentPane().getComponentAt(430, 280);
512 | countNum.setText("" + numberOfBatch);
513 |
514 |
515 | JLabel accuracyValue = (JLabel) jframe.getContentPane().getComponentAt(430, 320);
516 | accuracyValue.setText(String.format("%.2f", (double) totalCorrectNumber / (double) loopCounter * 100.00) + "%");
517 |
518 |
519 | JLabel temp = (JLabel) predictPanel.getComponentAt(35, maxIndex * 30);
520 | JLabel temp2 = (JLabel) predictPanel.getComponentAt(0, maxIndex * 30);
521 | temp.setForeground(Color.RED);
522 | temp2.setForeground(Color.RED);
523 |
524 | JLabel tempLabel = (JLabel) predictPanel.getComponentAt(35, label * 30);
525 | JLabel tempLabel2 = (JLabel) predictPanel.getComponentAt(0, label * 30);
526 | tempLabel.setForeground(Color.GREEN);
527 | tempLabel2.setForeground(Color.GREEN);
528 | }
529 | }
530 |
531 |
532 | // 复制数组函数
533 | private static double[][] arrayCopy(double[][] sourceArray) {
534 | double[][] returnArray = new double[sourceArray.length][sourceArray[0].length];
535 | for (int i = 0; i < sourceArray.length; i++) {
536 | System.arraycopy(sourceArray[i], 0, returnArray[i], 0, sourceArray[0].length);
537 | }
538 | return returnArray;
539 | }
540 |
541 | private static double[] arrayCopy(double[] sourceArray) {
542 | double[] returnArray = new double[sourceArray.length];
543 | System.arraycopy(sourceArray, 0, returnArray, 0, sourceArray.length);
544 | return returnArray;
545 | }
546 |
547 |
548 | // sign函数
549 | private static double sign(double weight) {
550 | if (weight < 0) return -1.0;
551 | else if (weight > 0) return 1.0;
552 | else return 0.0;
553 | }
554 |
555 |
556 | // 损失函数方差
557 | private static double lossFunctionVar() {
558 | double finalResult = 0.0;
559 | for (int i = 0; i < lossValueHistory.size(); i++) {
560 | finalResult += lossValueHistory.get(i);
561 | }
562 | return finalResult / (double) lossValueHistory.size();
563 | }
564 |
565 |
566 | // 计算BatchSet总损失
567 | private static double getLossValueForBatchSet(ArrayList allBatch) {
568 | double loss = 0.0;
569 | // 遍历整个BatchSet
570 | for (int numberOfBatch = 0; numberOfBatch < allBatch.size(); numberOfBatch++) {
571 | double[][] currentBatch = allBatch.get(numberOfBatch);
572 | double[] X = new double[784];
573 | double[] A = new double[20];
574 | double[] B = new double[20];
575 | double[] Y = new double[10];
576 | // 存到X
577 | int counter = 0;
578 | for (int x = 0; x < 28; x++) {
579 | for (int y = 0; y < 28; y++) {
580 | X[counter++] = currentBatch[x][y];
581 | }
582 | }
583 | // 计算隐藏层A
584 | for (int i = 0; i < 20; i++) {
585 | for (int left = 0; left < 784; left++) {
586 | A[i] += X[left] * weightXA[left][i];
587 | }
588 | // 添加偏置后,压入sigmoid
589 | A[i] += biasA[i];
590 | A[i] = sigmoid(A[i]);
591 | }
592 | // 计算隐藏层B
593 | for (int i = 0; i < 20; i++) {
594 | for (int left = 0; left < 20; left++) {
595 | B[i] += A[left] * weightAB[left][i];
596 | }
597 | B[i] += biasB[i];
598 | B[i] = sigmoid(B[i]);
599 | }
600 | // 计算Y
601 | for (int i = 0; i < 10; i++) {
602 | for (int left = 0; left < 20; left++) {
603 | Y[i] += B[left] * weightBY[left][i];
604 | }
605 | Y[i] += biasY[i];
606 | Y[i] = sigmoid(Y[i]);
607 | }
608 |
609 | // 计算损失
610 | int label = (int) currentBatch[28][0];
611 | double miniBatchLoss = 0;
612 | for (int i = 0; i < 10; i++) {
613 | double possibility;
614 | if (i == label) possibility = 1.0;
615 | else possibility = 0.0;
616 | miniBatchLoss += Math.pow(possibility - Y[i], 2);
617 | }
618 | loss += miniBatchLoss;
619 | }
620 | return loss / (double) allBatch.size();
621 | }
622 |
623 |
624 | // sigmoid 函数
625 | private static double sigmoid(double x) {
626 | return 1.00 / (Math.exp(-x) + 1.00);
627 | }
628 |
629 |
630 | // 预处理
631 | public static ArrayList normalization(ArrayList allBatch) {
632 | // 先处理图像
633 | for (int numberOfBatch = 0; numberOfBatch < allBatch.size(); numberOfBatch++) {
634 | allBatch.set(numberOfBatch, resizeImg(allBatch.get(numberOfBatch)));
635 | }
636 |
637 | // 然后归一化
638 | System.out.println("[Normalization Start...]");
639 | ArrayList normalizedBatch = new ArrayList<>();
640 | for (int i = 0; i < allBatch.size(); i++) {
641 | double[][] currentBatch = allBatch.get(i);
642 | double max = currentBatch[0][0];
643 | double min = currentBatch[0][0];
644 | for (int x = 0; x < currentBatch.length - 1; x++) {
645 | for (int y = 0; y < currentBatch[0].length; y++) {
646 | if (currentBatch[x][y] > max) {
647 | max = currentBatch[x][y];
648 | }
649 | if (currentBatch[x][y] < min) {
650 | min = currentBatch[x][y];
651 | }
652 | }
653 | }
654 | for (int x = 0; x < currentBatch.length - 1; x++) {
655 | for (int y = 0; y < currentBatch[0].length; y++) {
656 | currentBatch[x][y] = (currentBatch[x][y] - min) / (max - min);
657 | }
658 | }
659 | normalizedBatch.add(currentBatch);
660 | }
661 | System.out.println("[Normalization Finished]");
662 | return normalizedBatch;
663 |
664 | }
665 |
666 | public static double[][] normalization(double[][] img) {
667 | // 先处理图像
668 | img = resizeImg(img);
669 |
670 | // 然后归一化
671 | double max = img[0][0];
672 | double min = img[0][0];
673 | for (int x = 0; x < img.length - 1; x++) {
674 | for (int y = 0; y < img[0].length; y++) {
675 | if (img[x][y] > max) {
676 | max = img[x][y];
677 | }
678 | if (img[x][y] != 0 && img[x][y] < min) {
679 | min = img[x][y];
680 | }
681 | }
682 | }
683 | for (int x = 0; x < img.length - 1; x++) {
684 | for (int y = 0; y < img[0].length; y++) {
685 | img[x][y] = (img[x][y] - min) / (max - min);
686 | }
687 | }
688 | return img;
689 |
690 | }
691 |
692 |
693 | private static double[][] resizeImg(double[][] array) {
694 | int left = -1;
695 | int right = -1;
696 | int top = -1;
697 | int bottom = -1;
698 |
699 | for (int i = 0; i < 28; i++) {
700 | if (checkColumn(array, i)) {
701 | left = i;
702 | break;
703 | }
704 | }
705 | for (int i = 27; i >= 0; i--) {
706 | if (checkColumn(array, i)) {
707 | right = i;
708 | break;
709 | }
710 | }
711 |
712 | for (int i = 0; i < 28; i++) {
713 | if (checkRow(array, i)) {
714 | top = i;
715 | break;
716 | }
717 | }
718 |
719 | for (int i = 27; i >= 0; i--) {
720 | if (checkRow(array, i)) {
721 | bottom = i;
722 | break;
723 | }
724 | }
725 |
726 |
727 | if (left == -1 || right == -1 || top == -1 || bottom == -1) {
728 | return new double[29][28];
729 | }
730 |
731 | double[][] B;
732 | int height = Math.abs(bottom - top) + 1;
733 | int weight = Math.abs(left - right) + 1;
734 | double[][] A = new double[height][weight];
735 | for (int x = top; x <= bottom; x++) {
736 | if (right + 1 - left >= 0) System.arraycopy(array[x], left, A[x - top], 0, right + 1 - left);
737 | }
738 |
739 |
740 | if (height >= weight) {
741 | B = new double[28][(int) (Math.round((weight) * (28.00 / (height))))];
742 | } else {
743 | B = new double[(int) (Math.round((height) * (28.00 / (weight))))][28];
744 | }
745 |
746 | for (int x = 0; x < B.length; x++) {
747 | for (int y = 0; y < B[0].length; y++) {
748 | int w1 = weight;
749 | int h1 = height;
750 | int b1 = B.length;
751 | int b2 = B[0].length;
752 | int castX = (int) ((double) x * ((double) height / (double) B.length));
753 | int castY = (int) ((double) y * ((double) weight / (double) B[0].length));
754 | B[x][y] = A[castX][castY];
755 | }
756 | }
757 | for (int x = 0; x < 28; x++) {
758 | for (int y = 0; y < 28; y++) {
759 | array[x][y] = 0.0;
760 | }
761 | }
762 | if (B.length == 28) {
763 | int bias = (int) (Math.round((28.0 - (double) B[0].length) / 2.0));
764 | for (int x = 0; x < 28; x++) {
765 | System.arraycopy(B[x], 0, array[x], bias, B[0].length);
766 | }
767 | } else if (B[0].length == 28) {
768 | int bias = (int) (Math.round((28.0 - (double) B.length) / 2.0));
769 | for (int x = 0; x < B.length; x++) {
770 | System.arraycopy(B[x], 0, array[x + bias], 0, 28);
771 | }
772 | }
773 | return array;
774 | }
775 |
776 | private static boolean checkColumn(double[][] array, int column) {
777 | for (int x = 0; x < 28; x++) {
778 | if (array[x][column] != 0) return true;
779 | }
780 | return false;
781 | }
782 |
783 | private static boolean checkRow(double[][] array, int row) {
784 | for (int x = 0; x < 28; x++) {
785 | if (array[row][x] != 0) {
786 | return true;
787 | }
788 | }
789 | return false;
790 | }
791 |
792 | /**
793 | * record()用于导出模型到桌面
794 | */
795 | public static void record() throws IOException {
796 | // 如果有了就删除,以便覆盖
797 | File file = new File(EXPORT_PART);
798 | if (file.exists()) {
799 | file.delete();
800 | }
801 |
802 | BufferedWriter writer = new BufferedWriter(new FileWriter(EXPORT_PART, true));
803 | // 权重
804 | for (int i = 0; i < weightXA.length; i++) {
805 | for (int j = 0; j < weightXA[i].length; j++) {
806 | writer.write(weightXA[i][j] + "\n");
807 | }
808 | }
809 | for (int i = 0; i < weightAB.length; i++) {
810 | for (int j = 0; j < weightAB[i].length; j++) {
811 | writer.write(weightAB[i][j] + "\n");
812 | }
813 | }
814 | for (int i = 0; i < weightBY.length; i++) {
815 | for (int j = 0; j < weightBY[i].length; j++) {
816 | writer.write(weightBY[i][j] + "\n");
817 | }
818 | }
819 | // 偏置
820 | for (int i = 0; i < biasA.length; i++) {
821 | writer.write(biasA[i] + "\n");
822 | }
823 | for (int i = 0; i < biasB.length; i++) {
824 | writer.write(biasB[i] + "\n");
825 | }
826 | for (int i = 0; i < biasY.length; i++) {
827 | writer.write(biasY[i] + "\n");
828 | }
829 | // 权重的Adam数据
830 | for (int i = 0; i < weightXA_Mt.length; i++) {
831 | for (int j = 0; j < weightXA_Mt[i].length; j++) {
832 | writer.write(weightXA_Mt[i][j] + "\n");
833 | }
834 | }
835 | for (int i = 0; i < weightXA_Vt.length; i++) {
836 | for (int j = 0; j < weightXA_Vt[i].length; j++) {
837 | writer.write(weightXA_Vt[i][j] + "\n");
838 | }
839 | }
840 | for (int i = 0; i < weightAB_Mt.length; i++) {
841 | for (int j = 0; j < weightAB_Mt[i].length; j++) {
842 | writer.write(weightAB_Mt[i][j] + "\n");
843 | }
844 | }
845 | for (int i = 0; i < weightAB_Vt.length; i++) {
846 | for (int j = 0; j < weightAB_Vt[i].length; j++) {
847 | writer.write(weightAB_Vt[i][j] + "\n");
848 | }
849 | }
850 | for (int i = 0; i < weightBY_Mt.length; i++) {
851 | for (int j = 0; j < weightBY_Mt[i].length; j++) {
852 | writer.write(weightBY_Mt[i][j] + "\n");
853 | }
854 | }
855 | for (int i = 0; i < weightBY_Vt.length; i++) {
856 | for (int j = 0; j < weightBY_Vt[i].length; j++) {
857 | writer.write(weightBY_Vt[i][j] + "\n");
858 | }
859 | }
860 | // 偏置的Adam数据
861 | for (int i = 0; i < biasA_Mt.length; i++) {
862 | writer.write(biasA_Mt[i] + "\n");
863 | }
864 | for (int i = 0; i < biasA_Vt.length; i++) {
865 | writer.write(biasA_Vt[i] + "\n");
866 | }
867 | for (int i = 0; i < biasB_Mt.length; i++) {
868 | writer.write(biasB_Mt[i] + "\n");
869 | }
870 | for (int i = 0; i < biasB_Vt.length; i++) {
871 | writer.write(biasB_Vt[i] + "\n");
872 | }
873 | for (int i = 0; i < biasY_Mt.length; i++) {
874 | writer.write(biasY_Mt[i] + "\n");
875 | }
876 | for (int i = 0; i < biasY_Vt.length; i++) {
877 | writer.write(biasY_Vt[i] + "\n");
878 | }
879 | // 迭代数
880 | writer.write(weightXA_Iteration + "\n");
881 | writer.write(weightAB_Iteration + "\n");
882 | writer.write(weightBY_Iteration + "\n");
883 | writer.write(biasA_Iteration + "\n");
884 | writer.write(biasB_Iteration + "\n");
885 | writer.write(biasY_Iteration + "\n");
886 |
887 | writer.close();
888 | }
889 |
890 | /**
891 | * readModel()用于读取model文件夹的模型
892 | */
893 | public static void readModel() throws IOException {
894 | BufferedReader reader = new BufferedReader(new FileReader(IMPORT_PATH));
895 | // 权重
896 | for (int i = 0; i < weightXA.length; i++) {
897 | for (int j = 0; j < weightXA[i].length; j++) {
898 | weightXA[i][j] = Double.parseDouble(reader.readLine());
899 | }
900 | }
901 | for (int i = 0; i < weightAB.length; i++) {
902 | for (int j = 0; j < weightAB[i].length; j++) {
903 | weightAB[i][j] = Double.parseDouble(reader.readLine());
904 | }
905 | }
906 | for (int i = 0; i < weightBY.length; i++) {
907 | for (int j = 0; j < weightBY[i].length; j++) {
908 | weightBY[i][j] = Double.parseDouble(reader.readLine());
909 | }
910 | }
911 | // 偏置
912 | for (int i = 0; i < biasA.length; i++) {
913 | biasA[i] = Double.parseDouble(reader.readLine());
914 | }
915 | for (int i = 0; i < biasB.length; i++) {
916 | biasB[i] = Double.parseDouble(reader.readLine());
917 | }
918 | for (int i = 0; i < biasY.length; i++) {
919 | biasY[i] = Double.parseDouble(reader.readLine());
920 | }
921 | // 权重的Adam数据
922 | for (int i = 0; i < weightXA_Mt.length; i++) {
923 | for (int j = 0; j < weightXA_Mt[i].length; j++) {
924 | weightXA_Mt[i][j] = Double.parseDouble(reader.readLine());
925 | }
926 | }
927 | for (int i = 0; i < weightXA_Vt.length; i++) {
928 | for (int j = 0; j < weightXA_Vt[i].length; j++) {
929 | weightXA_Vt[i][j] = Double.parseDouble(reader.readLine());
930 | }
931 | }
932 | for (int i = 0; i < weightAB_Mt.length; i++) {
933 | for (int j = 0; j < weightAB_Mt[i].length; j++) {
934 | weightAB_Mt[i][j] = Double.parseDouble(reader.readLine());
935 | }
936 | }
937 | for (int i = 0; i < weightAB_Vt.length; i++) {
938 | for (int j = 0; j < weightAB_Vt[i].length; j++) {
939 | weightAB_Vt[i][j] = Double.parseDouble(reader.readLine());
940 | }
941 | }
942 | for (int i = 0; i < weightBY_Mt.length; i++) {
943 | for (int j = 0; j < weightBY_Mt[i].length; j++) {
944 | weightBY_Mt[i][j] = Double.parseDouble(reader.readLine());
945 | }
946 | }
947 | for (int i = 0; i < weightBY_Vt.length; i++) {
948 | for (int j = 0; j < weightBY_Vt[i].length; j++) {
949 | weightBY_Vt[i][j] = Double.parseDouble(reader.readLine());
950 | }
951 | }
952 | // 偏置的Adam数据
953 | for (int i = 0; i < biasA_Mt.length; i++) {
954 | biasA_Mt[i] = Double.parseDouble(reader.readLine());
955 | }
956 | for (int i = 0; i < biasA_Vt.length; i++) {
957 | biasA_Vt[i] = Double.parseDouble(reader.readLine());
958 | }
959 | for (int i = 0; i < biasB_Mt.length; i++) {
960 | biasB_Mt[i] = Double.parseDouble(reader.readLine());
961 | }
962 | for (int i = 0; i < biasB_Vt.length; i++) {
963 | biasB_Vt[i] = Double.parseDouble(reader.readLine());
964 | }
965 | for (int i = 0; i < biasY_Mt.length; i++) {
966 | biasY_Mt[i] = Double.parseDouble(reader.readLine());
967 | }
968 | for (int i = 0; i < biasY_Vt.length; i++) {
969 | biasY_Vt[i] = Double.parseDouble(reader.readLine());
970 | }
971 | // 迭代数
972 | weightXA_Iteration = Integer.parseInt(reader.readLine());
973 | weightAB_Iteration = Integer.parseInt(reader.readLine());
974 | weightBY_Iteration = Integer.parseInt(reader.readLine());
975 | biasA_Iteration = Integer.parseInt(reader.readLine());
976 | biasB_Iteration = Integer.parseInt(reader.readLine());
977 | biasY_Iteration = Integer.parseInt(reader.readLine());
978 |
979 | reader.close();
980 | }
981 |
982 | /**
983 | * 初始化
984 | */
985 | private static void initialize() {
986 | weightXA = initializeWeights(784, 20);
987 | weightAB = initializeWeights(20, 20);
988 | weightBY = initializeWeights(20, 10);
989 |
990 | biasA = initializeBias(20);
991 | biasB = initializeBias(20);
992 | biasY = initializeBias(10);
993 | }
994 |
995 | public static double[][] initializeWeights(int fromSize, int toSize) {
996 | double[][] weights = new double[fromSize][toSize];
997 | double limit = Math.sqrt(6.0 / fromSize);
998 |
999 | for (int i = 0; i < fromSize; i++) {
1000 | for (int j = 0; j < toSize; j++) {
1001 | Random random1 = new Random();
1002 | Random random2 = new Random(random1.nextInt());
1003 | // 使用均匀分布的He初始化
1004 | weights[i][j] = (random2.nextDouble() * 2.0 * limit - limit);
1005 | }
1006 | }
1007 | return weights;
1008 | }
1009 |
1010 | public static double[] initializeBias(int size) {
1011 | double[] bias = new double[size];
1012 |
1013 |
1014 | for (int m = 0; m < size; m++) {
1015 | Random random1 = new Random();
1016 | Random random2 = new Random(random1.nextInt());
1017 | // 使用正态分布初始化偏置,均值为0,标准差为0.01,可根据实际调整
1018 | bias[m] = random2.nextGaussian() * 0.01;
1019 | }
1020 | return bias;
1021 | }
1022 | }
1023 |
--------------------------------------------------------------------------------
/如何在Java中实现手写数字识别神经网络.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/如何在Java中实现手写数字识别神经网络.pdf
--------------------------------------------------------------------------------