├── How to implement a neural network for handwritten digit recognition in Java.pdf ├── README.md ├── imageResource ├── image-20241217173028697.png ├── image-20241217184752242.png ├── image-20241217185102809.png ├── image-20241217185807044.png ├── image-20241217185809888.png ├── image-20241217185844494.png ├── image-20241217190426681.png └── image-20241217190518353.png ├── resources ├── TrainImg │ └── txt.txt ├── VerifyImg │ └── txt.txt ├── database │ ├── t10k-images.idx3-ubyte │ ├── t10k-labels.idx1-ubyte │ ├── train-images-idx3-ubyte │ └── train-labels-idx1-ubyte └── model │ └── Model_Data.txt ├── src ├── Application.java └── functions │ ├── DevelopTool.java │ ├── GUI.java │ ├── LanguageSetting.java │ ├── MNISTReader.java │ └── Processing.java └── 如何在Java中实现手写数字识别神经网络.pdf /How to implement a neural network for handwritten digit recognition in Java.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/How to implement a neural network for handwritten digit recognition in Java.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Java Hand Writing Digit Recognition 2 | 3 | > # Java手写数字识别 4 | 5 | Read [How to implement a neural network for handwritten digit recognition in Java.pdf](https://github.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/blob/main/How%20to%20implement%20a%20neural%20network%20for%20handwritten%20digit%20recognition%20in%20Java.pdf) for more information! 6 | 7 | > 阅读 [如何在Java中实现手写数字识别神经网络.pdf](https://github.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/blob/main/如何在Java中实现手写数字识别神经网络.pdf) 文件以获得更多信息! 8 | 9 | 10 | 11 | ## Setup 12 | 13 | > ## 初始化 14 | 15 | Import it and then run the `Application` class to start the program. 16 | 17 | > 导入后运行`Application` 类以启动程序 18 | 19 | image-20241217173028697 20 | 21 | **STEP 1:** 22 | 23 | $\color{red}VERY~~IMPORTANT$ 24 | 25 | $\color{red}VERY~~IMPORTANT$ 26 | 27 | $\color{red}VERY~~IMPORTANT$ 28 | 29 | When starting the program for the first time, please click on `Settings` -> `Refresh Resource Files Again` in the upper left corner. (Only for the first time setup) 30 | 31 | > 初次启动程序请单击左上角 `设置` -> `重新刷新资源文件` (仅仅针对第一次启动) 32 | 33 | image-20241217184752242 34 | 35 | **STEP 2:** 36 | 37 | Then, click on `Read the Previous Model` to obtain the already trained model. (Every single time the program is started, the program will not read any model. Instead, it will randomize the neural network.) :) 38 | 39 | > 然后,点击`读取先前模型`来获取已经训练好的模型(每次启动程序,程序都不会读取任何模型,而是把神经网络做一个随机):) 40 | 41 | image-20241217185102809 42 | 43 | **STEP 3:** 44 | 45 | Try to write a number and then click on `Predict This Image` button. 46 | 47 | > 尝试写一个数字,然后点击`预测此图像` 48 | 49 | ![image-20241217185844494](./imageResource/image-20241217185844494.png) 50 | 51 | ## Verify Accuracy Of The Current Model 52 | 53 | > ## 验证当前模型的准确度 54 | 55 | 56 | 57 | Click `Verify the Model Accuracy with the Validation Set` button. 58 | 59 | > 按下`以验证集验证模型准确度`按钮 60 | 61 | ![image-20241217190426681](./imageResource/image-20241217190426681.png) 62 | -------------------------------------------------------------------------------- /imageResource/image-20241217173028697.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217173028697.png -------------------------------------------------------------------------------- /imageResource/image-20241217184752242.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217184752242.png -------------------------------------------------------------------------------- /imageResource/image-20241217185102809.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217185102809.png -------------------------------------------------------------------------------- /imageResource/image-20241217185807044.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217185807044.png -------------------------------------------------------------------------------- /imageResource/image-20241217185809888.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217185809888.png -------------------------------------------------------------------------------- /imageResource/image-20241217185844494.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217185844494.png -------------------------------------------------------------------------------- /imageResource/image-20241217190426681.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217190426681.png -------------------------------------------------------------------------------- /imageResource/image-20241217190518353.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/imageResource/image-20241217190518353.png -------------------------------------------------------------------------------- /resources/TrainImg/txt.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/resources/TrainImg/txt.txt -------------------------------------------------------------------------------- /resources/VerifyImg/txt.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/resources/VerifyImg/txt.txt -------------------------------------------------------------------------------- /resources/database/t10k-images.idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/resources/database/t10k-images.idx3-ubyte -------------------------------------------------------------------------------- /resources/database/t10k-labels.idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /resources/database/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/resources/database/train-images-idx3-ubyte -------------------------------------------------------------------------------- /resources/database/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/resources/database/train-labels-idx1-ubyte -------------------------------------------------------------------------------- /src/Application.java: -------------------------------------------------------------------------------- 1 | import functions.GUI; 2 | import java.io.IOException; 3 | 4 | public class Application { 5 | public static void main(String[] args) throws IOException { 6 | new GUI(); 7 | } 8 | 9 | } 10 | -------------------------------------------------------------------------------- /src/functions/DevelopTool.java: -------------------------------------------------------------------------------- 1 | package functions; 2 | 3 | import javax.imageio.ImageIO; 4 | import java.awt.image.BufferedImage; 5 | import java.io.*; 6 | import java.util.ArrayList; 7 | import java.util.stream.Stream; 8 | 9 | public class DevelopTool { 10 | static final String PATH_TRAIN_IMG_FOLDER = "resources/TrainImg"; 11 | static final String PATH_VERIFY_IMG_FOLDER = "resources/VerifyImg"; 12 | static final String PATH_TRAIN_TXT = "resources/database/MNIST_TRAIN_DATA.txt"; 13 | static final String PATH_VERIFY_TXT = "resources/database/MNIST_VERIFY_DATA.txt"; 14 | private DevelopTool(){} 15 | 16 | public static void clearCache() { 17 | // 清空图像文件 18 | System.out.println("[Deleting img folder...]"); 19 | deleteAllFilesInFolder(PATH_TRAIN_IMG_FOLDER); 20 | deleteAllFilesInFolder(PATH_VERIFY_IMG_FOLDER); 21 | 22 | // 删除txt文件 23 | System.out.println("[Deleting txt file...]"); 24 | deleteTxtFile(PATH_TRAIN_TXT); 25 | deleteTxtFile(PATH_VERIFY_TXT); 26 | 27 | } 28 | 29 | // 对外初始化 30 | public static void regenerateResourcesFile() throws IOException { 31 | System.out.println("[Initializing...]"); 32 | 33 | // 清空图像文件 34 | System.out.println("[Deleting img folder...]"); 35 | deleteAllFilesInFolder(PATH_TRAIN_IMG_FOLDER); 36 | deleteAllFilesInFolder(PATH_VERIFY_IMG_FOLDER); 37 | 38 | 39 | // 删除txt文件 40 | System.out.println("[Deleting txt file...]"); 41 | deleteTxtFile(PATH_TRAIN_TXT); 42 | deleteTxtFile(PATH_VERIFY_TXT); 43 | 44 | 45 | // 写入txt 46 | writeTxtFile(PATH_TRAIN_TXT); 47 | writeTxtFile(PATH_VERIFY_TXT); 48 | 49 | // 生成train img文件 50 | genImg(PATH_TRAIN_IMG_FOLDER); 51 | genImg(PATH_VERIFY_IMG_FOLDER); 52 | 53 | System.out.println("[Finish]"); 54 | 55 | 56 | } 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | private static void genImg(String PATH) throws IOException { 66 | BufferedReader reader; 67 | if(PATH.equals(PATH_TRAIN_IMG_FOLDER)) { 68 | reader = new BufferedReader(new FileReader(PATH_TRAIN_TXT)); 69 | } 70 | else 71 | { 72 | reader = new BufferedReader(new FileReader(PATH_VERIFY_TXT)); 73 | } 74 | String line ; 75 | 76 | int countProcess = 1; 77 | 78 | while((line = reader.readLine()) != null){ 79 | if(line.charAt(0) == 'L') { 80 | ++countProcess; 81 | continue; 82 | } 83 | 84 | int length = stringToDoubleArray(line).length; 85 | 86 | double[][] array = new double[length][length]; 87 | double[] temp = stringToDoubleArray(line); 88 | for(int i = 0; i < length; ++i){ 89 | array[0][i] = temp[i]; 90 | } 91 | for(int i = 1; i <= length-1; ++i){ 92 | line = reader.readLine(); 93 | temp = stringToDoubleArray(line); 94 | for(int j = 0; j < length; ++j){ 95 | array[i][j] = temp[j]; 96 | } 97 | } 98 | 99 | 100 | if(PATH.equals(PATH_TRAIN_IMG_FOLDER)) { 101 | if(countProcess%1000==0) 102 | { 103 | System.out.println("[Generating image. Process] \t"+String.format("%.2f",((double)countProcess/60000.00*100.00))+"%"); 104 | } 105 | } 106 | else 107 | { 108 | if(countProcess%1000==0) 109 | { 110 | System.out.println("[Generating image. Process] \t"+String.format("%.2f",((double)countProcess/10000.00*100.00))+"%"); 111 | } 112 | } 113 | 114 | 115 | if(PATH.equals(PATH_TRAIN_IMG_FOLDER)) 116 | arrayToImage(array, PATH_TRAIN_IMG_FOLDER, ""+countProcess); 117 | else 118 | arrayToImage(array, PATH_VERIFY_IMG_FOLDER, ""+countProcess); 119 | } 120 | reader.close(); 121 | } 122 | private static void writeTxtFile(String path) throws IOException { 123 | ArrayList mnist; 124 | if(path.equals(PATH_TRAIN_TXT)) { 125 | mnist = MNISTReader.getMNISTTrainDataFromFile(); 126 | } 127 | else 128 | { 129 | mnist = MNISTReader.getMNISTVerifyDataFromFile(); 130 | } 131 | BufferedWriter writer = new BufferedWriter(new FileWriter(path,true)); 132 | 133 | System.out.println("[Writing txt file...]"); 134 | for(int i = 0; i < mnist.size(); i++){ 135 | double[][] arr = mnist.get(i); 136 | for(int j = 0; j < arr.length-1; j++){ 137 | for(int k = 0; k < arr[j].length; k++){ 138 | writer.write(arr[j][k] + " "); 139 | } 140 | writer.newLine(); 141 | } 142 | writer.write("Label ↑:"+mnist.get(i)[mnist.get(i).length-1][0]); 143 | writer.newLine(); 144 | } 145 | writer.close(); 146 | } 147 | // 给一个double[][] 输出图片 148 | private static void arrayToImage(double[][] array, String path, String name) throws IOException { 149 | int height = array.length; 150 | int width = array[0].length; 151 | 152 | // Create a BufferedImage 153 | BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY); 154 | 155 | // Populate the image with the array values 156 | for (int y = 0; y < height; y++) { 157 | for (int x = 0; x < width; x++) { 158 | // Clamp and scale the value from [0, 1] to [0, 255] 159 | int value = (int) Math.round(Math.min(1.0, Math.max(0.0, array[y][x])) * 255); 160 | int gray = (value << 16) | (value << 8) | value; // Grayscale value 161 | image.setRGB(x, y, gray); 162 | } 163 | } 164 | 165 | // Create the output file 166 | File outputFile = new File(path, name + ".png"); 167 | 168 | // Ensure the directory exists 169 | if (!outputFile.getParentFile().exists()) { 170 | outputFile.getParentFile().mkdirs(); 171 | } 172 | 173 | // Write the image to file 174 | ImageIO.write(image, "png", outputFile); 175 | } 176 | // 给一个图片 输出double[][] 177 | // 静态方法,传入图片路径,返回double[][]数组 178 | 179 | public static double[][] convertToGrayscale(String path) throws IOException { 180 | // 读取图片 181 | BufferedImage image = ImageIO.read(new File(path)); 182 | 183 | // 获取图片的宽度和高度 184 | int width = image.getWidth(); 185 | int height = image.getHeight(); 186 | 187 | // 创建二维数组存储灰度值 188 | double[][] grayscale = new double[height][width]; 189 | 190 | // 遍历每个像素,获取精确的灰度值 191 | for (int y = 0; y < height; y++) { 192 | for (int x = 0; x < width; x++) { 193 | // 读取像素的ARGB值 194 | int argb = image.getRGB(x, y); 195 | 196 | // 获取灰度值(假设图片为灰度图,RGB通道值相同) 197 | int gray = argb & 0xFF; // 直接取低8位(灰度图中,R=G=B) 198 | 199 | // 将灰度值归一化到 0.0~1.0 范围 200 | grayscale[y][x] = gray / 255.0; 201 | } 202 | } 203 | 204 | return grayscale; 205 | } 206 | // 删除 img文件夹所有子文件 207 | private static void deleteAllFilesInFolder(String path) { 208 | 209 | File folder = new File(path); 210 | 211 | // Check if the path is a valid directory 212 | if (!folder.exists() || !folder.isDirectory()) { 213 | System.err.println("[Warning, Class: DevelopTool] Invalid directory: " + path); 214 | return; 215 | } 216 | 217 | // Get all files in the directory 218 | File[] files = folder.listFiles(); 219 | if (files == null || files.length == 0) { 220 | System.err.println("[Warning, Class: DevelopTool] No files to delete in directory: " + path); 221 | return; 222 | } 223 | 224 | System.err.println("[Deleting data. Please wait]"); 225 | // Iterate through the files and delete each one 226 | for (File file : files) { 227 | if (file.isFile()) { 228 | if (!file.delete()) { 229 | System.err.println("[Warning, Class:DevelopTool] Failed to delete file: " + file.getName()); 230 | return; 231 | } 232 | } else if (file.isDirectory()) { 233 | System.err.println("[Warning, Class:DevelopTool] Skipping subdirectory: " + file.getName()); 234 | } 235 | } 236 | 237 | } 238 | // 删除 txt文件 239 | private static void deleteTxtFile(String path) { 240 | File file = new File(path); 241 | if(file.exists()) { 242 | file.delete(); 243 | } 244 | 245 | } 246 | // 获取double 247 | private static double[] stringToDoubleArray(String input) { 248 | return Stream.of(input.split(" ")) 249 | .mapToDouble(Double::parseDouble) 250 | .toArray(); 251 | } 252 | 253 | 254 | // return 28*29 *n Batch 255 | public static ArrayList getTrainBatch() throws IOException { 256 | ArrayList batch = new ArrayList<>(); 257 | 258 | BufferedReader reader = new BufferedReader(new FileReader(PATH_TRAIN_TXT)); 259 | String line = null; 260 | int batchSize = GUI.BATCH_SET_SIZE; 261 | 262 | for(int i = 0;i getVerifyBatch() throws IOException { 279 | ArrayList batch = new ArrayList<>(); 280 | 281 | BufferedReader reader = new BufferedReader(new FileReader(PATH_VERIFY_TXT)); 282 | String line = null; 283 | int batchSize = 10000; // 验证集的大小是10000 284 | 285 | System.out.println(batchSize); 286 | for(int i = 0;i> 24) & 0xff; 317 | int red = (argb >> 16) & 0xff; 318 | int green = (argb >> 8) & 0xff; 319 | int blue = argb & 0xff; 320 | // 灰度值的计算(这里简单平均RGB三个分量,对于灰度图本身RGB相等,也可以直接取其一) 321 | int gray = (red + green + blue) / 3; 322 | // 将灰度值归一化到0到1区间 323 | result[y][x] = (double) gray / 255.0; 324 | } 325 | } 326 | return result; 327 | } 328 | } 329 | -------------------------------------------------------------------------------- /src/functions/GUI.java: -------------------------------------------------------------------------------- 1 | package functions; 2 | 3 | import javax.swing.*; 4 | import java.awt.*; 5 | import java.awt.event.*; 6 | import java.io.IOException; 7 | import java.util.Random; 8 | 9 | 10 | public class GUI extends JFrame { 11 | int CONSTANT_K = 300; 12 | static boolean stopButtonPressed = false; 13 | protected static int BATCH_SET_SIZE = 10000; 14 | public GUI() { 15 | // 先弹出语言界面 16 | SwingUtilities.invokeLater(() -> { 17 | // 创建 JDialog 弹窗,并设置为模态对话框 18 | JDialog dialog = new JDialog((Frame) null, "Language:", true); 19 | dialog.setSize(300, 150); 20 | dialog.setResizable(false); 21 | dialog.setLocationRelativeTo(null); 22 | dialog.setDefaultCloseOperation(JDialog.DO_NOTHING_ON_CLOSE); 23 | 24 | JPanel panel = new JPanel(); 25 | panel.setLayout(new FlowLayout(FlowLayout.CENTER, 10, 10)); 26 | 27 | // 使用默认字体设置标签字体 28 | JLabel label = new JLabel("语言设置/ Language Setting"); 29 | label.setFont(new Font(Font.DIALOG, Font.PLAIN, 16)); 30 | panel.add(label); 31 | 32 | JButton chineseButton = new JButton("中文"); 33 | // 使用默认字体设置按钮字体 34 | chineseButton.setFont(new Font(Font.DIALOG, Font.PLAIN, 14)); 35 | chineseButton.setPreferredSize(new Dimension(80, 30)); 36 | chineseButton.addActionListener(new ActionListener() { 37 | @Override 38 | public void actionPerformed(ActionEvent e) { 39 | // 这里可以添加选择中文后的具体逻辑,比如设置程序语言为中文等 40 | LanguageSetting.runChineseText(); 41 | // 关闭对话框,解除阻塞,后续代码才能继续执行 42 | dialog.dispose(); 43 | } 44 | }); 45 | panel.add(chineseButton); 46 | 47 | JButton englishButton = new JButton("English"); 48 | // 使用默认字体设置按钮字体 49 | englishButton.setFont(new Font(Font.DIALOG, Font.PLAIN, 14)); 50 | englishButton.setPreferredSize(new Dimension(80, 30)); 51 | englishButton.addActionListener(new ActionListener() { 52 | @Override 53 | public void actionPerformed(ActionEvent e) { 54 | // 这里可以添加选择英文后的具体逻辑,比如设置程序语言为英文等 55 | LanguageSetting.runEnglishText(); 56 | // 关闭对话框,解除阻塞,后续代码才能继续执行 57 | dialog.dispose(); 58 | } 59 | }); 60 | panel.add(englishButton); 61 | 62 | dialog.add(panel); 63 | dialog.setVisible(true); 64 | 65 | // 这里的代码在对话框关闭前不会执行,因为对话框是模态的,阻塞了此处代码的运行 66 | GUI.this.setSize(1024, 768); 67 | 68 | initApplication(); 69 | 70 | GUI.this.setVisible(true); 71 | }); 72 | 73 | 74 | } 75 | 76 | private void initApplication() { 77 | 78 | // 这里设置基本属性 79 | { 80 | 81 | this.setSize(703, 820); 82 | 83 | // 设置界面标题 84 | this.setTitle(LanguageSetting.c1); 85 | 86 | // 设置软件置顶 87 | this.setAlwaysOnTop(true); 88 | 89 | // 设置界面居中 90 | this.setLocationRelativeTo(null); 91 | 92 | // 游戏关闭那么就结束运行 93 | this.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); 94 | 95 | // 取消默认的界面内部的居中方式,这样下面图片才可以设置坐标 96 | this.setLayout(null); 97 | 98 | // 不允许resize 99 | this.setResizable(false); 100 | } 101 | 102 | // 这里设置按钮 103 | { 104 | JMenuBar menuBar = new JMenuBar(); 105 | JMenu fileMenu = new JMenu(LanguageSetting.b1); 106 | JMenuItem reset = new JMenuItem(LanguageSetting.b2); 107 | JMenuItem clearCache = new JMenuItem(LanguageSetting.b3); 108 | 109 | 110 | clearCache.addActionListener(new ActionListener() { 111 | @Override 112 | public void actionPerformed(ActionEvent e) { 113 | GUI.this.setVisible(false); 114 | DevelopTool.clearCache(); 115 | GUI.this.setVisible(true); 116 | System.out.println("[已成功清除缓存]"); 117 | } 118 | }); 119 | reset.addActionListener(new ActionListener() { 120 | @Override 121 | public void actionPerformed(ActionEvent e) { 122 | showJDialog1(); 123 | } 124 | }); 125 | 126 | fileMenu.add(reset); 127 | fileMenu.add(clearCache); 128 | menuBar.add(fileMenu); 129 | 130 | this.getContentPane().add(menuBar); 131 | this.setJMenuBar(menuBar); 132 | } 133 | 134 | 135 | 136 | // 这里设置界面内的内容 137 | { 138 | final int panelX = 20; 139 | final int panelY = 20; 140 | 141 | this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 142 | 143 | // 创建用于放置28*28个小方块的面板 144 | JPanel newJPanel = new JPanel(); 145 | newJPanel.setLayout(null); // 设置为null布局,便于手动定位小方块 146 | 147 | int squareSize = 10; // 每个小方块的边长(包含边框) 148 | int spacing = 0; // 小方块之间的间隔,用于显示边框效果 149 | 150 | for (int i = 0; i < 28; i++) { 151 | for (int j = 0; j < 28; j++) { 152 | // 创建每个小方块(以JPanel模拟) 153 | JPanel squarePanel = new JPanel() { 154 | @Override 155 | protected void paintComponent(Graphics g) { 156 | super.paintComponent(g); 157 | g.setColor(new Color(0,0,0)); 158 | // 绘制小方块的边框,这里绘制四条边 159 | g.drawRect(0, 0, getWidth(), getHeight()); 160 | } 161 | }; 162 | squarePanel.setBackground(new Color(0, 0, 0)); 163 | squarePanel.setBounds(i * (squareSize + spacing), j * (squareSize + spacing), squareSize, squareSize); 164 | newJPanel.add(squarePanel); 165 | } 166 | } 167 | newJPanel.setBounds(panelX, panelY, 280, 280); 168 | 169 | 170 | // 键盘监听 171 | Random rand = new Random(); 172 | newJPanel.addMouseMotionListener(new MouseMotionAdapter() { 173 | @Override 174 | public void mouseDragged(MouseEvent e) { 175 | for (int dx = -5; dx <= 5; dx++) { 176 | for (int dy = -5; dy <= 5; dy++) { 177 | if (e.getX() + dx >= 1 && e.getX() + dx <= 279 || e.getY() + dy >= 1 && e.getY() + dy <= 279) 178 | updatePanel(e.getX(), e.getY(), dx, dy, newJPanel); 179 | } 180 | } 181 | } 182 | }); 183 | newJPanel.addMouseListener(new MouseAdapter() { 184 | @Override 185 | public void mousePressed(MouseEvent e) { 186 | for (int dx = -5; dx <= 5; dx++) { 187 | for (int dy = -5; dy <= 5; dy++) { 188 | if (e.getX() + dx >= 1 && e.getX() + dx <= 279 || e.getY() + dy >= 1 && e.getY() + dy <= 279) 189 | updatePanel(e.getX(), e.getY(), dx, dy, newJPanel); 190 | } 191 | } 192 | } 193 | }); 194 | 195 | 196 | // 将包含小方块的面板添加到窗口的内容面板 197 | this.getContentPane().add(newJPanel); 198 | 199 | 200 | 201 | 202 | JButton reset = new JButton(LanguageSetting.a1); 203 | reset.addActionListener(new ActionListener() { 204 | @Override 205 | public void actionPerformed(ActionEvent e) { 206 | for(int x=0;x<28;x++) 207 | { 208 | for(int y=0;y<28;y++) 209 | { 210 | newJPanel.getComponentAt(x*10,y*10).setBackground(new Color(0,0,0)); 211 | } 212 | } 213 | 214 | System.out.println("已重置画布!"); 215 | } 216 | }); 217 | reset.setBounds(20 ,310,130,30); 218 | this.getContentPane().add(reset); 219 | 220 | JButton stopButton = new JButton(LanguageSetting.a9); 221 | stopButton.addActionListener(new ActionListener() { 222 | @Override 223 | public void actionPerformed(ActionEvent e) { 224 | stopButtonPressed = true; 225 | } 226 | }); 227 | stopButton.setBounds(330,370,150,20); 228 | this.getContentPane().add(stopButton); 229 | 230 | JButton recordButton = new JButton(LanguageSetting.a10); 231 | recordButton.addActionListener(new ActionListener() { 232 | @Override 233 | public void actionPerformed(ActionEvent e) { 234 | try { 235 | Processing.record(); 236 | System.out.println("[成功导出模型到桌面!]"); 237 | } catch (IOException ex) { 238 | throw new RuntimeException(ex); 239 | } 240 | } 241 | }); 242 | recordButton.setBounds(330,410,350,20); 243 | this.getContentPane().add(recordButton); 244 | 245 | JButton loadButton = new JButton(LanguageSetting.a11); 246 | loadButton.addActionListener(new ActionListener() { 247 | @Override 248 | public void actionPerformed(ActionEvent e) { 249 | GUI.this.setVisible(false); 250 | try { 251 | Processing.readModel(); 252 | } catch (IOException ex) { 253 | throw new RuntimeException(ex); 254 | } 255 | System.out.println("[已成功导入模型!]"); 256 | GUI.this.setVisible(true); 257 | } 258 | }); 259 | loadButton.setBounds(330,440,350,20); 260 | this.getContentPane().add(loadButton); 261 | 262 | JButton setAsTrainImg = new JButton(LanguageSetting.a5); 263 | setAsTrainImg.setBounds(330,120,150,20); 264 | setAsTrainImg.addActionListener(new ActionListener() { 265 | @Override 266 | public void actionPerformed(ActionEvent e) { 267 | JTextField jt = (JTextField) GUI.this.getContentPane().getComponentAt(330,200); 268 | GUI.BATCH_SET_SIZE = Integer.parseInt(jt.getText()); 269 | Thread thread = new Thread(new Runnable() { 270 | @Override 271 | public void run() { 272 | try { 273 | Processing.run(); 274 | } catch (IOException ex) { 275 | throw new RuntimeException(ex); 276 | } 277 | } 278 | }); 279 | thread.start(); 280 | } 281 | }); 282 | this.getContentPane().add(setAsTrainImg); 283 | 284 | JLabel text1 = new JLabel(LanguageSetting.a6); 285 | text1.setBounds(330,140,170,20); 286 | this.getContentPane().add(text1); 287 | 288 | JTextField input1 = new JTextField(); 289 | input1.setText("10000"); 290 | input1.setBounds(330, 170, 75, 40); 291 | 292 | input1.setFont(new Font("custom", Font.BOLD,25)); 293 | this.getContentPane().add(input1); 294 | 295 | JButton getImgFromFolder = new JButton(LanguageSetting.a7); 296 | getImgFromFolder.setBounds(330, 220, 350, 20); 297 | getImgFromFolder.addActionListener(new ActionListener() { 298 | @Override 299 | public void actionPerformed(ActionEvent e) { 300 | try { 301 | getRandomImgFromFile(); 302 | Processing.getDigit(GUI.this); 303 | } catch (IOException ex) { 304 | throw new RuntimeException(ex); 305 | } 306 | } 307 | }); 308 | 309 | 310 | this.getContentPane().add(getImgFromFolder); 311 | 312 | // 训练该样本的准确性 313 | JButton verifyThisModel = new JButton(LanguageSetting.a8); 314 | verifyThisModel.setBounds(330, 260, 350, 20); 315 | verifyThisModel.addActionListener(new ActionListener() { 316 | @Override 317 | public void actionPerformed(ActionEvent e) { 318 | // 创建一个新线程来执行长时间运行的任务 319 | Thread thread = new Thread(new Runnable() { 320 | @Override 321 | public void run() { 322 | try { 323 | Processing.VerifyAccuracy(GUI.this); 324 | } catch (IOException ex) { 325 | throw new RuntimeException(ex); 326 | } 327 | } 328 | }); 329 | thread.start(); 330 | } 331 | }); 332 | this.getContentPane().add(verifyThisModel); 333 | 334 | 335 | JLabel correctNumber = new JLabel("0"); 336 | correctNumber.setFont(new Font("custom", Font.BOLD,18)); 337 | correctNumber.setBounds(430,280,150,30); 338 | this.getContentPane().add(correctNumber); 339 | 340 | JLabel per10000 = new JLabel("/ 10000"); 341 | per10000.setFont(new Font("custom", Font.BOLD,18)); 342 | per10000.setBounds(480,280,150,30); 343 | this.getContentPane().add(per10000); 344 | 345 | JLabel countText = new JLabel("Count:"); 346 | countText.setFont(new Font("custom", Font.BOLD,18)); 347 | countText.setBounds(330,280,150,30); 348 | this.getContentPane().add(countText); 349 | 350 | 351 | JLabel accuracyNumber = new JLabel("00.00%"); 352 | accuracyNumber.setFont(new Font("custom", Font.BOLD,18)); 353 | accuracyNumber.setBounds(430,320,150,30); 354 | this.getContentPane().add(accuracyNumber); 355 | 356 | JLabel accracyText = new JLabel("Accuracy:"); 357 | accracyText.setFont(new Font("custom", Font.BOLD,18)); 358 | accracyText.setBounds(330,300,150,30); 359 | this.getContentPane().add(accracyText); 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | JButton getPredict = new JButton(LanguageSetting.a3); 368 | getPredict.setBounds(20 ,370,280,20); 369 | getPredict.addActionListener(new ActionListener() { 370 | @Override 371 | public void actionPerformed(ActionEvent e) { 372 | System.out.println("已经更新获取的数据"); 373 | Processing.getDigit(GUI.this); 374 | } 375 | }); 376 | this.getContentPane().add(getPredict); 377 | 378 | 379 | // 设置随机图片按钮 380 | JButton randomButton = new JButton(LanguageSetting.a2); 381 | randomButton.addActionListener(new ActionListener() { 382 | @Override 383 | public void actionPerformed(ActionEvent e) { 384 | for(int x=0;x<28;x++) 385 | { 386 | for(int y=0;y<28;y++) 387 | { 388 | int randColor = (int)(Math.random()*255.00); 389 | 390 | newJPanel.getComponentAt(x*10,y*10).setBackground(new Color(randColor,randColor,randColor)); 391 | } 392 | } 393 | } 394 | }); 395 | randomButton.setBounds(20 +280 - 130 ,310,130,30); 396 | this.getContentPane().add(randomButton); 397 | 398 | // 加入0~9预测 399 | { 400 | JPanel predictPanel = new JPanel(); 401 | predictPanel.setLayout(null); 402 | Font font = new Font("custom", Font.BOLD,20); 403 | JLabel digit0 = new JLabel("[0]"); 404 | JLabel digit1 = new JLabel("[1]"); 405 | JLabel digit2 = new JLabel("[2]"); 406 | JLabel digit3 = new JLabel("[3]"); 407 | JLabel digit4 = new JLabel("[4]"); 408 | JLabel digit5 = new JLabel("[5]"); 409 | JLabel digit6 = new JLabel("[6]"); 410 | JLabel digit7 = new JLabel("[7]"); 411 | JLabel digit8 = new JLabel("[8]"); 412 | JLabel digit9 = new JLabel("[9]"); 413 | JLabel preDigit0 = new JLabel("0.00%"); 414 | JLabel preDigit1 = new JLabel("0.00%"); 415 | JLabel preDigit2 = new JLabel("0.00%"); 416 | JLabel preDigit3 = new JLabel("0.00%"); 417 | JLabel preDigit4 = new JLabel("0.00%"); 418 | JLabel preDigit5 = new JLabel("0.00%"); 419 | JLabel preDigit6 = new JLabel("0.00%"); 420 | JLabel preDigit7 = new JLabel("0.00%"); 421 | JLabel preDigit8 = new JLabel("0.00%"); 422 | JLabel preDigit9 = new JLabel("0.00%"); 423 | digit0.setFont(font); 424 | digit1.setFont(font); 425 | digit2.setFont(font); 426 | digit3.setFont(font); 427 | digit4.setFont(font); 428 | digit5.setFont(font); 429 | digit6.setFont(font); 430 | digit7.setFont(font); 431 | digit8.setFont(font); 432 | digit9.setFont(font); 433 | 434 | preDigit0.setFont(font); 435 | preDigit1.setFont(font); 436 | preDigit2.setFont(font); 437 | preDigit3.setFont(font); 438 | preDigit4.setFont(font); 439 | preDigit5.setFont(font); 440 | preDigit6.setFont(font); 441 | preDigit7.setFont(font); 442 | preDigit8.setFont(font); 443 | preDigit9.setFont(font); 444 | 445 | 446 | int YBIAS = 30; 447 | digit0.setBounds(0,0*YBIAS,35,24); 448 | digit1.setBounds(0,1*YBIAS,35,24); 449 | digit2.setBounds(0,2*YBIAS,35,24); 450 | digit3.setBounds(0,3*YBIAS,35,24); 451 | digit4.setBounds(0,4*YBIAS,35,24); 452 | digit5.setBounds(0,5*YBIAS,35,24); 453 | digit6.setBounds(0,6*YBIAS,35,24); 454 | digit7.setBounds(0,7*YBIAS,35,24); 455 | digit8.setBounds(0,8*YBIAS,35,24); 456 | digit9.setBounds(0,9*YBIAS,35,24); 457 | preDigit0.setBounds(35,0*YBIAS,140,24); 458 | preDigit1.setBounds(35,1*YBIAS,140,24); 459 | preDigit2.setBounds(35,2*YBIAS,140,24); 460 | preDigit3.setBounds(35,3*YBIAS,140,24); 461 | preDigit4.setBounds(35,4*YBIAS,140,24); 462 | preDigit5.setBounds(35,5*YBIAS,140,24); 463 | preDigit6.setBounds(35,6*YBIAS,140,24); 464 | preDigit7.setBounds(35,7*YBIAS,140,24); 465 | preDigit8.setBounds(35,8*YBIAS,140,24); 466 | preDigit9.setBounds(35,9*YBIAS,140,24); 467 | 468 | predictPanel.add(digit0); 469 | predictPanel.add(digit1); 470 | predictPanel.add(digit2); 471 | predictPanel.add(digit3); 472 | predictPanel.add(digit4); 473 | predictPanel.add(digit5); 474 | predictPanel.add(digit6); 475 | predictPanel.add(digit7); 476 | predictPanel.add(digit8); 477 | predictPanel.add(digit9); 478 | 479 | predictPanel.add(preDigit0); 480 | predictPanel.add(preDigit1); 481 | predictPanel.add(preDigit2); 482 | predictPanel.add(preDigit3); 483 | predictPanel.add(preDigit4); 484 | predictPanel.add(preDigit5); 485 | predictPanel.add(preDigit6); 486 | predictPanel.add(preDigit7); 487 | predictPanel.add(preDigit8); 488 | predictPanel.add(preDigit9); 489 | 490 | 491 | predictPanel.setBounds(20,400,300,480); 492 | this.getContentPane().add(predictPanel); 493 | } 494 | 495 | // 加入参数调整 496 | { 497 | JLabel lambdaText = new JLabel (LanguageSetting.a12); 498 | JLabel alphaText = new JLabel (LanguageSetting.a14); 499 | JLabel stopVarText = new JLabel (LanguageSetting.a15); 500 | JLabel miniBatchSize = new JLabel (LanguageSetting.a16); 501 | 502 | lambdaText.setFont(new Font("customize",Font.BOLD,17)); 503 | alphaText.setFont(new Font("customize",Font.BOLD,17)); 504 | stopVarText.setFont(new Font("customize",Font.BOLD,17)); 505 | miniBatchSize.setFont(new Font("customize",Font.BOLD,17)); 506 | 507 | lambdaText.setBounds(330,470,350,20); 508 | alphaText.setBounds(330,530,350,20); 509 | stopVarText.setBounds(330,590,350,20); 510 | miniBatchSize.setBounds(330,650,350,20); 511 | 512 | this.getContentPane().add(lambdaText); 513 | this.getContentPane().add(alphaText); 514 | this.getContentPane().add(stopVarText); 515 | this.getContentPane().add(miniBatchSize); 516 | 517 | JTextField lambda = new JTextField(); 518 | lambda.setText(String.valueOf(Processing.LAMBDA_L1)); 519 | int constant1 = 30; 520 | lambda.setFont(new Font("customize",Font.BOLD,17)); 521 | lambda.setBounds(330,470+constant1,100,25); 522 | this.getContentPane().add(lambda); 523 | 524 | JTextField alpha = new JTextField(); 525 | alpha.setText(String.valueOf(Processing.ALPHA)); 526 | alpha.setFont(new Font("customize",Font.BOLD,17)); 527 | alpha.setBounds(330,530+constant1,100,25); 528 | this.getContentPane().add(alpha); 529 | 530 | JTextField stopEpoch = new JTextField(); 531 | stopEpoch.setText(String.valueOf(Processing.EPOCH_TOTAL)); 532 | stopEpoch.setFont(new Font("customize",Font.BOLD,17)); 533 | stopEpoch.setBounds(330,590+constant1,100,25); 534 | this.getContentPane().add(stopEpoch); 535 | 536 | JTextField miniBatch = new JTextField(); 537 | miniBatch.setText(String.valueOf(Processing.DEFINE_BATCH_SIZE)); 538 | miniBatch.setFont(new Font("customize",Font.BOLD,17)); 539 | miniBatch.setBounds(330,650+constant1,100,25); 540 | this.getContentPane().add(miniBatch); 541 | 542 | 543 | JButton update1 = new JButton(LanguageSetting.a13); 544 | update1.addActionListener(new ActionListener() { 545 | @Override 546 | public void actionPerformed(ActionEvent e) { 547 | System.out.println("已更改正则化L1、L2 Lambda为"+lambda.getText()); 548 | lambda.setForeground(Color.black); 549 | Processing.LAMBDA_L1 = Double.parseDouble(lambda.getText()); 550 | Processing.LAMBDA_L2 = Double.parseDouble(lambda.getText()); 551 | } 552 | }); 553 | update1.setFont(new Font("customize",Font.BOLD,17)); 554 | update1.setBounds(330+115,470+constant1,100,25); 555 | this.getContentPane().add(update1); 556 | 557 | JButton update2 = new JButton(LanguageSetting.a13); 558 | update2.addActionListener(new ActionListener() { 559 | @Override 560 | public void actionPerformed(ActionEvent e) { 561 | alpha.setForeground(Color.black); 562 | Processing.ALPHA = Double.parseDouble(alpha.getText()); 563 | System.out.println("已更改学习率Alpha为"+Processing.ALPHA); 564 | } 565 | }); 566 | update2.setFont(new Font("customize",Font.BOLD,17)); 567 | update2.setBounds(330+115,530+constant1,100,25); 568 | this.getContentPane().add(update2); 569 | 570 | JButton update3 = new JButton(LanguageSetting.a13); 571 | update3.addActionListener(new ActionListener() { 572 | @Override 573 | public void actionPerformed(ActionEvent e) { 574 | System.out.println("已更改Epoch轮数为"+stopEpoch.getText()); 575 | stopEpoch.setForeground(Color.black); 576 | Processing.EPOCH_TOTAL = (int) Double.parseDouble(stopEpoch.getText()); 577 | } 578 | }); 579 | update3.setFont(new Font("customize",Font.BOLD,17)); 580 | update3.setBounds(330+115,590+constant1,100,25); 581 | this.getContentPane().add(update3); 582 | 583 | JButton update4 = new JButton(LanguageSetting.a13); 584 | update4.addActionListener(new ActionListener() { 585 | @Override 586 | public void actionPerformed(ActionEvent e) { 587 | System.out.println("已更改Batch值为"+miniBatch.getText()); 588 | miniBatch.setForeground(Color.black); 589 | Processing.EPOCH_TOTAL = (int) Double.parseDouble(miniBatch.getText()); 590 | } 591 | }); 592 | update4.setFont(new Font("customize",Font.BOLD,17)); 593 | update4.setBounds(330+115,650+constant1,100,25); 594 | this.getContentPane().add(update4); 595 | 596 | 597 | 598 | } 599 | 600 | 601 | this.setVisible(true); 602 | } 603 | 604 | // 设置浓度滑条 605 | { 606 | // 创建滑条(范围:0-100,初始值:0) 607 | JSlider slider = new JSlider(0, 100, 0); 608 | slider.setMajorTickSpacing(10); // 每10个单位显示一个大刻度 609 | slider.setMinorTickSpacing(1); // 每1个单位显示一个小刻度 610 | slider.setPaintTicks(true); // 显示刻度线 611 | slider.setPaintLabels(true); // 显示刻度标签 612 | slider.setValue(91); 613 | 614 | slider.addMouseListener(new MouseAdapter() { 615 | @Override 616 | public void mouseReleased(MouseEvent e) { 617 | CONSTANT_K = (101-slider.getValue())*35; 618 | 619 | } 620 | }); 621 | 622 | // 创建按钮 623 | JButton defButton = new JButton(LanguageSetting.a4); 624 | 625 | // 添加按钮点击事件 626 | 627 | defButton.addActionListener(new ActionListener() { 628 | 629 | @Override 630 | public void actionPerformed(ActionEvent e) { 631 | System.out.println("已还原为默认值!"); 632 | CONSTANT_K = 300; 633 | } 634 | }); 635 | // 创建容器JPanel并设置布局 636 | JPanel panel = new JPanel(); 637 | panel.setLayout(new FlowLayout()); 638 | 639 | panel.setBounds(330,10,200,200); 640 | // 将滑条和按钮添加到JPanel 641 | panel.add(slider); 642 | panel.add(defButton); 643 | 644 | // 将panel添加到JFrame(假设你已经创建了JFrame) 645 | add(panel); 646 | 647 | } 648 | 649 | 650 | } 651 | 652 | private void getRandomImgFromFile() throws IOException { 653 | JPanel newJPanel = (JPanel) this.getContentPane().getComponentAt(20,20); 654 | Random rand = new Random(); 655 | double[][] array = DevelopTool.imgToArray("resources/TrainImg/"+rand.nextInt(GUI.BATCH_SET_SIZE)+".png"); 656 | for(int x=0;x<28;x++) 657 | { 658 | for (int y=0;y<28;y++) 659 | { 660 | int RGB = (int)(array[x][y]*255.00); 661 | newJPanel.getComponentAt(y*10,x*10).setBackground(new Color(RGB,RGB,RGB)); 662 | } 663 | } 664 | } 665 | 666 | 667 | private void updatePanel(int x, int y, int dx, int dy, JPanel newJPanel) { 668 | // 找不到零件就返回 669 | if (newJPanel.getComponentAt(x + dx, y + dy) == null) return; 670 | int RGB = (int) ((dx * dx + dy * dy / 50.0) / (double) CONSTANT_K * 255.0); // 非常小的一个数字,例如0.21 671 | int originalRGB = newJPanel.getComponentAt(x + dx, y + dy).getBackground().getRed(); //0, 或者150这种 672 | if (originalRGB + RGB > 255) { 673 | newJPanel.getComponentAt(x + dx, y + dy).setBackground(new Color(255,255,255)); 674 | return; 675 | } 676 | 677 | newJPanel.getComponentAt(x + dx, y + dy).setBackground(new Color(originalRGB + RGB, originalRGB + RGB, originalRGB + RGB)); 678 | } 679 | 680 | private void showJDialog1() { 681 | // 创建对话框 682 | JDialog dialog = new JDialog(this, "警告", true); 683 | 684 | // 创建面板用于放置组件,这里设置为null布局,以便进行绝对定位 685 | JPanel panel = new JPanel(null); 686 | 687 | JLabel text = new JLabel("你确定要这么做吗?可能会等待较长时间"); 688 | 689 | JButton yesButton = new JButton("Yes"); 690 | yesButton.addActionListener(new ActionListener() { 691 | @Override 692 | public void actionPerformed(ActionEvent e) { 693 | dialog.setVisible(false); 694 | dialog.dispose(); 695 | GUI.this.setVisible(false); 696 | try { 697 | DevelopTool.regenerateResourcesFile(); 698 | } catch (IOException ex) { 699 | throw new RuntimeException(ex); 700 | } 701 | GUI.this.setVisible(true); 702 | } 703 | }); 704 | 705 | JButton noButton = new JButton("No"); 706 | noButton.addActionListener(new ActionListener() { 707 | @Override 708 | public void actionPerformed(ActionEvent e) { 709 | dialog.dispose(); 710 | } 711 | }); 712 | 713 | // 设置按钮的大小和位置 714 | yesButton.setBounds(10, 30, 80, 30); 715 | noButton.setBounds(150, 30, 80, 30); 716 | 717 | // 设置文本标签的位置和大小,这里简单设置了下,可根据实际调整 718 | text.setBounds(10, 10, 300, 20); 719 | 720 | // 将组件添加到面板中 721 | panel.add(text); 722 | panel.add(yesButton); 723 | panel.add(noButton); 724 | 725 | // 将面板添加到对话框中 726 | dialog.add(panel); 727 | 728 | // 设置对话框大小 729 | dialog.setSize(275, 110); 730 | 731 | // 设置对话框相对于当前窗口的位置 732 | dialog.setLocationRelativeTo(this); 733 | 734 | // 设置对话框始终在最上层显示 735 | dialog.setAlwaysOnTop(true); 736 | // 不可resize 737 | dialog.setResizable(false); 738 | 739 | // 让对话框可见 740 | dialog.setVisible(true); 741 | } 742 | } 743 | -------------------------------------------------------------------------------- /src/functions/LanguageSetting.java: -------------------------------------------------------------------------------- 1 | package functions; 2 | 3 | public class LanguageSetting { 4 | private LanguageSetting() {} 5 | public static String a1; 6 | public static String a2; 7 | public static String a3; 8 | public static String a4; 9 | public static String a5; 10 | public static String a6; 11 | public static String a7; 12 | public static String a8; 13 | public static String a9; 14 | public static String a10; 15 | public static String a11; 16 | public static String a12; 17 | public static String a13; 18 | public static String a14; 19 | public static String a15; 20 | public static String a16; 21 | public static String b1; 22 | public static String b2; 23 | public static String b3; 24 | public static String c1; 25 | 26 | public static void runChineseText(){ 27 | a1 = "重置画布"; 28 | a2 = "随机画布"; 29 | a3 = "预测此图像"; 30 | a4 = "还原为默认值"; 31 | a5 = "开始训练"; 32 | a6 = "设置训练样本量"; 33 | a7 = "从训练集中获取图像"; 34 | a8 = "以验证集验证模型准确度"; 35 | a9 = "停止训练"; 36 | a10 = "导出当前模型到桌面"; 37 | a11 = "读取先前模型"; 38 | a12 = "正则化L1 & L2 值 ->"; 39 | a13 = "更改"; 40 | a14 = "学习率 Alpha 值->"; 41 | a15 = "设置总训练轮数 ->"; 42 | a16 = "设置 Batch 值 ->"; 43 | b1 = "设置"; 44 | b2 = "重新刷新资源文件"; 45 | b3 = "清除所有图片和缓存数据库(.txt数据)"; 46 | c1 = "Java手写数字识别"; 47 | } 48 | public static void runEnglishText() { 49 | a1 = "Reset Canvas"; 50 | a2 = "Random Canvas"; 51 | a3 = "Predict This Image"; 52 | a4 = "Restore to Default Values"; 53 | a5 = "Start Training"; 54 | a6 = "Set the Training Sample Size"; 55 | a7 = "Get Images from the Training Set"; 56 | a8 = "Verify the Model Accuracy with the Validation Set"; 57 | a9 = "Stop Training"; 58 | a10 = "Export the Current Model to the Desktop"; 59 | a11 = "Read the Previous Model"; 60 | a12 = "Regularization L1 & L2 Values ->"; 61 | a13 = "Change"; 62 | a14 = "Learning Rate Alpha Value ->"; 63 | a15 = "Set the Total Training Times ->"; 64 | a16 = "Set the Batch Value ->"; 65 | b1 = "Settings"; 66 | b2 = "Refresh Resource Files"; 67 | b3 = "Clear All Images and Cache Databases(.txt files)"; 68 | c1 = "Java Hand Writing Digits Recognition"; 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/functions/MNISTReader.java: -------------------------------------------------------------------------------- 1 | package functions; 2 | 3 | import java.io.*; 4 | import java.util.ArrayList; 5 | 6 | /** 7 | * This class implements a reader for the MNIST dataset of handwritten digits. The dataset is found 8 | * at http://yann.lecun.com/exdb/mnist/. 9 | * 10 | * Author: Gabe Johnson 11 | * 12 | * Modified by: James Thornton & Zhibin Jiang 13 | */ 14 | public class MNISTReader { 15 | 16 | 17 | /** 18 | * @param args 19 | * args[0]: label file; args[1]: data file. 20 | * @throws IOException 21 | */ 22 | 23 | public static void main(String[] args) throws IOException { 24 | 25 | } 26 | 27 | /** 28 | * @throws IOException 29 | */ 30 | 31 | 32 | public static ArrayList getMNISTTrainDataFromFile() throws IOException { 33 | String[] args = {"resources/database/train-labels-idx1-ubyte","resources/database/train-images-idx3-ubyte"}; 34 | System.out.println("Retrieving MNIST data..."); 35 | 36 | DataInputStream labels = new DataInputStream(new FileInputStream(args[0])); 37 | DataInputStream images = new DataInputStream(new FileInputStream(args[1])); 38 | 39 | int magicNumber = labels.readInt(); 40 | if (magicNumber != 2049) { 41 | System.err.println("Label file has wrong magic number: " + magicNumber + " (should be 2049)"); 42 | System.exit(0); 43 | } 44 | magicNumber = images.readInt(); 45 | if (magicNumber != 2051) { 46 | System.err.println("Image file has wrong magic number: " + magicNumber + " (should be 2051)"); 47 | System.exit(0); 48 | } 49 | 50 | int numLabels = labels.readInt(); 51 | int numImages = images.readInt(); 52 | int numRows = images.readInt(); 53 | int numCols = images.readInt(); 54 | if (numLabels != numImages) { 55 | System.err.println("Image file and label file do not contain the same number of entries."); 56 | System.err.println("Label file contains: " + numLabels); 57 | System.err.println("Image file contains: " + numImages); 58 | System.exit(0); 59 | } 60 | 61 | long start = System.currentTimeMillis(); 62 | int numLabelsRead = 0; 63 | int numImagesRead = 0; 64 | 65 | ArrayList result = new ArrayList(); 66 | 67 | int counter = 0; 68 | while (labels.available() > 0 && numLabelsRead < numLabels) { 69 | counter++; 70 | 71 | if(counter%1000==0) 72 | { 73 | System.out.println("[Reading from training database. Process] \t"+String.format("%.2f",((double)counter/60000.00*100.00))+"%"); 74 | } 75 | 76 | 77 | //if(counter==3) break; // functions.test in case 78 | byte label = labels.readByte(); 79 | numLabelsRead++; 80 | double[][] image = new double[numCols+1][numRows]; 81 | for (int colIdx = 0; colIdx < numCols; colIdx++) { 82 | for (int rowIdx = 0; rowIdx < numRows; rowIdx++) { 83 | image[colIdx][rowIdx] = images.readUnsignedByte()/256.0; 84 | } 85 | } 86 | numImagesRead++; 87 | 88 | double[] labelArray = new double[1]; 89 | labelArray[0] = label; 90 | image[numCols] = labelArray; 91 | 92 | result.add(image); 93 | } 94 | // At this point, 'label' and 'image' agree and you can do whatever you like with them. 95 | labels.close(); 96 | images.close(); 97 | return result; 98 | } 99 | public static ArrayList getMNISTVerifyDataFromFile() throws IOException { 100 | String[] args = {"resources/database/t10k-labels.idx1-ubyte","resources/database/t10k-images.idx3-ubyte"}; 101 | System.out.println("Retrieving MNIST data..."); 102 | 103 | DataInputStream labels = new DataInputStream(new FileInputStream(args[0])); 104 | DataInputStream images = new DataInputStream(new FileInputStream(args[1])); 105 | 106 | int magicNumber = labels.readInt(); 107 | if (magicNumber != 2049) { 108 | System.err.println("Label file has wrong magic number: " + magicNumber + " (should be 2049)"); 109 | System.exit(0); 110 | } 111 | magicNumber = images.readInt(); 112 | if (magicNumber != 2051) { 113 | System.err.println("Image file has wrong magic number: " + magicNumber + " (should be 2051)"); 114 | System.exit(0); 115 | } 116 | 117 | int numLabels = labels.readInt(); 118 | int numImages = images.readInt(); 119 | int numRows = images.readInt(); 120 | int numCols = images.readInt(); 121 | if (numLabels != numImages) { 122 | System.err.println("Image file and label file do not contain the same number of entries."); 123 | System.err.println("Label file contains: " + numLabels); 124 | System.err.println("Image file contains: " + numImages); 125 | System.exit(0); 126 | } 127 | 128 | long start = System.currentTimeMillis(); 129 | int numLabelsRead = 0; 130 | int numImagesRead = 0; 131 | 132 | ArrayList result = new ArrayList(); 133 | 134 | int counter = 0; 135 | while (labels.available() > 0 && numLabelsRead < numLabels) { 136 | counter++; 137 | 138 | if(counter%1000==0) 139 | { 140 | System.out.println("[Reading from verify database. Process] \t"+String.format("%.2f",((double)counter/10000.00*100.00))+"%"); 141 | } 142 | 143 | 144 | //if(counter==3) break; // functions.test in case 145 | byte label = labels.readByte(); 146 | numLabelsRead++; 147 | double[][] image = new double[numCols+1][numRows]; 148 | for (int colIdx = 0; colIdx < numCols; colIdx++) { 149 | for (int rowIdx = 0; rowIdx < numRows; rowIdx++) { 150 | image[colIdx][rowIdx] = images.readUnsignedByte()/256.0; 151 | } 152 | } 153 | numImagesRead++; 154 | 155 | double[] labelArray = new double[1]; 156 | labelArray[0] = label; 157 | image[numCols] = labelArray; 158 | 159 | result.add(image); 160 | } 161 | // At this point, 'label' and 'image' agree and you can do whatever you like with them. 162 | labels.close(); 163 | images.close(); 164 | return result; 165 | } 166 | 167 | } -------------------------------------------------------------------------------- /src/functions/Processing.java: -------------------------------------------------------------------------------- 1 | package functions; 2 | 3 | 4 | import javax.swing.*; 5 | import java.awt.*; 6 | import java.io.*; 7 | import java.util.ArrayList; 8 | import java.util.Random; 9 | 10 | // 历史最小:0.030220112 11 | // Updated: 2025/1/10 12 | public class Processing { 13 | 14 | /** 15 | * 默认变量 16 | */ 17 | // 输出模型文件夹 18 | static final String EXPORT_PART = System.getProperty("user.home") + "\\Desktop\\" + "Model_Data.txt"; 19 | // 读取模型文件 20 | static final String IMPORT_PATH = "resources/model/Model_Data.txt"; 21 | /** 22 | * 三层神经网络,结构组成如下: 23 | * 输入层:X,用于接收外部输入数据。 24 | * 隐藏层:A 层与 B 层,负责对输入数据进行处理和特征提取。 25 | * 输出层:Y,输出最终的网络计算结果。 26 | * 网络包含四层结构,配置有三权重、三偏置用于调整计算逻辑。 27 | */ 28 | static double[][] weightXA = new double[784][20]; 29 | static double[][] weightAB = new double[20][20]; 30 | static double[][] weightBY = new double[20][10]; 31 | static double[] biasA = new double[20]; 32 | static double[] biasB = new double[20]; 33 | static double[] biasY = new double[10]; 34 | /** 35 | * 这里是Adam优化器的数据 36 | */ 37 | static double BETA1 = 0.9; 38 | static double BETA2 = 0.999; 39 | static double ALPHA = 0.001; 40 | static double EPSILON = 1.0E-08; 41 | static double LAMBDA_L1 = 0.001; 42 | static double LAMBDA_L2 = 0.001; 43 | 44 | static double[][] weightXA_Mt = new double[784][20]; 45 | static double[][] weightXA_Vt = new double[784][20]; 46 | static int weightXA_Iteration = 0; 47 | static double[][] weightAB_Mt = new double[20][20]; 48 | static double[][] weightAB_Vt = new double[20][20]; 49 | static int weightAB_Iteration = 0; 50 | static double[][] weightBY_Mt = new double[20][10]; 51 | static double[][] weightBY_Vt = new double[20][10]; 52 | static int weightBY_Iteration = 0; 53 | static double[] biasA_Mt = new double[20]; 54 | static double[] biasA_Vt = new double[20]; 55 | static int biasA_Iteration = 0; 56 | static double[] biasB_Mt = new double[20]; 57 | static double[] biasB_Vt = new double[20]; 58 | static int biasB_Iteration = 0; 59 | static double[] biasY_Mt = new double[20]; 60 | static double[] biasY_Vt = new double[20]; 61 | static int biasY_Iteration = 0; 62 | // 总训练轮数 63 | static int EPOCH_TOTAL = 5000; 64 | // 每个Batch大小 65 | static int DEFINE_BATCH_SIZE = 64; 66 | 67 | /** 68 | * 其余储存变量 69 | */ 70 | // 储存最近10次的损失函数值 71 | static ArrayList lossValueHistory = new ArrayList<>(); 72 | 73 | static { 74 | initialize(); // 初始化 75 | } 76 | 77 | 78 | /** 79 | * run()用于开始训练模型” 80 | */ 81 | public static void run() throws IOException { 82 | System.out.println("开始训练..."); 83 | System.out.println("总计样本量" + GUI.BATCH_SET_SIZE); 84 | System.out.println("Batch = " + DEFINE_BATCH_SIZE); 85 | GUI.stopButtonPressed = false; 86 | Long start = System.currentTimeMillis(); 87 | Random rand = new Random(0); 88 | 89 | ArrayList allBatch = DevelopTool.getTrainBatch(); 90 | allBatch = normalization(allBatch); // 预处理 91 | 92 | int EPOCH_CURRENT = 0; 93 | if (allBatch.size() > 10000 && EPOCH_TOTAL != EPOCH_CURRENT) { 94 | System.err.println(String.format("[警告] 数据集过大[%d]", allBatch.size()) + ",每轮Epoch可能时间较长,请耐心等待"); 95 | } 96 | double lossValue = getLossValueForBatchSet(allBatch); // 当前损失函数。在while执行完更新 97 | 98 | // 停止按钮和Epoch到达训练两个判断条件 99 | while (!GUI.stopButtonPressed && EPOCH_TOTAL > EPOCH_CURRENT) { 100 | 101 | System.out.println(String.format("损失函数[%.16f]", lossValue) + String.format("\t最近10次损失方差[%.4f] ×10^-4", lossFunctionVar() * 10000.00) + "\tEpoch = " + (++EPOCH_CURRENT) + String.format("进度[%.2f%%]", (double) EPOCH_CURRENT / (double) EPOCH_TOTAL * 100.00)); 102 | ArrayList batch = new ArrayList<>(); 103 | // 这个for运行完之后,就代表一个EPOCH结束 104 | for (int i = 0; i < GUI.BATCH_SET_SIZE / DEFINE_BATCH_SIZE; i++) { 105 | for (int j = 0; j < DEFINE_BATCH_SIZE; j++) { 106 | // 投入图像 107 | int randomIndex = rand.nextInt(GUI.BATCH_SET_SIZE); 108 | batch.add(allBatch.get(randomIndex)); 109 | } 110 | trainBatch(batch); 111 | batch.clear(); 112 | //new Scanner(System.in).nextLine(); 113 | } 114 | // 更新lossValueHistory 115 | lossValue = getLossValueForBatchSet(allBatch); 116 | if (lossValueHistory.size() > 10) lossValueHistory.removeFirst(); 117 | lossValueHistory.add(lossValue); 118 | } 119 | record(); 120 | Long end = System.currentTimeMillis(); 121 | System.out.println("用时:" + String.format("%.3f", (double) (end - start) / 1000.00) + "s"); 122 | System.out.println("[训练完成!]"); 123 | GUI.stopButtonPressed = false; 124 | } 125 | 126 | /** 127 | * 用于对一个batch进行梯度下降 128 | */ 129 | private static void trainBatch(ArrayList batch) { 130 | // 这里先计算所有batch的网络 131 | // 在同一个batch中,所用到的权重、偏置、神经元数据都必须是上一batch的。 132 | // 因此在这里我们复制了全局中的权重、偏置、神经元数据,然后对于更新完的权重,我们存在这里 133 | // 但是在权重更新过程中我们不适用这里的weight,而是weight 134 | // 在整个batch更新完之后,我们才给weight 赋值为 我们更新的 weight 135 | double[][] XList = new double[batch.size()][784]; 136 | double[][] AList = new double[batch.size()][20]; 137 | double[][] BList = new double[batch.size()][20]; 138 | double[][] YList = new double[batch.size()][10]; 139 | 140 | double[][] gradient_weightXA = new double[784][20]; 141 | double[][] gradient_weightAB = new double[20][20]; 142 | double[][] gradient_weightBY = new double[20][10]; 143 | double[] gradient_biasA = new double[20]; 144 | double[] gradient_biasB = new double[20]; 145 | double[] gradient_biasY = new double[10]; 146 | 147 | ++weightXA_Iteration; 148 | ++weightAB_Iteration; 149 | ++weightBY_Iteration; 150 | ++biasA_Iteration; 151 | ++biasB_Iteration; 152 | ++biasY_Iteration; 153 | 154 | for (int numberOfBatch = 0; numberOfBatch < batch.size(); numberOfBatch++) { 155 | double[] X = new double[784]; 156 | double[] A = new double[20]; 157 | double[] B = new double[20]; 158 | double[] Y = new double[10]; 159 | double[][] currentImg = batch.get(numberOfBatch); 160 | // 存到X 161 | int counter = 0; 162 | for (int x = 0; x < 28; x++) { 163 | for (int y = 0; y < 28; y++) { 164 | X[counter++] = currentImg[x][y]; 165 | } 166 | } 167 | // 计算隐藏层A 168 | for (int i = 0; i < 20; i++) { 169 | for (int left = 0; left < 784; left++) { 170 | A[i] += X[left] * weightXA[left][i]; 171 | } 172 | // 添加偏置后,压入sigmoid 173 | A[i] += biasA[i]; 174 | A[i] = sigmoid(A[i]); 175 | } 176 | // 计算隐藏层B 177 | for (int i = 0; i < 20; i++) { 178 | for (int left = 0; left < 20; left++) { 179 | B[i] += A[left] * weightAB[left][i]; 180 | } 181 | B[i] += biasB[i]; 182 | B[i] = sigmoid(B[i]); 183 | } 184 | // 计算Y 185 | for (int i = 0; i < 10; i++) { 186 | for (int left = 0; left < 20; left++) { 187 | Y[i] += B[left] * weightBY[left][i]; 188 | } 189 | Y[i] += biasY[i]; 190 | Y[i] = sigmoid(Y[i]); 191 | } 192 | XList[numberOfBatch] = arrayCopy(X); 193 | AList[numberOfBatch] = arrayCopy(A); 194 | BList[numberOfBatch] = arrayCopy(B); 195 | YList[numberOfBatch] = arrayCopy(Y); 196 | } 197 | 198 | // 设置梯度 199 | for (int n = 0; n < batch.size(); n++) { 200 | ArrayList storage1 = new ArrayList<>(); // 0~10 201 | ArrayList storage2 = new ArrayList<>(); // 0~20 : i 202 | ArrayList storage3 = new ArrayList<>(); // 0~20 : r 203 | 204 | // biasY 205 | double label = batch.get(n)[28][0]; 206 | for (int j = 0; j < 10; j++) { 207 | double ygt; 208 | if (label == j) ygt = 1.0; 209 | else ygt = 0.0; 210 | double value = (YList[n][j] - ygt) * YList[n][j] * (1.0 - YList[n][j]); 211 | gradient_biasY[j] += value; 212 | storage1.add(value); 213 | } 214 | // weightBY 215 | for (int i = 0; i < 20; i++) { 216 | double temp = 0; // 添加temp 217 | for (int j = 0; j < 10; j++) { 218 | double value = storage1.get(j) * BList[n][i]; 219 | gradient_weightBY[i][j] += value; 220 | 221 | temp += storage1.get(j) * weightBY[i][j] * BList[n][i] * (1.0 - BList[n][i]); // 累加 222 | } 223 | storage2.add(temp); 224 | } 225 | // biasB 226 | for (int i = 0; i < 20; i++) { 227 | gradient_biasB[i] += storage2.get(i); 228 | } 229 | // weightAB 230 | for (int r = 0; r < 20; r++) { 231 | double temp = 0; 232 | for (int i = 0; i < 20; i++) { 233 | gradient_weightAB[r][i] += storage2.get(i) * AList[n][r]; 234 | temp += storage2.get(i) * AList[n][r] * weightAB[r][i] * (1.0 - AList[n][r]); 235 | } 236 | storage3.add(temp); 237 | } 238 | // biasA 239 | for (int r = 0; r < 20; r++) { 240 | gradient_biasA[r] += storage3.get(r); 241 | } 242 | 243 | // weightXA 244 | for (int l = 0; l < 784; l++) { 245 | for (int r = 0; r < 20; r++) { 246 | gradient_weightXA[l][r] += storage3.get(r) * XList[n][l]; 247 | } 248 | } 249 | } 250 | 251 | 252 | // 给梯度添加L1和L2正则化 // 这有必要吗 253 | { 254 | // gradient_weightBY 255 | for (int i = 0; i < 20; i++) { 256 | for (int j = 0; j < 10; j++) { 257 | 258 | gradient_weightBY[i][j] += 2 * LAMBDA_L2 * weightBY[i][j] + LAMBDA_L1 * sign(weightBY[i][j]); 259 | } 260 | } 261 | // gradient_weightAB 262 | for (int right = 0; right < 20; right++) { 263 | for (int i = 0; i < 20; i++) { 264 | gradient_weightAB[right][i] += 2 * LAMBDA_L2 * weightAB[right][i] + LAMBDA_L1 * sign(weightAB[right][i]); 265 | } 266 | } 267 | // gradient_weightXA 268 | for (int left = 0; left < 784; left++) { 269 | for (int right = 0; right < 20; right++) { 270 | gradient_weightXA[left][right] += 2 * LAMBDA_L2 * weightXA[left][right] + LAMBDA_L1 * sign(weightXA[left][right]); 271 | } 272 | } 273 | } 274 | 275 | 276 | // Adam 优化器对权重和偏置进行下降 277 | 278 | // biasY 279 | for (int j = 0; j < 10; j++) { 280 | biasY_Mt[j] = BETA1 * biasY_Mt[j] + (1.0 - BETA1) * gradient_biasY[j]; 281 | biasY_Vt[j] = BETA2 * biasY_Vt[j] + (1.0 - BETA2) * gradient_biasY[j] * gradient_biasY[j]; 282 | double MtHat = biasY_Mt[j] / (1.0 - Math.pow(BETA1, biasY_Iteration)); 283 | double VtHat = biasY_Vt[j] / (1.0 - Math.pow(BETA2, biasY_Iteration)); 284 | biasY[j] = biasY[j] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON); 285 | } 286 | // weight_BY 287 | for (int i = 0; i < 20; i++) { 288 | for (int j = 0; j < 10; j++) { 289 | weightBY_Mt[i][j] = BETA1 * weightBY_Mt[i][j] + (1.0 - BETA1) * gradient_weightBY[i][j]; 290 | weightBY_Vt[i][j] = BETA2 * weightBY_Vt[i][j] + (1.0 - BETA2) * gradient_weightBY[i][j] * gradient_weightBY[i][j]; 291 | double MtHat = weightBY_Mt[i][j] / (1.0 - Math.pow(BETA1, weightBY_Iteration)); 292 | double VtHat = weightBY_Vt[i][j] / (1.0 - Math.pow(BETA2, weightBY_Iteration)); 293 | weightBY[i][j] = weightBY[i][j] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON); 294 | 295 | } 296 | } 297 | // biasB 298 | for (int i = 0; i < 20; i++) { 299 | biasB_Mt[i] = BETA1 * biasB_Mt[i] + (1.0 - BETA1) * gradient_biasB[i]; 300 | biasB_Vt[i] = BETA2 * biasB_Vt[i] + (1.0 - BETA2) * gradient_biasB[i] * gradient_biasB[i]; 301 | 302 | double MtHat = biasB_Mt[i] / (1.0 - Math.pow(BETA1, biasB_Iteration)); 303 | double VtHat = biasB_Vt[i] / (1.0 - Math.pow(BETA2, biasB_Iteration)); 304 | 305 | biasB[i] = biasB[i] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON); 306 | } 307 | // weightAB 308 | for (int right = 0; right < 20; right++) { 309 | for (int i = 0; i < 20; i++) { 310 | weightAB_Mt[right][i] = BETA1 * weightAB_Mt[right][i] + (1.0 - BETA1) * gradient_weightAB[right][i]; 311 | weightAB_Vt[right][i] = BETA2 * weightAB_Vt[right][i] + (1.0 - BETA2) * gradient_weightAB[right][i] * gradient_weightAB[right][i]; 312 | 313 | double MtHat = weightAB_Mt[right][i] / (1.0 - Math.pow(BETA1, weightAB_Iteration)); 314 | double VtHat = weightAB_Vt[right][i] / (1.0 - Math.pow(BETA2, weightAB_Iteration)); 315 | 316 | weightAB[right][i] = weightAB[right][i] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON); 317 | } 318 | } 319 | // biasA 320 | for (int right = 0; right < 20; right++) { 321 | biasA_Mt[right] = BETA1 * biasA_Mt[right] + (1.0 - BETA1) * gradient_biasA[right]; 322 | biasA_Vt[right] = BETA2 * biasA_Vt[right] + (1.0 - BETA2) * gradient_biasA[right] * gradient_biasA[right]; 323 | 324 | double MtHat = biasA_Mt[right] / (1.0 - Math.pow(BETA1, biasA_Iteration)); 325 | double VtHat = biasA_Vt[right] / (1.0 - Math.pow(BETA2, biasA_Iteration)); 326 | 327 | biasA[right] = biasA[right] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON); 328 | } 329 | // weightXA 330 | for (int left = 0; left < 784; left++) { 331 | for (int right = 0; right < 20; right++) { 332 | weightXA_Mt[left][right] = BETA1 * weightXA_Mt[left][right] + (1.0 - BETA1) * gradient_weightXA[left][right]; 333 | weightXA_Vt[left][right] = BETA2 * weightXA_Vt[left][right] + (1.0 - BETA2) * gradient_weightXA[left][right] * gradient_weightXA[left][right]; 334 | 335 | double MtHat = weightXA_Mt[left][right] / (1.0 - Math.pow(BETA1, weightXA_Iteration)); 336 | double VtHat = weightXA_Vt[left][right] / (1.0 - Math.pow(BETA2, weightXA_Iteration)); 337 | 338 | weightXA[left][right] = weightXA[left][right] - (ALPHA * MtHat) / (Math.sqrt(VtHat) + EPSILON); 339 | } 340 | } 341 | 342 | } 343 | 344 | /** 345 | * 获取GUI的数字,并且输出判断 346 | */ 347 | public static void getDigit(JFrame frame) { 348 | JPanel newPanel = (JPanel) frame.getContentPane().getComponentAt(20, 20); 349 | // 黑色是0,白色是1 350 | double[][] array = new double[28][28]; 351 | for (int x = 0; x < 28; x++) { 352 | for (int y = 0; y < 28; y++) { 353 | array[x][y] = (double) (newPanel.getComponentAt(y * 10, x * 10).getBackground().getRed()) / 255.0; 354 | } 355 | } 356 | array = normalization(array); 357 | // 重新绘制 358 | for (int x = 0; x < 28; x++) { 359 | for (int y = 0; y < 28; y++) { 360 | int RBG = (int) (255.00 * array[x][y]); 361 | if (RBG < 0) RBG = 0; 362 | newPanel.getComponentAt(y * 10, x * 10).setBackground(new Color(RBG, RBG, RBG)); 363 | } 364 | } 365 | 366 | 367 | double[] X = new double[784]; 368 | double[] A = new double[20]; 369 | double[] B = new double[20]; 370 | double[] Y = new double[10]; 371 | double[][] currentImg = array; 372 | // 存到X 373 | int counter = 0; 374 | for (int x = 0; x < 28; x++) { 375 | for (int y = 0; y < 28; y++) { 376 | X[counter++] = currentImg[x][y]; 377 | } 378 | } 379 | // 计算隐藏层A 380 | for (int i = 0; i < 20; i++) { 381 | for (int left = 0; left < 784; left++) { 382 | A[i] += X[left] * weightXA[left][i]; 383 | } 384 | // 添加偏置后,压入sigmoid 385 | A[i] += biasA[i]; 386 | A[i] = sigmoid(A[i]); 387 | } 388 | // 计算隐藏层B 389 | for (int i = 0; i < 20; i++) { 390 | for (int left = 0; left < 20; left++) { 391 | B[i] += A[left] * weightAB[left][i]; 392 | } 393 | B[i] += biasB[i]; 394 | B[i] = sigmoid(B[i]); 395 | } 396 | // 计算Y 397 | for (int i = 0; i < 10; i++) { 398 | for (int left = 0; left < 20; left++) { 399 | Y[i] += B[left] * weightBY[left][i]; 400 | } 401 | Y[i] += biasY[i]; 402 | Y[i] = sigmoid(Y[i]); 403 | } 404 | 405 | JPanel predictPanel = (JPanel) frame.getContentPane().getComponentAt(20, 400); 406 | for (int i = 0; i < 10; i++) { 407 | String temp = String.format("%.2f", Y[i] * 100) + "%"; 408 | JLabel tempLabel = (JLabel) predictPanel.getComponentAt(35, i * 30); 409 | JLabel tempLabel2 = (JLabel) predictPanel.getComponentAt(0, i * 30); 410 | tempLabel.setText(temp); 411 | int constant = 100; 412 | int RGB = 255 - (constant + (int) (Y[i] * (255.0 - (double) constant))); 413 | tempLabel.setForeground(new Color(RGB, RGB, RGB)); 414 | tempLabel2.setForeground(new Color(RGB, RGB, RGB)); 415 | } 416 | int maxIndex = 0; 417 | for (int i = 0; i < 10; i++) { 418 | if (Y[i] > Y[maxIndex]) { 419 | maxIndex = i; 420 | } 421 | } 422 | JLabel tempLabel = (JLabel) predictPanel.getComponentAt(35, maxIndex * 30); 423 | JLabel tempLabel2 = (JLabel) predictPanel.getComponentAt(0, maxIndex * 30); 424 | tempLabel.setForeground(Color.RED); 425 | tempLabel2.setForeground(Color.RED); 426 | 427 | 428 | } 429 | 430 | /** 431 | * 验证模型准确度 432 | */ 433 | public static void VerifyAccuracy(JFrame jframe) throws IOException { 434 | int totalCorrectNumber = 0; 435 | ArrayList allBatch = DevelopTool.getVerifyBatch(); 436 | allBatch = normalization(allBatch); 437 | int loopCounter = 0; 438 | for (int numberOfBatch = 0; numberOfBatch < allBatch.size(); numberOfBatch++) { 439 | ++loopCounter; 440 | 441 | double[] X = new double[784]; 442 | double[] A = new double[20]; 443 | double[] B = new double[20]; 444 | double[] Y = new double[10]; 445 | double[][] currentImg = allBatch.get(numberOfBatch); 446 | // 存到X 447 | int counter = 0; 448 | for (int x = 0; x < 28; x++) { 449 | for (int y = 0; y < 28; y++) { 450 | X[counter++] = currentImg[x][y]; 451 | } 452 | } 453 | // 计算隐藏层A 454 | for (int i = 0; i < 20; i++) { 455 | for (int left = 0; left < 784; left++) { 456 | A[i] += X[left] * weightXA[left][i]; 457 | } 458 | // 添加偏置后,压入sigmoid 459 | A[i] += biasA[i]; 460 | A[i] = sigmoid(A[i]); 461 | } 462 | // 计算隐藏层B 463 | for (int i = 0; i < 20; i++) { 464 | for (int left = 0; left < 20; left++) { 465 | B[i] += A[left] * weightAB[left][i]; 466 | } 467 | B[i] += biasB[i]; 468 | B[i] = sigmoid(B[i]); 469 | } 470 | // 计算Y 471 | for (int i = 0; i < 10; i++) { 472 | for (int left = 0; left < 20; left++) { 473 | Y[i] += B[left] * weightBY[left][i]; 474 | } 475 | Y[i] += biasY[i]; 476 | Y[i] = sigmoid(Y[i]); 477 | } 478 | 479 | 480 | JPanel paintingPanel = (JPanel) jframe.getContentPane().getComponentAt(20, 20); 481 | // 重新绘制画板 482 | for (int x = 0; x < 28; x++) { 483 | for (int y = 0; y < 28; y++) { 484 | int RGB = (int) (currentImg[x][y] * 255.00); 485 | paintingPanel.getComponentAt(y * 10, x * 10).setBackground(new Color(RGB, RGB, RGB)); 486 | } 487 | } 488 | // 重新绘制digit 489 | JPanel predictPanel = (JPanel) jframe.getContentPane().getComponentAt(20, 400); 490 | for (int i = 0; i < 10; i++) { 491 | String temp = String.format("%.2f", Y[i] * 100) + "%"; 492 | JLabel tempLabel = (JLabel) predictPanel.getComponentAt(35, i * 30); 493 | JLabel tempLabel2 = (JLabel) predictPanel.getComponentAt(0, i * 30); 494 | tempLabel.setText(temp); 495 | int constant = 100; 496 | int RGB = 255 - (constant + (int) (Y[i] * (255.0 - (double) constant))); 497 | tempLabel.setForeground(new Color(RGB, RGB, RGB)); 498 | tempLabel2.setForeground(new Color(RGB, RGB, RGB)); 499 | } 500 | // 获取最大值Index 501 | int maxIndex = 0; 502 | for (int i = 0; i < 10; i++) { 503 | if (Y[i] > Y[maxIndex]) { 504 | maxIndex = i; 505 | } 506 | } 507 | int label = (int) allBatch.get(numberOfBatch)[28][0]; 508 | 509 | if (label == maxIndex) ++totalCorrectNumber; 510 | 511 | JLabel countNum = (JLabel) jframe.getContentPane().getComponentAt(430, 280); 512 | countNum.setText("" + numberOfBatch); 513 | 514 | 515 | JLabel accuracyValue = (JLabel) jframe.getContentPane().getComponentAt(430, 320); 516 | accuracyValue.setText(String.format("%.2f", (double) totalCorrectNumber / (double) loopCounter * 100.00) + "%"); 517 | 518 | 519 | JLabel temp = (JLabel) predictPanel.getComponentAt(35, maxIndex * 30); 520 | JLabel temp2 = (JLabel) predictPanel.getComponentAt(0, maxIndex * 30); 521 | temp.setForeground(Color.RED); 522 | temp2.setForeground(Color.RED); 523 | 524 | JLabel tempLabel = (JLabel) predictPanel.getComponentAt(35, label * 30); 525 | JLabel tempLabel2 = (JLabel) predictPanel.getComponentAt(0, label * 30); 526 | tempLabel.setForeground(Color.GREEN); 527 | tempLabel2.setForeground(Color.GREEN); 528 | } 529 | } 530 | 531 | 532 | // 复制数组函数 533 | private static double[][] arrayCopy(double[][] sourceArray) { 534 | double[][] returnArray = new double[sourceArray.length][sourceArray[0].length]; 535 | for (int i = 0; i < sourceArray.length; i++) { 536 | System.arraycopy(sourceArray[i], 0, returnArray[i], 0, sourceArray[0].length); 537 | } 538 | return returnArray; 539 | } 540 | 541 | private static double[] arrayCopy(double[] sourceArray) { 542 | double[] returnArray = new double[sourceArray.length]; 543 | System.arraycopy(sourceArray, 0, returnArray, 0, sourceArray.length); 544 | return returnArray; 545 | } 546 | 547 | 548 | // sign函数 549 | private static double sign(double weight) { 550 | if (weight < 0) return -1.0; 551 | else if (weight > 0) return 1.0; 552 | else return 0.0; 553 | } 554 | 555 | 556 | // 损失函数方差 557 | private static double lossFunctionVar() { 558 | double finalResult = 0.0; 559 | for (int i = 0; i < lossValueHistory.size(); i++) { 560 | finalResult += lossValueHistory.get(i); 561 | } 562 | return finalResult / (double) lossValueHistory.size(); 563 | } 564 | 565 | 566 | // 计算BatchSet总损失 567 | private static double getLossValueForBatchSet(ArrayList allBatch) { 568 | double loss = 0.0; 569 | // 遍历整个BatchSet 570 | for (int numberOfBatch = 0; numberOfBatch < allBatch.size(); numberOfBatch++) { 571 | double[][] currentBatch = allBatch.get(numberOfBatch); 572 | double[] X = new double[784]; 573 | double[] A = new double[20]; 574 | double[] B = new double[20]; 575 | double[] Y = new double[10]; 576 | // 存到X 577 | int counter = 0; 578 | for (int x = 0; x < 28; x++) { 579 | for (int y = 0; y < 28; y++) { 580 | X[counter++] = currentBatch[x][y]; 581 | } 582 | } 583 | // 计算隐藏层A 584 | for (int i = 0; i < 20; i++) { 585 | for (int left = 0; left < 784; left++) { 586 | A[i] += X[left] * weightXA[left][i]; 587 | } 588 | // 添加偏置后,压入sigmoid 589 | A[i] += biasA[i]; 590 | A[i] = sigmoid(A[i]); 591 | } 592 | // 计算隐藏层B 593 | for (int i = 0; i < 20; i++) { 594 | for (int left = 0; left < 20; left++) { 595 | B[i] += A[left] * weightAB[left][i]; 596 | } 597 | B[i] += biasB[i]; 598 | B[i] = sigmoid(B[i]); 599 | } 600 | // 计算Y 601 | for (int i = 0; i < 10; i++) { 602 | for (int left = 0; left < 20; left++) { 603 | Y[i] += B[left] * weightBY[left][i]; 604 | } 605 | Y[i] += biasY[i]; 606 | Y[i] = sigmoid(Y[i]); 607 | } 608 | 609 | // 计算损失 610 | int label = (int) currentBatch[28][0]; 611 | double miniBatchLoss = 0; 612 | for (int i = 0; i < 10; i++) { 613 | double possibility; 614 | if (i == label) possibility = 1.0; 615 | else possibility = 0.0; 616 | miniBatchLoss += Math.pow(possibility - Y[i], 2); 617 | } 618 | loss += miniBatchLoss; 619 | } 620 | return loss / (double) allBatch.size(); 621 | } 622 | 623 | 624 | // sigmoid 函数 625 | private static double sigmoid(double x) { 626 | return 1.00 / (Math.exp(-x) + 1.00); 627 | } 628 | 629 | 630 | // 预处理 631 | public static ArrayList normalization(ArrayList allBatch) { 632 | // 先处理图像 633 | for (int numberOfBatch = 0; numberOfBatch < allBatch.size(); numberOfBatch++) { 634 | allBatch.set(numberOfBatch, resizeImg(allBatch.get(numberOfBatch))); 635 | } 636 | 637 | // 然后归一化 638 | System.out.println("[Normalization Start...]"); 639 | ArrayList normalizedBatch = new ArrayList<>(); 640 | for (int i = 0; i < allBatch.size(); i++) { 641 | double[][] currentBatch = allBatch.get(i); 642 | double max = currentBatch[0][0]; 643 | double min = currentBatch[0][0]; 644 | for (int x = 0; x < currentBatch.length - 1; x++) { 645 | for (int y = 0; y < currentBatch[0].length; y++) { 646 | if (currentBatch[x][y] > max) { 647 | max = currentBatch[x][y]; 648 | } 649 | if (currentBatch[x][y] < min) { 650 | min = currentBatch[x][y]; 651 | } 652 | } 653 | } 654 | for (int x = 0; x < currentBatch.length - 1; x++) { 655 | for (int y = 0; y < currentBatch[0].length; y++) { 656 | currentBatch[x][y] = (currentBatch[x][y] - min) / (max - min); 657 | } 658 | } 659 | normalizedBatch.add(currentBatch); 660 | } 661 | System.out.println("[Normalization Finished]"); 662 | return normalizedBatch; 663 | 664 | } 665 | 666 | public static double[][] normalization(double[][] img) { 667 | // 先处理图像 668 | img = resizeImg(img); 669 | 670 | // 然后归一化 671 | double max = img[0][0]; 672 | double min = img[0][0]; 673 | for (int x = 0; x < img.length - 1; x++) { 674 | for (int y = 0; y < img[0].length; y++) { 675 | if (img[x][y] > max) { 676 | max = img[x][y]; 677 | } 678 | if (img[x][y] != 0 && img[x][y] < min) { 679 | min = img[x][y]; 680 | } 681 | } 682 | } 683 | for (int x = 0; x < img.length - 1; x++) { 684 | for (int y = 0; y < img[0].length; y++) { 685 | img[x][y] = (img[x][y] - min) / (max - min); 686 | } 687 | } 688 | return img; 689 | 690 | } 691 | 692 | 693 | private static double[][] resizeImg(double[][] array) { 694 | int left = -1; 695 | int right = -1; 696 | int top = -1; 697 | int bottom = -1; 698 | 699 | for (int i = 0; i < 28; i++) { 700 | if (checkColumn(array, i)) { 701 | left = i; 702 | break; 703 | } 704 | } 705 | for (int i = 27; i >= 0; i--) { 706 | if (checkColumn(array, i)) { 707 | right = i; 708 | break; 709 | } 710 | } 711 | 712 | for (int i = 0; i < 28; i++) { 713 | if (checkRow(array, i)) { 714 | top = i; 715 | break; 716 | } 717 | } 718 | 719 | for (int i = 27; i >= 0; i--) { 720 | if (checkRow(array, i)) { 721 | bottom = i; 722 | break; 723 | } 724 | } 725 | 726 | 727 | if (left == -1 || right == -1 || top == -1 || bottom == -1) { 728 | return new double[29][28]; 729 | } 730 | 731 | double[][] B; 732 | int height = Math.abs(bottom - top) + 1; 733 | int weight = Math.abs(left - right) + 1; 734 | double[][] A = new double[height][weight]; 735 | for (int x = top; x <= bottom; x++) { 736 | if (right + 1 - left >= 0) System.arraycopy(array[x], left, A[x - top], 0, right + 1 - left); 737 | } 738 | 739 | 740 | if (height >= weight) { 741 | B = new double[28][(int) (Math.round((weight) * (28.00 / (height))))]; 742 | } else { 743 | B = new double[(int) (Math.round((height) * (28.00 / (weight))))][28]; 744 | } 745 | 746 | for (int x = 0; x < B.length; x++) { 747 | for (int y = 0; y < B[0].length; y++) { 748 | int w1 = weight; 749 | int h1 = height; 750 | int b1 = B.length; 751 | int b2 = B[0].length; 752 | int castX = (int) ((double) x * ((double) height / (double) B.length)); 753 | int castY = (int) ((double) y * ((double) weight / (double) B[0].length)); 754 | B[x][y] = A[castX][castY]; 755 | } 756 | } 757 | for (int x = 0; x < 28; x++) { 758 | for (int y = 0; y < 28; y++) { 759 | array[x][y] = 0.0; 760 | } 761 | } 762 | if (B.length == 28) { 763 | int bias = (int) (Math.round((28.0 - (double) B[0].length) / 2.0)); 764 | for (int x = 0; x < 28; x++) { 765 | System.arraycopy(B[x], 0, array[x], bias, B[0].length); 766 | } 767 | } else if (B[0].length == 28) { 768 | int bias = (int) (Math.round((28.0 - (double) B.length) / 2.0)); 769 | for (int x = 0; x < B.length; x++) { 770 | System.arraycopy(B[x], 0, array[x + bias], 0, 28); 771 | } 772 | } 773 | return array; 774 | } 775 | 776 | private static boolean checkColumn(double[][] array, int column) { 777 | for (int x = 0; x < 28; x++) { 778 | if (array[x][column] != 0) return true; 779 | } 780 | return false; 781 | } 782 | 783 | private static boolean checkRow(double[][] array, int row) { 784 | for (int x = 0; x < 28; x++) { 785 | if (array[row][x] != 0) { 786 | return true; 787 | } 788 | } 789 | return false; 790 | } 791 | 792 | /** 793 | * record()用于导出模型到桌面 794 | */ 795 | public static void record() throws IOException { 796 | // 如果有了就删除,以便覆盖 797 | File file = new File(EXPORT_PART); 798 | if (file.exists()) { 799 | file.delete(); 800 | } 801 | 802 | BufferedWriter writer = new BufferedWriter(new FileWriter(EXPORT_PART, true)); 803 | // 权重 804 | for (int i = 0; i < weightXA.length; i++) { 805 | for (int j = 0; j < weightXA[i].length; j++) { 806 | writer.write(weightXA[i][j] + "\n"); 807 | } 808 | } 809 | for (int i = 0; i < weightAB.length; i++) { 810 | for (int j = 0; j < weightAB[i].length; j++) { 811 | writer.write(weightAB[i][j] + "\n"); 812 | } 813 | } 814 | for (int i = 0; i < weightBY.length; i++) { 815 | for (int j = 0; j < weightBY[i].length; j++) { 816 | writer.write(weightBY[i][j] + "\n"); 817 | } 818 | } 819 | // 偏置 820 | for (int i = 0; i < biasA.length; i++) { 821 | writer.write(biasA[i] + "\n"); 822 | } 823 | for (int i = 0; i < biasB.length; i++) { 824 | writer.write(biasB[i] + "\n"); 825 | } 826 | for (int i = 0; i < biasY.length; i++) { 827 | writer.write(biasY[i] + "\n"); 828 | } 829 | // 权重的Adam数据 830 | for (int i = 0; i < weightXA_Mt.length; i++) { 831 | for (int j = 0; j < weightXA_Mt[i].length; j++) { 832 | writer.write(weightXA_Mt[i][j] + "\n"); 833 | } 834 | } 835 | for (int i = 0; i < weightXA_Vt.length; i++) { 836 | for (int j = 0; j < weightXA_Vt[i].length; j++) { 837 | writer.write(weightXA_Vt[i][j] + "\n"); 838 | } 839 | } 840 | for (int i = 0; i < weightAB_Mt.length; i++) { 841 | for (int j = 0; j < weightAB_Mt[i].length; j++) { 842 | writer.write(weightAB_Mt[i][j] + "\n"); 843 | } 844 | } 845 | for (int i = 0; i < weightAB_Vt.length; i++) { 846 | for (int j = 0; j < weightAB_Vt[i].length; j++) { 847 | writer.write(weightAB_Vt[i][j] + "\n"); 848 | } 849 | } 850 | for (int i = 0; i < weightBY_Mt.length; i++) { 851 | for (int j = 0; j < weightBY_Mt[i].length; j++) { 852 | writer.write(weightBY_Mt[i][j] + "\n"); 853 | } 854 | } 855 | for (int i = 0; i < weightBY_Vt.length; i++) { 856 | for (int j = 0; j < weightBY_Vt[i].length; j++) { 857 | writer.write(weightBY_Vt[i][j] + "\n"); 858 | } 859 | } 860 | // 偏置的Adam数据 861 | for (int i = 0; i < biasA_Mt.length; i++) { 862 | writer.write(biasA_Mt[i] + "\n"); 863 | } 864 | for (int i = 0; i < biasA_Vt.length; i++) { 865 | writer.write(biasA_Vt[i] + "\n"); 866 | } 867 | for (int i = 0; i < biasB_Mt.length; i++) { 868 | writer.write(biasB_Mt[i] + "\n"); 869 | } 870 | for (int i = 0; i < biasB_Vt.length; i++) { 871 | writer.write(biasB_Vt[i] + "\n"); 872 | } 873 | for (int i = 0; i < biasY_Mt.length; i++) { 874 | writer.write(biasY_Mt[i] + "\n"); 875 | } 876 | for (int i = 0; i < biasY_Vt.length; i++) { 877 | writer.write(biasY_Vt[i] + "\n"); 878 | } 879 | // 迭代数 880 | writer.write(weightXA_Iteration + "\n"); 881 | writer.write(weightAB_Iteration + "\n"); 882 | writer.write(weightBY_Iteration + "\n"); 883 | writer.write(biasA_Iteration + "\n"); 884 | writer.write(biasB_Iteration + "\n"); 885 | writer.write(biasY_Iteration + "\n"); 886 | 887 | writer.close(); 888 | } 889 | 890 | /** 891 | * readModel()用于读取model文件夹的模型 892 | */ 893 | public static void readModel() throws IOException { 894 | BufferedReader reader = new BufferedReader(new FileReader(IMPORT_PATH)); 895 | // 权重 896 | for (int i = 0; i < weightXA.length; i++) { 897 | for (int j = 0; j < weightXA[i].length; j++) { 898 | weightXA[i][j] = Double.parseDouble(reader.readLine()); 899 | } 900 | } 901 | for (int i = 0; i < weightAB.length; i++) { 902 | for (int j = 0; j < weightAB[i].length; j++) { 903 | weightAB[i][j] = Double.parseDouble(reader.readLine()); 904 | } 905 | } 906 | for (int i = 0; i < weightBY.length; i++) { 907 | for (int j = 0; j < weightBY[i].length; j++) { 908 | weightBY[i][j] = Double.parseDouble(reader.readLine()); 909 | } 910 | } 911 | // 偏置 912 | for (int i = 0; i < biasA.length; i++) { 913 | biasA[i] = Double.parseDouble(reader.readLine()); 914 | } 915 | for (int i = 0; i < biasB.length; i++) { 916 | biasB[i] = Double.parseDouble(reader.readLine()); 917 | } 918 | for (int i = 0; i < biasY.length; i++) { 919 | biasY[i] = Double.parseDouble(reader.readLine()); 920 | } 921 | // 权重的Adam数据 922 | for (int i = 0; i < weightXA_Mt.length; i++) { 923 | for (int j = 0; j < weightXA_Mt[i].length; j++) { 924 | weightXA_Mt[i][j] = Double.parseDouble(reader.readLine()); 925 | } 926 | } 927 | for (int i = 0; i < weightXA_Vt.length; i++) { 928 | for (int j = 0; j < weightXA_Vt[i].length; j++) { 929 | weightXA_Vt[i][j] = Double.parseDouble(reader.readLine()); 930 | } 931 | } 932 | for (int i = 0; i < weightAB_Mt.length; i++) { 933 | for (int j = 0; j < weightAB_Mt[i].length; j++) { 934 | weightAB_Mt[i][j] = Double.parseDouble(reader.readLine()); 935 | } 936 | } 937 | for (int i = 0; i < weightAB_Vt.length; i++) { 938 | for (int j = 0; j < weightAB_Vt[i].length; j++) { 939 | weightAB_Vt[i][j] = Double.parseDouble(reader.readLine()); 940 | } 941 | } 942 | for (int i = 0; i < weightBY_Mt.length; i++) { 943 | for (int j = 0; j < weightBY_Mt[i].length; j++) { 944 | weightBY_Mt[i][j] = Double.parseDouble(reader.readLine()); 945 | } 946 | } 947 | for (int i = 0; i < weightBY_Vt.length; i++) { 948 | for (int j = 0; j < weightBY_Vt[i].length; j++) { 949 | weightBY_Vt[i][j] = Double.parseDouble(reader.readLine()); 950 | } 951 | } 952 | // 偏置的Adam数据 953 | for (int i = 0; i < biasA_Mt.length; i++) { 954 | biasA_Mt[i] = Double.parseDouble(reader.readLine()); 955 | } 956 | for (int i = 0; i < biasA_Vt.length; i++) { 957 | biasA_Vt[i] = Double.parseDouble(reader.readLine()); 958 | } 959 | for (int i = 0; i < biasB_Mt.length; i++) { 960 | biasB_Mt[i] = Double.parseDouble(reader.readLine()); 961 | } 962 | for (int i = 0; i < biasB_Vt.length; i++) { 963 | biasB_Vt[i] = Double.parseDouble(reader.readLine()); 964 | } 965 | for (int i = 0; i < biasY_Mt.length; i++) { 966 | biasY_Mt[i] = Double.parseDouble(reader.readLine()); 967 | } 968 | for (int i = 0; i < biasY_Vt.length; i++) { 969 | biasY_Vt[i] = Double.parseDouble(reader.readLine()); 970 | } 971 | // 迭代数 972 | weightXA_Iteration = Integer.parseInt(reader.readLine()); 973 | weightAB_Iteration = Integer.parseInt(reader.readLine()); 974 | weightBY_Iteration = Integer.parseInt(reader.readLine()); 975 | biasA_Iteration = Integer.parseInt(reader.readLine()); 976 | biasB_Iteration = Integer.parseInt(reader.readLine()); 977 | biasY_Iteration = Integer.parseInt(reader.readLine()); 978 | 979 | reader.close(); 980 | } 981 | 982 | /** 983 | * 初始化 984 | */ 985 | private static void initialize() { 986 | weightXA = initializeWeights(784, 20); 987 | weightAB = initializeWeights(20, 20); 988 | weightBY = initializeWeights(20, 10); 989 | 990 | biasA = initializeBias(20); 991 | biasB = initializeBias(20); 992 | biasY = initializeBias(10); 993 | } 994 | 995 | public static double[][] initializeWeights(int fromSize, int toSize) { 996 | double[][] weights = new double[fromSize][toSize]; 997 | double limit = Math.sqrt(6.0 / fromSize); 998 | 999 | for (int i = 0; i < fromSize; i++) { 1000 | for (int j = 0; j < toSize; j++) { 1001 | Random random1 = new Random(); 1002 | Random random2 = new Random(random1.nextInt()); 1003 | // 使用均匀分布的He初始化 1004 | weights[i][j] = (random2.nextDouble() * 2.0 * limit - limit); 1005 | } 1006 | } 1007 | return weights; 1008 | } 1009 | 1010 | public static double[] initializeBias(int size) { 1011 | double[] bias = new double[size]; 1012 | 1013 | 1014 | for (int m = 0; m < size; m++) { 1015 | Random random1 = new Random(); 1016 | Random random2 = new Random(random1.nextInt()); 1017 | // 使用正态分布初始化偏置,均值为0,标准差为0.01,可根据实际调整 1018 | bias[m] = random2.nextGaussian() * 0.01; 1019 | } 1020 | return bias; 1021 | } 1022 | } 1023 | -------------------------------------------------------------------------------- /如何在Java中实现手写数字识别神经网络.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CH3COOK2023/Java-Hand-Writing-Digit-Recognition-Application/739629831545f76bb7337a2102e93c366c3f5663/如何在Java中实现手写数字识别神经网络.pdf --------------------------------------------------------------------------------