├── README.md ├── dataset ├── standardOutput │ └── digitLabels.txt └── weights │ ├── BoundingBoxRegressionAnn.w │ └── BoundingBoxRegressionCnn.w ├── images ├── BoundingBoxRegression.gif ├── after NMS before BBOX Regression.gif ├── example.bmp ├── files.jpg ├── multi predicted box.bmp └── predefine.jpg └── src ├── detection └── BoundingBox.java ├── gui ├── ConvolutionPanel.java ├── DrawingPanel.java ├── ExamplePanel.java ├── MainFrame.java └── PredictPanel.java ├── neuralNetwork ├── Convolution.java ├── DataSet.java ├── Network.java ├── Neuron.java └── Trainer.java └── tool ├── CreateForeOrBackgroundSample.java ├── CreatePictureForObjectDetectionFromMNIST.java ├── DataInfo.java └── ReadFile.java /README.md: -------------------------------------------------------------------------------- 1 | # Bounding-Box-Regression-GUI 2 | This program shows how Bounding-Box-Regression works in a visual form. 3 | ## Updating soon for Digit-Recognition-CNN-and-ANN-using-Mnist-with-GUI 4 | https://github.com/timmmGZ/Digit-Recognition-CNN-and-ANN-using-Mnist-with-GUI 5 | I am going to modify the above program in my free time, add this Bounding Box Regression along with ROI Align layer(https://github.com/timmmGZ/ROIAlign-Bounding-Box-ROI-Align-of-Mask-RCNN-GUI) in it, make it become a MNIST object detection. 6 | ## First of all, Let's see how it works 7 | ![image](https://github.com/timmmGZ/Bounding-Box-Regression-GUI/blob/master/images/BoundingBoxRegression.gif) 8 | ## Predefine 9 | Download the MNIST digit 60000 train set and 10000 test set here: 10 | https://drive.google.com/open?id=1VwABcxX0DaQakPpHbMaRQJHlJf3mVONf 11 | 1. Put both files to ../dataset/, and then go to "tool" package. 12 | 2. Run the "CreatePictureForObjectDetectionFromMNIST.java", it will create random pictures base on MNIST datasets in ../dataset/pictures. 13 | 3. After step 2, run the "CreateForeOrBackgroundSample.java", it will create random foreground(in ../dataset/foreground) and background(in ../dataset/background) datasets based on output pictures of step 2, you could see it more clearly in ../dataset/groundTruthExamples. 14 | as below: 15 | ![image](https://github.com/timmmGZ/Bounding-Box-Regression-GUI/blob/master/images/files.jpg) 16 | 17 | 4. Both step 2 and 3 will create Label-files in ../dataset/standardOutput, you could see the column names in first line of each Label-files 18 | 5. Make sure you have big enough RAM if you want to store more datasets in RAM, watch below picture: 19 | ![image](https://github.com/timmmGZ/Bounding-Box-Regression-GUI/blob/master/images/predefine.jpg) 20 | ## Start 21 | Run the MainFrame.java, if you don't want to train the model, click "Menu" then "Read Weight", that is my trained weights around 94% accuracy on both train and test set(70000 datasets in total, I define it is true prediction if Predicted-Bounding-Box has higher IOU with Ground-Truth-Bounding-Box than it has with input Bounding-Box). 22 | ## Warning 23 | Actually the number of predicted boxes should = the number of classes(e.g. one Bounding-Box can have both Apple and Bird inside it), but this program is just a Digit-Detection, for convenience and higher FPS in real-time detection, I make it have only one predicted box, see the advantage of having normal number of predicted boxes as below picture: 24 | ![image](https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/master/images/multi%20predicted%20box.bmp) 25 | ## Bounding-Box-Regression is used after NMS(Non Maximum Suppression) 26 | For each picture, we will get so many Bounding-Box(Region-Propasal), NMS is used for filtering out the best Bounding-Boxes, below gif shows what will look like if we only use ROI-Align-layer(not necessary but better use it) and NMS but not use Bounding-Box-Regression: 27 | Predefine about ROI-Align-layer: sample size=1, output size=7, feature maps=16(https://github.com/timmmGZ/ROIAlign-Bounding-Box-ROI-Align-of-Mask-RCNN-GUI) 28 | ![image](https://github.com/timmmGZ/Bounding-Box-Regression-GUI/blob/master/images/after%20NMS%20before%20BBOX%20Regression.gif) 29 | Let's cut one picture from the gif, when the object is small like below picture, and the number of objects is big, obviously we need to do Bounding-Box-Regression, or it will be a mess. 30 | ![image](https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/master/images/example.bmp) 31 | -------------------------------------------------------------------------------- /dataset/standardOutput/digitLabels.txt: -------------------------------------------------------------------------------- 1 | 0|1,0,0,0,0,0,0,0,0,0 2 | 1|0,1,0,0,0,0,0,0,0,0 3 | 2|0,0,1,0,0,0,0,0,0,0 4 | 3|0,0,0,1,0,0,0,0,0,0 5 | 4|0,0,0,0,1,0,0,0,0,0 6 | 5|0,0,0,0,0,1,0,0,0,0 7 | 6|0,0,0,0,0,0,1,0,0,0 8 | 7|0,0,0,0,0,0,0,1,0,0 9 | 8|0,0,0,0,0,0,0,0,1,0 10 | 9|0,0,0,0,0,0,0,0,0,1 -------------------------------------------------------------------------------- /dataset/weights/BoundingBoxRegressionAnn.w: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/56f7e03021b42919fc582bb2026f008d75cb88ae/dataset/weights/BoundingBoxRegressionAnn.w -------------------------------------------------------------------------------- /dataset/weights/BoundingBoxRegressionCnn.w: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/56f7e03021b42919fc582bb2026f008d75cb88ae/dataset/weights/BoundingBoxRegressionCnn.w -------------------------------------------------------------------------------- /images/BoundingBoxRegression.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/56f7e03021b42919fc582bb2026f008d75cb88ae/images/BoundingBoxRegression.gif -------------------------------------------------------------------------------- /images/after NMS before BBOX Regression.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/56f7e03021b42919fc582bb2026f008d75cb88ae/images/after NMS before BBOX Regression.gif -------------------------------------------------------------------------------- /images/example.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/56f7e03021b42919fc582bb2026f008d75cb88ae/images/example.bmp -------------------------------------------------------------------------------- /images/files.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/56f7e03021b42919fc582bb2026f008d75cb88ae/images/files.jpg -------------------------------------------------------------------------------- /images/multi predicted box.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/56f7e03021b42919fc582bb2026f008d75cb88ae/images/multi predicted box.bmp -------------------------------------------------------------------------------- /images/predefine.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/56f7e03021b42919fc582bb2026f008d75cb88ae/images/predefine.jpg -------------------------------------------------------------------------------- /src/detection/BoundingBox.java: -------------------------------------------------------------------------------- 1 | package detection; 2 | 3 | import java.util.List; 4 | 5 | public class BoundingBox { 6 | private int x, y, w, h; 7 | 8 | public BoundingBox(int cX, int cY, int W, int H) { 9 | x = cX; 10 | y = cY; 11 | w = W; 12 | h = H; 13 | } 14 | 15 | public BoundingBox(List pred) { 16 | this(pred.get(0), pred.get(1), pred.get(2), pred.get(3)); 17 | } 18 | 19 | public int getX() { 20 | return x; 21 | } 22 | 23 | public int getY() { 24 | return y; 25 | } 26 | 27 | public int getW() { 28 | return w; 29 | } 30 | 31 | public int getH() { 32 | return h; 33 | } 34 | 35 | public double getWScalar() { 36 | return w; 37 | } 38 | 39 | public double getHScalar() { 40 | return h; 41 | } 42 | 43 | public String toString() { 44 | return x + "," + y + "," + w + "," + h; 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/gui/ConvolutionPanel.java: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/56f7e03021b42919fc582bb2026f008d75cb88ae/src/gui/ConvolutionPanel.java -------------------------------------------------------------------------------- /src/gui/DrawingPanel.java: -------------------------------------------------------------------------------- 1 | package gui; 2 | 3 | import java.awt.*; 4 | import java.awt.event.*; 5 | import java.io.IOException; 6 | import java.util.Arrays; 7 | import java.util.List; 8 | import java.util.Random; 9 | import javax.swing.*; 10 | import neuralNetwork.*; 11 | import tool.DataInfo; 12 | 13 | public class DrawingPanel extends JPanel implements MouseMotionListener, MouseListener { 14 | /** 15 | * Github:timmmGZ 2020/02/25 16 | */ 17 | private static final long serialVersionUID = 1L; public static int pixelW = 13, pixelH = 13, resolution = 28; 18 | protected double[][] map, mapDrawing = new double[resolution][resolution]; 19 | public ConvolutionPanel convPanel; 20 | MainFrame mainFrame; 21 | boolean hasGrountruth = false; 22 | Convolution conv = new Convolution(); 23 | List pred, box, gT; 24 | 25 | public DrawingPanel(ConvolutionPanel cP, MainFrame m) throws IOException { 26 | map = new double[resolution][resolution]; 27 | convPanel = cP; 28 | setPreferredSize(new Dimension(364, 364)); 29 | setBackground(Color.WHITE); 30 | addMouseMotionListener(this); 31 | addMouseListener(this); 32 | mainFrame = m; 33 | } 34 | 35 | @Override 36 | protected void paintComponent(Graphics g) { 37 | super.paintComponent(g); 38 | for (int y = 0; y < resolution; y++) 39 | for (int x = 0; x < resolution; x++) { 40 | int grey = 255 - (int) (map[y][x] * 255); 41 | g.setColor(new Color(grey, grey, grey)); 42 | g.fillRect(x * pixelW, y * pixelH, pixelW, pixelH); 43 | } 44 | if (pred != null) 45 | drawBoundingBox(); 46 | } 47 | 48 | public DataSet drawFalseWord(List falseSet) { 49 | DataSet random = falseSet.get((int) (Math.random() * falseSet.size())); 50 | map = random.getInputs(); 51 | convPanel.setCurrentWordsConvDrawing(random); 52 | repaint(); 53 | return random; 54 | } 55 | 56 | public DataSet drawRandomWord(int type, int i) {// type=0->testSet, type=1->trainSet, type=2->both 57 | Random r = new Random(); 58 | hasGrountruth = true; 59 | DataSet random = i == -1 ? (type == 0 ? DataInfo.getTestSets().get(r.nextInt(mainFrame.testSize))// 60 | : (type == 1 ? DataInfo.getTrainSets().get(r.nextInt(mainFrame.trainSize)) 61 | : r.nextDouble() > 0.5 ? DataInfo.getTestSets().get(r.nextInt(mainFrame.testSize)) 62 | : DataInfo.getTrainSets().get(r.nextInt(mainFrame.trainSize)))) 63 | : (type == 0 ? DataInfo.getTestSets().get(i) 64 | : (type == 1 ? DataInfo.getTrainSets().get(i) 65 | : i < mainFrame.trainSize ? DataInfo.getTrainSets().get(i) 66 | : DataInfo.getTestSets().get(i - mainFrame.trainSize))); 67 | map = random.getInputs(); 68 | convPanel.setCurrentWordsConvDrawing(random); 69 | repaint(); 70 | return random; 71 | } 72 | 73 | void drawBoundingBox() { 74 | mainFrame.repaint(); 75 | int x = 145 + (int) (13.0 * ((pred.get(0) - box.get(0)) / (box.get(2) / 28.0))), 76 | y = 195 + (int) (13.0 * ((pred.get(1) - box.get(1)) / (box.get(3) / 28.0))), 77 | w = (int) (13.0 * (pred.get(2) / (box.get(2) / 28.0))), 78 | h = (int) (13.0 * (pred.get(3) / (box.get(3) / 28.0))); 79 | Graphics2D g = (Graphics2D) getRootPane().getGraphics(); 80 | g.setStroke(new BasicStroke(3.0f)); 81 | g.setFont(new Font("MS Song", Font.BOLD, 30)); 82 | if (hasGrountruth) { 83 | int xGT = 145 + (int) (13.0 * ((gT.get(0) - box.get(0)) / (box.get(2) / 28.0))), 84 | yGT = 195 + (int) (13.0 * ((gT.get(1) - box.get(1)) / (box.get(3) / 28.0))), 85 | wGT = (int) (13.0 * (gT.get(2) / (box.get(2) / 28.0))), 86 | hGT = (int) (13.0 * (gT.get(3) / (box.get(3) / 28.0))); 87 | g.setColor(new Color(0, 220, 0)); 88 | g.drawRect(xGT, yGT, wGT, hGT); 89 | g.drawString("Ground Truth box", xGT, yGT); 90 | } 91 | g.setColor(Color.ORANGE); 92 | g.drawRect(x, y, w, h); 93 | g.drawString("Predicted box", x, y); 94 | g.setColor(Color.BLUE); 95 | for (int i = 0; i < 1; i++) { 96 | g.drawLine(145, 195, 145+364, 195); 97 | g.drawLine(145, 195+364, 145+364, 195+364); 98 | g.drawLine(145, 195, 145, 195+364); 99 | g.drawLine(145+364, 195, 145+364, 195+364); 100 | g.drawRect(145, 195, 364, 364); 101 | g.drawString("Forehead box", 145, 195); 102 | } 103 | } 104 | 105 | private void setPixel(MouseEvent e) { 106 | try { 107 | hasGrountruth = false; 108 | mapDrawing[e.getY() / pixelH][e.getX() / pixelW] = SwingUtilities.isLeftMouseButton(e) ? 0.8 : 0; 109 | double[][] input = conv.createDrawing(mapDrawing); 110 | map = Convolution.normalization(input); 111 | for (int y = 0; y < map.length; y++) 112 | for (int x = 0; x < map[0].length; x++) 113 | input[y][x] = 1 / (1 + Math.exp(-input[y][x] * 2.5) * 255); 114 | convPanel.setCurrentWordsConvDrawing(new DataSet(input, conv)); 115 | repaint(); 116 | mainFrame.trainer.network.setAllNeuronsInputs(convPanel.getVector()); 117 | mainFrame.predict.setText("Prediction: "); 118 | mainFrame.accuracy.setText("Prediction of your drawing doesn't have true label"); 119 | box = Arrays.asList(0, 0, 28, 28); 120 | pred = mainFrame.boundingBoxDenormalization(mainFrame.network.getOutputs(), box); 121 | mainFrame.probPanel.updateProbs(box, pred, null); 122 | } catch (Exception exc) { 123 | } 124 | } 125 | 126 | @Override 127 | public void mouseDragged(MouseEvent e) { 128 | setPixel(e); 129 | } 130 | 131 | @Override 132 | public void mouseClicked(MouseEvent e) { 133 | setPixel(e); 134 | } 135 | 136 | @Override 137 | public void mouseEntered(MouseEvent e) { 138 | } 139 | 140 | @Override 141 | public void mouseExited(MouseEvent e) { 142 | } 143 | 144 | @Override 145 | public void mousePressed(MouseEvent e) { 146 | } 147 | 148 | @Override 149 | public void mouseReleased(MouseEvent e) { 150 | } 151 | 152 | @Override 153 | public void mouseMoved(MouseEvent e) { 154 | } 155 | 156 | public void clear() { 157 | map = new double[28][28]; 158 | mapDrawing = new double[28][28]; 159 | repaint(); 160 | convPanel.setCurrentWordsConvDrawing(new DataSet(map, conv)); 161 | } 162 | 163 | public void setBoundingBox(List b, List p, List gt) { 164 | pred = p; 165 | box = b; 166 | gT = gt; 167 | 168 | } 169 | 170 | } 171 | -------------------------------------------------------------------------------- /src/gui/ExamplePanel.java: -------------------------------------------------------------------------------- 1 | package gui; 2 | 3 | import java.awt.Dimension; 4 | import java.awt.Graphics; 5 | import java.awt.image.BufferedImage; 6 | import java.io.File; 7 | import java.io.IOException; 8 | 9 | import javax.imageio.ImageIO; 10 | import javax.swing.JFrame; 11 | import javax.swing.JPanel; 12 | 13 | import tool.ReadFile; 14 | 15 | public class ExamplePanel extends JPanel { 16 | /** 17 | * Github:timmmGZ 2020/02/25 18 | */ 19 | private static final long serialVersionUID = 1L; 20 | BufferedImage bi; 21 | 22 | public ExamplePanel(MainFrame m) throws IOException { 23 | ReadFile.resolutionW = 320; 24 | ReadFile.resolutionH = 240; 25 | setLayout(null); 26 | setPreferredSize(new Dimension(340, 260)); 27 | setBackground(MainFrame.lightBlue); 28 | setVisible(true); 29 | this.setFocusable(true); 30 | JFrame jf = new JFrame("Example picture in ../groundTruthExamples"); 31 | jf.add(this); 32 | jf.pack(); 33 | jf.setVisible(true); 34 | jf.setResizable(false); 35 | jf.setAlwaysOnTop(true); 36 | jf.setLocationRelativeTo(null); 37 | if (m.currentExampleID != null) 38 | setImage(m.currentExampleID); 39 | } 40 | 41 | @Override 42 | protected void paintComponent(Graphics g) { 43 | super.paintComponent(g); 44 | g.drawImage(bi, 15, 15, null); 45 | } 46 | 47 | public void setImage(String index) throws IOException { 48 | bi = ImageIO.read(new File("dataset/groundTruthExamples/" + index + ".jpg")); 49 | repaint(); 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/gui/MainFrame.java: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmmGZ/Bounding-Box-Regression-GUI/56f7e03021b42919fc582bb2026f008d75cb88ae/src/gui/MainFrame.java -------------------------------------------------------------------------------- /src/gui/PredictPanel.java: -------------------------------------------------------------------------------- 1 | package gui; 2 | 3 | import java.awt.*; 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | 7 | import javax.swing.*; 8 | 9 | public class PredictPanel extends JPanel { 10 | /** 11 | * Github:timmmGZ 2020/02/25 12 | */ 13 | private static final long serialVersionUID = 1L; 14 | MainFrame mainFrame; 15 | List texts = new ArrayList<>(); 16 | 17 | public PredictPanel(MainFrame frame) { 18 | setLayout(new BorderLayout()); 19 | mainFrame = frame; 20 | for (int i = 0; i < frame.outputSize; i++) 21 | texts.add(new JTextArea()); 22 | setLeft(); 23 | setRight(); 24 | } 25 | 26 | public void setLeft() { 27 | JPanel panel = new JPanel(); 28 | panel.setLayout(new GridLayout(0, 1)); 29 | panel.setBackground(new Color(0, 206, 209)); 30 | String[] classes = { "x: ", "y: ", "w: ", "h: " }; 31 | for (String word : classes) 32 | panel.add(new JLabel(word, SwingConstants.CENTER)); 33 | add(panel, BorderLayout.WEST); 34 | } 35 | 36 | public void setRight() { 37 | JPanel panel = new JPanel(); 38 | panel.setPreferredSize(new Dimension(100, 0)); 39 | panel.setLayout(new GridLayout(0, 1)); 40 | panel.setBackground(new Color(0, 206, 209)); 41 | for (JTextArea text : texts) { 42 | text.setBackground(new Color(0, 206, 209)); 43 | panel.add(text); 44 | } 45 | add(panel, BorderLayout.CENTER); 46 | } 47 | 48 | public void updateProbs(List b, List p, List gt) { 49 | for (int i = 0; i < b.size(); i++) { 50 | texts.get(i).setText("\n\n\n\n\nInput: " + b.get(i) + "\nPrediction: " + p.get(i) 51 | + (gt == null ? "" : "\nGround truth: " + gt.get(i))); 52 | } 53 | 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/neuralNetwork/Convolution.java: -------------------------------------------------------------------------------- 1 | package neuralNetwork; 2 | 3 | import java.util.Arrays; 4 | import java.util.List; 5 | import java.util.stream.Collectors; 6 | 7 | public class Convolution { 8 | public List filters = Arrays.asList(// 9 | new double[][] { { 1, 0, -1 }, { 2, 0, -2 }, { 1, 0, 1 } }, // | 10 | new double[][] { { 1, 2, 1 }, { 0, 0, 0 }, { -1, -2, -1 } }, // —— 11 | new double[][] { { 0, -1, 1 }, { 0, -1, 1 }, { 0, -1, 1 } }, // | 12 | new double[][] { { 0, 1, -1 }, { 0, 1, -1 }, { 0, 1, -1 } }, // | 13 | new double[][] { { 1, 0, -1 }, { 1, 0, -1 }, { 1, 0, -1 } }, // | 14 | // new double[][] { { -1, 0, 1 }, { -1, 0, 1 }, { -1, 0, 1 } }, // | 15 | new double[][] { { 1, -1, 0 }, { 1, -1, 0 }, { 1, -1, 0 } }, // | 16 | new double[][] { { -1, 1, 0 }, { -1, 1, 0 }, { -1, 1, 0 } }, // | 17 | new double[][] { { 0, 0, 0 }, { 1, 1, 1 }, { -1, -1, -1 } }, // —— 18 | new double[][] { { 0, 0, 0 }, { -1, -1, -1 }, { 1, 1, 1 } }, // —— 19 | // new double[][] { { 1, 1, 1 }, { 0, 0, 0 }, { -1, -1, -1 } }, // —— 20 | new double[][] { { -1, -1, -1 }, { 0, 0, 0 }, { 1, 1, 1 } }, // —— 21 | new double[][] { { 1, 1, 1 }, { -1, -1, -1 }, { 0, 0, 0 } }, // —— 22 | new double[][] { { -1, -1, -1 }, { 1, 1, 1 }, { 0, 0, 0 } }, // —— 23 | new double[][] { { 1, 1, 0 }, { 1, 0, -1 }, { 0, -1, -1 } }, /// 24 | new double[][] { { 0, 1, 1 }, { -1, 0, 1 }, { -1, -1, 0 } }, // \ 25 | new double[][] { { 0, -1, -1 }, { 1, 0, -1 }, { 1, 1, 0 } }, // \ 26 | new double[][] { { -1, -1, 0 }, { -1, 0, 1 }, { 0, 1, 1 } }); /// 27 | 28 | public double[][] firstLayerFeatureMapsSum; 29 | public List firstLayerFeaturesVector, secondLayerFeaturesVector, nFeaturesVector; 30 | 31 | public void setupAllLayers(double[][] inputs) { 32 | firstLayerFeaturesVector = features(inputs, filters, 1, true, true, 2, 2); 33 | firstLayerFeatureMapsSum = matrixSum(firstLayerFeaturesVector); 34 | nFeaturesVector = secondLayerFeaturesVector = features(firstLayerFeatureMapsSum, filters, 1, true, true, 2, 2); 35 | } 36 | 37 | public void setupAllLayers(double[][] inputs, int i) { 38 | setupAllLayers(inputs); 39 | for (int n = 2; n < i; n++) 40 | nFeaturesVector = features(matrixSum(nFeaturesVector), filters, 1, true, true, 2, 2); 41 | for (int ii = 0; ii < nFeaturesVector.size(); ii++) 42 | nFeaturesVector.set(ii, normalization(nFeaturesVector.get(ii))); 43 | } 44 | 45 | public List getFeatureVector() { 46 | return nFeaturesVector.stream().map(f -> Arrays.stream(f).map(Arrays::stream)// 47 | .map(e -> e.boxed()).flatMap(s -> s)).flatMap(s -> s).collect(Collectors.toList()); 48 | } 49 | 50 | public double[][] matrixSum(List features) { 51 | double[][] outputs = new double[features.get(0).length][features.get(0)[0].length]; 52 | for (double[][] feature : features) 53 | for (int y = 0; y < features.get(0).length; y++) 54 | for (int x = 0; x < features.get(0)[0].length; x++) 55 | outputs[y][x] += feature[y][x]; 56 | return outputs; 57 | } 58 | 59 | public List features(double[][] inputs, List filters, int stride, boolean padding, 60 | boolean relu, int poolingFilterSize, int poolingStride) { 61 | return filters.stream().map(f -> pooling(convolution(inputs, f, stride, padding, relu)// 62 | , poolingFilterSize, poolingStride)).collect(Collectors.toList()); 63 | } 64 | 65 | public static double[][] normalization(double[][] input) { 66 | double max = 0; 67 | for (int y = 0; y < input.length; y++) 68 | for (int x = 0; x < input[0].length; x++) 69 | max = input[y][x] > max ? input[y][x] : max; 70 | double scale = max == 0 ? 1 : max; 71 | return Arrays.stream(input).map(a -> Arrays.stream(a).map(d -> d / scale).toArray()).toArray(double[][]::new); 72 | } 73 | 74 | public double[][] createDrawing(double[][] input) { 75 | return matrixSum(filters.stream().map(f -> convolution(input, f, 1, true, true)).collect(Collectors.toList())); 76 | } 77 | 78 | public double[][] pooling(double[][] inputs, int filterSize, int stride) { 79 | int sizeH = (int) Math.floor(inputs.length / stride), sizeW = (int) Math.floor(inputs[0].length / stride); 80 | double[][] output = new double[sizeH][sizeW]; 81 | for (int y = 0; y < output.length; y++) 82 | for (int x = 0; x < output[0].length; x++) 83 | output[y][x] = max(inputs, filterSize, x * stride, y * stride); 84 | return output; 85 | } 86 | 87 | public double max(double[][] inputs, int filterSize, int X, int Y) { 88 | double output = 0; 89 | for (int y = 0; y < filterSize; y++) 90 | for (int x = 0; x < filterSize; x++) 91 | output = output > inputs[Y + y][X + x] ? output : inputs[Y + y][X + x]; 92 | return output; 93 | } 94 | 95 | public double[][] convolution(double[][] inputs, double[][] filter, int stride, boolean padding, boolean relu) { 96 | inputs = padding ? pad(inputs, (filter.length - 1) / 2) : inputs; 97 | int sizeH = (int) Math.ceil((inputs.length - filter.length + 1) / stride); 98 | int sizeW = (int) Math.ceil((inputs[0].length - filter.length + 1) / stride); 99 | double[][] output = new double[sizeH][sizeW]; 100 | for (int y = 0; y < output.length; y++) 101 | for (int x = 0; x < output[0].length; x++) { 102 | double conv = convolveOnePixel(inputs, filter, x * stride, y * stride); 103 | output[y][x] = relu ? Math.max(0, conv) : conv; 104 | } 105 | return output; 106 | } 107 | 108 | public double convolveOnePixel(double[][] inputs, double[][] filter, int X, int Y) { 109 | double output = 0; 110 | for (int y = 0; y < filter.length; y++) 111 | for (int x = 0; x < filter[0].length; x++) 112 | output += filter[y][x] * inputs[Y + y][X + x]; 113 | return output; 114 | 115 | } 116 | 117 | public double[][] pad(double[][] inputs, int size) { 118 | double[][] output = new double[inputs.length + size * 2][inputs[0].length + size * 2]; 119 | for (int y = 0; y < inputs.length; y++) 120 | System.arraycopy(inputs[y], 0, output[y + size], size, inputs[y].length); 121 | return output; 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/neuralNetwork/DataSet.java: -------------------------------------------------------------------------------- 1 | package neuralNetwork; 2 | 3 | import java.util.Arrays; 4 | import java.util.List; 5 | import java.util.stream.Collectors; 6 | 7 | public class DataSet { 8 | 9 | private double[][] inputsAsMatrix; 10 | private List featureVector, inputVector, trueOutputDenormalized, box, trueOuput; 11 | private String label; 12 | 13 | public DataSet(double[][] in, List out, String word, Convolution convolution) { 14 | this(in, convolution); 15 | trueOuput = out; 16 | label = word; 17 | } 18 | 19 | public DataSet(double[][] in, List out, String word, Convolution convolution, List b, 20 | List outDenormalized) { 21 | this(in, out, word, convolution); 22 | box = b; 23 | trueOutputDenormalized = outDenormalized; 24 | } 25 | 26 | public DataSet(double[][] in, Convolution convolution) { 27 | inputsAsMatrix = in; 28 | convolution.setupAllLayers(inputsAsMatrix); 29 | featureVector = convolution.getFeatureVector(); 30 | inputVector = Arrays.stream(inputsAsMatrix).map(Arrays::stream).map(e -> e.boxed()).flatMap(s -> s) 31 | .collect(Collectors.toList()); 32 | } 33 | 34 | public double[][] getInputs() { 35 | return inputsAsMatrix; 36 | } 37 | 38 | public List getBox() { 39 | return box; 40 | } 41 | 42 | public List getInputsVector() { 43 | return inputVector; 44 | }// map ->List ->List ->List> ->Stream 45 | 46 | public List getFeatureVector() { 47 | return featureVector; 48 | } 49 | 50 | public List getTrueOutputDenormalized() { 51 | return trueOutputDenormalized; 52 | } 53 | 54 | public List getTrueOutput() { 55 | return trueOuput; 56 | } 57 | 58 | public String getLabel() { 59 | return label; 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/neuralNetwork/Network.java: -------------------------------------------------------------------------------- 1 | package neuralNetwork; 2 | 3 | import java.io.*; 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | import java.util.stream.Collectors; 7 | 8 | import gui.MainFrame; 9 | 10 | public class Network { 11 | private List neurons = new ArrayList<>(); 12 | 13 | public Network(MainFrame mf) { 14 | for (int i = 0; i < mf.outputSize; i++) 15 | neurons.add(new Neuron()); 16 | } 17 | 18 | public void setAllNeuronsInputs(List inputvector) { 19 | neurons.forEach(n -> n.setInputs(inputvector)); 20 | } 21 | 22 | public void modifyAllNeuronsWeights(List label) { 23 | neurons.forEach(n -> n.totalCost += n.modifyWeights(label.get(neurons.indexOf(n)), n.sum())); 24 | } 25 | 26 | public List getNeurons() { 27 | return neurons; 28 | } 29 | 30 | public List getOutputs() { 31 | return neurons.stream().map(n -> n.sum()).collect(Collectors.toList()); 32 | 33 | } 34 | 35 | public void saveWeights(String type) { 36 | DataOutputStream out = null; 37 | try { 38 | out = new DataOutputStream(new FileOutputStream(new File("dataset/weights/" + type + ".w"))); 39 | for (Neuron n : neurons) 40 | for (int i = 0; i < n.weights.length; i++) 41 | out.writeDouble(n.weights[i]); 42 | out.flush(); 43 | out.close(); 44 | } catch (IOException e) { 45 | } 46 | } 47 | 48 | @SuppressWarnings("deprecation") 49 | public void readWeights(String type) { 50 | DataInputStream in = null; 51 | try { 52 | in = new DataInputStream( 53 | new BufferedInputStream(new File("dataset/weights/" + type + ".w").toURL().openStream())); 54 | for (Neuron n : neurons) 55 | for (int i = 0; i < n.weights.length; i++) 56 | n.weights[i] = in.readDouble(); 57 | in.close(); 58 | } catch (IOException e) { 59 | } 60 | } 61 | } 62 | // 150 1.5, 50 1, 50 0.5, 50 0.25, 1 0.1 63 | // 100 2, 50 1.5, 50 1, 50 0.5, 25 0.25 64 | -------------------------------------------------------------------------------- /src/neuralNetwork/Neuron.java: -------------------------------------------------------------------------------- 1 | package neuralNetwork; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class Neuron { 7 | ArrayList inputs = new ArrayList<>(); 8 | public static double learningRate = 0.008, totalCAll = 0; 9 | public double weights[] = new double[785], oldWeights[] = new double[785]; 10 | public static boolean useBP = false, keepBestWeights = false, allowCopy = false; 11 | 12 | public double totalCost = 0, lastTotalC = 100000000; 13 | 14 | public double modifyWeights(double trueOutput, double perdictedOutput) { 15 | double error = trueOutput - perdictedOutput; 16 | if (keepBestWeights && allowCopy) { 17 | System.arraycopy(weights, 0, oldWeights, 0, weights.length); 18 | allowCopy = false; 19 | } 20 | double backPropagation = useBP ? (-error * perdictedOutput * (1 - perdictedOutput)) : 0; 21 | for (int i = 0; i < inputs.size(); i++) 22 | weights[i] += (learningRate * inputs.get(i) * (error - backPropagation)); 23 | weights[weights.length - 1] += learningRate * (error - backPropagation); 24 | return error * error; 25 | } 26 | 27 | public double sum() { 28 | double sum = 0; 29 | for (int i = 0; i < inputs.size(); i++) 30 | sum += inputs.get(i) * weights[i]; 31 | sum += weights[weights.length - 1]; 32 | return sigmoid(sum); 33 | } 34 | 35 | public void setInputs(List list) { 36 | inputs = new ArrayList<>(list); 37 | } 38 | 39 | public static double sigmoid(Double sum) { 40 | return 1 / (1 + Math.exp(-sum)); 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/neuralNetwork/Trainer.java: -------------------------------------------------------------------------------- 1 | package neuralNetwork; 2 | 3 | import java.io.IOException; 4 | import java.util.List; 5 | 6 | import javax.swing.JTextArea; 7 | 8 | import gui.DrawingPanel; 9 | import gui.MainFrame; 10 | import tool.DataInfo; 11 | 12 | public class Trainer { 13 | public Network network; 14 | public List trainSets; 15 | public static int i = 0, epoch = 1; 16 | public JTextArea lossTextArea = new JTextArea(); 17 | DrawingPanel drawPanel; 18 | int trainSize; 19 | 20 | public Trainer(MainFrame mf) throws IOException { 21 | lossTextArea = mf.lossTextArea; 22 | network = mf.network; 23 | trainSets = DataInfo.getTrainSets(); 24 | drawPanel = mf.drawPanel; 25 | trainSize = mf.trainSize; 26 | } 27 | 28 | public void train(int count, int type) { 29 | for (i = 0; i < count; i++) { 30 | DataSet trainSet = trainSets.get((int) (Math.random() * trainSets.size())); 31 | network.setAllNeuronsInputs(type == 0 ? trainSet.getFeatureVector() : trainSet.getInputsVector()); 32 | network.modifyAllNeuronsWeights(trainSet.getTrueOutput()); 33 | if (i % trainSize == trainSize - 1) 34 | UpdateLossAreaAndWeightsDrawing(count); 35 | if (i % trainSize == 0 && Neuron.keepBestWeights) 36 | Neuron.allowCopy = true; 37 | } 38 | } 39 | 40 | public void UpdateLossAreaAndWeightsDrawing(int c) { 41 | StringBuilder sb = new StringBuilder(); 42 | network.getNeurons().forEach(n -> { 43 | int index = network.getNeurons().indexOf(n); 44 | if (n.totalCost > n.lastTotalC && Neuron.keepBestWeights) { 45 | n.weights = n.oldWeights; 46 | sb.append(index + " neuron loss: " 47 | + MainFrame.df.format((n.totalCost = n.lastTotalC) / (trainSize / 100)) + "%(keep) "); 48 | } else 49 | sb.append(index + "neuron loss: " 50 | + MainFrame.df.format((n.lastTotalC = n.totalCost) / (trainSize / 100)) + "% "); 51 | sb.append((index % 2 == 0 ? "" : "\n")); 52 | Neuron.totalCAll += n.totalCost; 53 | n.totalCost = 0; 54 | }); 55 | sb.append("Epoch " + epoch++ + " total loss: " + MainFrame.df.format(Neuron.totalCAll / (trainSize / 100)) 56 | + "%, " + Neuron.totalCAll); 57 | Neuron.totalCAll = 0; 58 | lossTextArea.setText(sb.toString()); 59 | try { 60 | lossTextArea.update(lossTextArea.getGraphics()); 61 | if (drawPanel.convPanel.showWeights) 62 | drawPanel.convPanel.update(drawPanel.convPanel.getGraphics()); 63 | } catch (Exception e) { 64 | e.printStackTrace(); 65 | } 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /src/tool/CreateForeOrBackgroundSample.java: -------------------------------------------------------------------------------- 1 | package tool; 2 | 3 | import java.awt.Color; 4 | import java.awt.Graphics; 5 | import java.awt.image.BufferedImage; 6 | import java.io.*; 7 | import java.nio.file.Files; 8 | import java.nio.file.Paths; 9 | import java.util.ArrayList; 10 | import java.util.List; 11 | import java.util.Random; 12 | import java.util.stream.Collectors; 13 | import java.util.stream.IntStream; 14 | 15 | import javax.imageio.ImageIO; 16 | 17 | import detection.BoundingBox; 18 | 19 | public class CreateForeOrBackgroundSample { 20 | static File boxRegressionLable = new File("dataset/standardOutput/foreground_bBoxRegressionLabels.txt"), 21 | gtBoxLabel = new File("dataset/standardOutput/pictures_groundTruthBoxAndLabel.txt"); // 22 | static BufferedWriter out; 23 | 24 | static int scales[] = { 28, 56, 112, 224 }; 25 | static double ratios[][] = { { 1, 1 }, { 1, 1.5 }, { 1.5, 1 }, { 1, 1.8 }, { 1.8, 1 }, { 1, 2.1 }, { 2.1, 1 }, 26 | { 1, 2.4 }, { 2.4, 1 } }; 27 | static int scalars[] = IntStream.range(0, ratios.length * scales.length * 2) 28 | .map(i -> (int) (scales[i / ratios.length / 2] * ratios[i / 2 % ratios.length][i % 2])).toArray(); 29 | static double W = 28, H = 28, overlapThreshhold = 0.7; 30 | static int resolutionW = 320, resolutionH = 240, girdLen = 8, numPositiveSample = 2, numNegativeSample = 2, 31 | countF = 0, countB = 0; 32 | static Random rand = new Random(); 33 | static BufferedImage bi, biTmp, biNew, biOut; 34 | static Graphics g; 35 | static BoundingBox b; 36 | 37 | public static void main(String[] args) throws IOException { 38 | long l = System.currentTimeMillis(); 39 | out = new BufferedWriter(new FileWriter(boxRegressionLable)); 40 | out.write("id | ground truth box normalized movement | foreground box | ground truth box | label\n"); 41 | List images = Files.walk(Paths.get("dataset/pictures"), 1).map(p -> p.toFile()).filter(File::isFile) 42 | .sorted((f1, f2) -> Integer.parseInt(f1.getName().toString().replace(".jpg", "")) - Integer.parseInt// 43 | (f2.getName().toString().replace(".jpg", ""))).collect(Collectors.toList()); 44 | BufferedReader br = new BufferedReader(new FileReader(gtBoxLabel)); 45 | BoundingBox boxGT, boxRand; 46 | List lineList = br.lines().collect(Collectors.toList()); 47 | br.close(); 48 | List foreground, background; 49 | for (int i = 1; i < lineList.size(); i++) { 50 | String[] parts = lineList.get(i).split("\\|"), box = parts[1].split(","); 51 | int id = Integer.parseInt(parts[0]), xL1 = Integer.parseInt(box[0]), yT1 = Integer.parseInt(box[1]), 52 | w1 = Integer.parseInt(box[2]), h1 = Integer.parseInt(box[3]); 53 | String label = parts[2]; // "1" for the ground truth box 54 | boxGT = new BoundingBox(xL1, yT1, w1, h1); 55 | bi = ImageIO.read(images.get(id)); 56 | biTmp = new BufferedImage(320, 240, BufferedImage.TYPE_INT_RGB); 57 | biTmp.setData(bi.getData()); 58 | g = bi.createGraphics(); 59 | foreground = new ArrayList<>(); 60 | background = new ArrayList<>(); 61 | for (int x = 0; x < resolutionW; x += girdLen) 62 | for (int y = 0; y < resolutionH; y += girdLen) 63 | for (int s = 0; s < scalars.length; s += 2) { // "2" for the bounding boxes with different ratios 64 | int w2 = scalars[s], h2 = scalars[s + 1], xL2 = x - w2 / 2, yT2 = y - h2 / 2; 65 | boxRand = new BoundingBox(xL2, yT2, w2, h2); 66 | double overlapPercent = CreatePictureForObjectDetectionFromMNIST.IOU(boxGT, boxRand); 67 | if (overlapPercent > overlapThreshhold) 68 | foreground.add(boxRand); 69 | else if (0.05 < overlapPercent && overlapPercent < 0.5) 70 | background.add(boxRand); 71 | } 72 | g.setColor(Color.BLUE); 73 | createSamplePictures(foreground, label, numPositiveSample, "foreground", id, w1, h1, xL1, yT1); 74 | g.setColor(Color.RED); 75 | createSamplePictures(background, label, numNegativeSample - (int) Math.round(rand.nextDouble()), 76 | "background", id, w1, h1, xL1, yT1); 77 | g.setColor(Color.GREEN); 78 | g.drawRect(xL1, yT1, w1, h1); 79 | g.drawString("ground truth", Math.max(0, xL1), Math.max(0, yT1)); 80 | ImageIO.write(bi, "jpg", new File("dataset/groundTruthExamples/" + id + ".jpg")); 81 | System.out.println(id + 1 + "/" + images.size() + ", count: foreground=" + countF + ",boreground=" + countB 82 | + "," + (System.currentTimeMillis() - l) / 1000.0); 83 | } 84 | System.out.println((System.currentTimeMillis() - l) / 1000.0); 85 | out.close(); 86 | } 87 | 88 | public static void createSamplePictures(List sample, String label, int num, String dir, int exampleID, 89 | int w, int h, int xL, int yT) throws IOException { 90 | int size = Math.min(num, sample.size()); 91 | biNew = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB); 92 | for (int n = 0; n < size; n++) { 93 | b = sample.get(size == num ? rand.nextInt(sample.size()) : n); 94 | g.drawRect(b.getX(), b.getY(), b.getW(), b.getH()); 95 | g.drawString(dir, Math.max(0, b.getX()), Math.max(0, b.getY())); 96 | // above is for drawing pos or neg rect in pictures in ../groundTruthExamples 97 | double wS = W / b.getW(), hS = H / b.getH(); 98 | biOut = CreatePictureForObjectDetectionFromMNIST.scaleNearest(biTmp, wS, hS); 99 | for (int x = 0, x1 = (int) (b.getX() * wS); x < 28; x++, x1++) 100 | for (int y = 0, y1 = (int) (b.getY() * hS); y < 28; y++, y1++) 101 | biNew.setRGB(x, y, (x1 < 0 || y1 < 0 || x1 > (int) (320 * wS) - 1 || y1 > (int) (240 * hS) - 1) ? // 102 | -1 : biOut.getRGB(x1, y1)); 103 | ImageIO.write(biNew, "jpg", 104 | new File("dataset/" + dir + "/" + (dir.equals("foreground") ? countF : countB++) + ".jpg")); 105 | if (dir.equals("foreground")) 106 | out.write(countF++ + "|" + exampleID + "|" + sigmoid((b.getX() - xL) / (double) b.getW(), 1) + "," 107 | + sigmoid((b.getY() - yT) / (double) b.getH(), 1) + "," + sigmoid((double) w / b.getW(), 1e4) 108 | + "," + sigmoid((double) h / b.getH(), 1e4) + "|" + b.getX() + "," + b.getY() + "," + b.getW() 109 | + "," + b.getH() + "|" + xL + "," + yT + "," + w + "," + h + "|" + label + "\n"); 110 | } 111 | } 112 | 113 | public static double sigmoid(Double sum, double leftOrRight) { 114 | return 1 / (1 + Math.exp(-sum * 10) * leftOrRight); 115 | } 116 | 117 | } 118 | -------------------------------------------------------------------------------- /src/tool/CreatePictureForObjectDetectionFromMNIST.java: -------------------------------------------------------------------------------- 1 | package tool; 2 | 3 | import java.awt.geom.AffineTransform; 4 | import java.awt.image.AffineTransformOp; 5 | import java.awt.image.BufferedImage; 6 | import java.io.*; 7 | import java.util.List; 8 | import java.util.Random; 9 | 10 | import javax.imageio.ImageIO; 11 | 12 | import detection.BoundingBox; 13 | import neuralNetwork.DataSet; 14 | 15 | public class CreatePictureForObjectDetectionFromMNIST { 16 | static Random rand = new Random(); 17 | static File groundTruthBoxAndLabel = new File("dataset/standardOutput/pictures_groundTruthBoxAndLabel.txt"); // 18 | 19 | static int mnistW = 28, mnistH = 28, outputW = 320, outputH = 240; 20 | 21 | public static void main(String[] args) throws IOException { 22 | long l = System.currentTimeMillis(); 23 | new DataInfo("mnist"); 24 | DataInfo.getTrainSets().addAll(DataInfo.getTestSets()); 25 | createPicturesAndLabels(DataInfo.getTrainSets()); 26 | System.out.println((System.currentTimeMillis() - l) / 1000.0); 27 | } 28 | 29 | public static void createPicturesAndLabels(List dataSet) throws IOException { 30 | BufferedWriter out = new BufferedWriter(new FileWriter(groundTruthBoxAndLabel)); 31 | out.write("id | left edge X, top edge Y, W, H | Label\n"); 32 | BufferedImage bi, biNew; 33 | for (int i = 0; i < dataSet.size(); i++) { 34 | double[][] map = dataSet.get(i).getInputs(); 35 | bi = new BufferedImage(mnistW, mnistH, BufferedImage.TYPE_INT_RGB); 36 | biNew = new BufferedImage(outputW, outputH, BufferedImage.TYPE_INT_RGB); 37 | biNew.getGraphics().fillRect(0, 0, outputW, outputH); 38 | for (int x = 0; x < mnistW; x++) 39 | for (int y = 0; y < mnistH; y++) 40 | bi.setRGB(x, y, (int) ((1 - map[y][x]) * 255) << 16 | (int) ((1 - map[y][x]) * 255) << 8 41 | | (int) ((1 - map[y][x]) * 255)); 42 | double wScalar = rand.nextDouble() * 10 + 1, 43 | hScalar = Math.max(1, wScalar + rand.nextDouble() * (rand.nextDouble() > 0.5 ? 3 : -3)); 44 | int w = (int) (mnistW * wScalar), h = (int) (mnistH * hScalar), xL = (int) (rand.nextInt(outputW) - w / 2), 45 | yT = (int) (rand.nextInt(outputH) - h / 2); 46 | if (IOU(new BoundingBox(0, 0, outputW, outputH), new BoundingBox(xL, yT, w, h)) < 0.7 47 | && !(xL > 0 && yT > 0 && xL + w < outputW && yT + h < outputH)) { 48 | i--; 49 | continue; 50 | } 51 | biNew.getGraphics().drawImage(scaleNearest(bi, wScalar, hScalar), xL, yT, null); 52 | ImageIO.write(biNew, "jpg", new File("dataset/pictures/" + i + ".jpg")); 53 | out.write(i + "|" + xL + "," + yT + "," + w + "," + h + "|" + dataSet.get(i).getLabel() + "\n"); 54 | System.out.println(i + "/" + dataSet.size()); 55 | 56 | } 57 | out.close(); 58 | 59 | } 60 | 61 | public static BufferedImage scaleNearest(BufferedImage img, double wScalar, double hScalar) { 62 | BufferedImage newImage = new BufferedImage((int) (img.getWidth() * wScalar), (int) (img.getHeight() * hScalar), 63 | img.getType()); 64 | AffineTransform scaleInstance = AffineTransform.getScaleInstance(wScalar, hScalar); 65 | AffineTransformOp scaleOp = new AffineTransformOp(scaleInstance, AffineTransformOp.TYPE_NEAREST_NEIGHBOR); 66 | scaleOp.filter(img, newImage); 67 | return newImage; 68 | } 69 | 70 | public static double IOU(BoundingBox a, BoundingBox b) { 71 | int areaA = a.getH() * a.getW(), areaB = b.getH() * b.getW(); 72 | int wTotal = Math.max(a.getX() + a.getW(), b.getX() + b.getW()) - Math.min(a.getX(), b.getX()), 73 | hTotal = Math.max(a.getY() + a.getH(), b.getY() + b.getH()) - Math.min(a.getY(), b.getY()), 74 | wOverlap = wTotal - a.getW() - b.getW(), hOverlap = hTotal - a.getH() - b.getH(), 75 | areaOverlap = (wOverlap >= 0 || hOverlap >= 0) ? 0 : wOverlap * hOverlap; 76 | return (double) areaOverlap / (areaA + areaB - areaOverlap); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/tool/DataInfo.java: -------------------------------------------------------------------------------- 1 | package tool; 2 | 3 | import java.io.*; 4 | import java.util.ArrayList; 5 | import java.util.Arrays; 6 | import java.util.Collections; 7 | import java.util.List; 8 | import java.util.Map; 9 | import java.util.TreeMap; 10 | import java.util.stream.Collectors; 11 | 12 | import neuralNetwork.DataSet; 13 | 14 | public class DataInfo { 15 | private List trainWordNames = new ArrayList<>(); 16 | private static Map> trueOutputs, box, trueOutputsdenormalized, exampleID; 17 | private static List trainSets, testSets; 18 | 19 | public DataInfo(String type) throws IOException { 20 | switch (type) { 21 | case "mnist": 22 | trueOutputs = vectorsParser("dataset/standardOutput/digitLabels.txt", 1); 23 | trainSets = ReadFile.readFromSingleTxt("MNIST TrainSet.txt"); 24 | testSets = ReadFile.readFromSingleTxt("MNIST TestSet.txt"); 25 | break; 26 | case "foreOrBackground": 27 | exampleID = vectorsParser("dataset/standardOutput/foreground_bBoxRegressionLabels.txt", 1); 28 | trueOutputs = vectorsParser("dataset/standardOutput/foreground_bBoxRegressionLabels.txt", 2); 29 | box = vectorsParser("dataset/standardOutput/foreground_bBoxRegressionLabels.txt", 3); 30 | trueOutputsdenormalized = vectorsParser("dataset/standardOutput/foreground_bBoxRegressionLabels.txt", 4); 31 | trainSets = ReadFile.readFromPictures("foreground", true); 32 | Collections.shuffle(trainSets); 33 | testSets = trainSets.subList(trainSets.size() * 6 / 7, trainSets.size()); 34 | trainSets = trainSets.subList(0, trainSets.size() * 6 / 7); 35 | } 36 | 37 | } 38 | 39 | public static List getTrainSets() { 40 | return trainSets; 41 | } 42 | 43 | public static List getTestSets() { 44 | return testSets; 45 | } 46 | 47 | public List getTrainWordNames() { 48 | return trainWordNames; 49 | } 50 | 51 | public String[] getTrainWordNamesArray() { 52 | return trainWordNames.toArray(new String[getTrainWordNames().size()]); 53 | } 54 | 55 | public static List getTrueOutput(String input) { 56 | return trueOutputs.get(input); 57 | } 58 | 59 | public static List getExampleID(String input) { 60 | return exampleID.get(input); 61 | } 62 | 63 | public static List getTrueOutputDenormalized(String input) { 64 | return trueOutputsdenormalized.get(input); 65 | } 66 | 67 | public static List getBox(String input) { 68 | return box.get(input); 69 | } 70 | 71 | public Map> vectorsParser(String file, int type) throws IOException { 72 | Map> map = new TreeMap<>(); 73 | BufferedReader br = new BufferedReader(new FileReader(new File(file))); 74 | for (String line = br.readLine(); (line = br.readLine()) != null;) { 75 | String[] parts = line.split("\\|"); 76 | map.put(parts[0], Arrays.asList(parts[type].split(",")).stream().map(Double::parseDouble) 77 | .collect(Collectors.toList())); 78 | } 79 | br.close(); 80 | return map; 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/tool/ReadFile.java: -------------------------------------------------------------------------------- 1 | package tool; 2 | 3 | import java.awt.image.BufferedImage; 4 | import java.awt.image.Raster; 5 | import java.io.*; 6 | import java.nio.file.Files; 7 | import java.nio.file.Paths; 8 | import java.util.ArrayList; 9 | import java.util.Arrays; 10 | import java.util.List; 11 | import java.util.Scanner; 12 | import java.util.stream.Collectors; 13 | import javax.imageio.ImageIO; 14 | import neuralNetwork.*; 15 | 16 | public class ReadFile { 17 | public static int i = 0, resolutionW = 28, resolutionH = 28, dataSetAmount = 10000; 18 | 19 | public static List readFromSingleTxt(String filename) throws IOException { 20 | long t = System.currentTimeMillis(); 21 | Convolution conv = new Convolution(); 22 | List trainingSets = new ArrayList<>(); 23 | for (Scanner s = new Scanner(new FileInputStream(new File("dataset/" + filename))); s.hasNextLine();) { 24 | String[] line = s.nextLine().split("\\|"); 25 | String word = String.valueOf(line[line.length - 1]); 26 | double[][] trainSetInputs = new double[28][28]; 27 | for (int y = 0; y < 28; y++) 28 | for (int x = 0; x < 28; x++) 29 | trainSetInputs[x][y] = Integer.parseInt(line[y * 28 + x]) / 255.0; 30 | trainingSets.add(new DataSet(trainSetInputs, DataInfo.getTrueOutput(word), word, conv)); 31 | 32 | System.out.println(i++); 33 | } 34 | System.out.println((System.currentTimeMillis() - t) / 1000.0 + " seconds"); 35 | return trainingSets; 36 | } 37 | 38 | public static List readFromPictures(String dir, boolean isBBoxRegression) throws IOException { 39 | Convolution conv = new Convolution(); 40 | List trainingSets = new ArrayList<>(); 41 | List files = Files.walk(Paths.get("dataset/" + dir), 1).map(p -> p.toFile()).filter(File::isFile) 42 | .limit(dataSetAmount).collect(Collectors.toList()); 43 | 44 | for (File f : files) { 45 | System.out.println(i++); 46 | // System.out.println(f); 47 | BufferedImage bi = ImageIO.read(f); 48 | double[][] input = new double[bi.getHeight()][bi.getWidth()]; 49 | 50 | Raster r = bi.getData(); 51 | for (int y = 0; y < bi.getHeight(); y++) 52 | for (int x = 0; x < bi.getWidth(); x++) { 53 | int[] tmp = r.getPixel(x, y, (int[]) null); 54 | input[y][x] = 1.0 - (tmp[0] + tmp[1] + tmp[2]) / 765.0; 55 | } 56 | String id = f.getName().replace(".jpg", ""), exampleID = ", example ID: " 57 | + DataInfo.getExampleID(id).get(0).intValue() + "(in ../groundTruthExamples/)"; 58 | List output = new ArrayList<>(isBBoxRegression ? DataInfo.getTrueOutput(f.getName().replace(// 59 | ".jpg", "")) : dir.contains("foreground") ? Arrays.asList(1.0, 0.0) : Arrays.asList(0.0, 1.0)); 60 | trainingSets.add(new DataSet(input, output, id + exampleID, conv, isBBoxRegression ? // 61 | DataInfo.getBox(id) : null, isBBoxRegression ? DataInfo.getTrueOutputDenormalized(id) : null)); 62 | } 63 | return trainingSets; 64 | } 65 | 66 | } 67 | --------------------------------------------------------------------------------