├── .gitignore ├── README.md ├── requirements.yml └── script ├── DeepLineDP_model.py ├── export_data_for_line_level_baseline.py ├── file-level-baseline ├── Bi-LSTM-baseline.py ├── BoW-baseline.py ├── CNN-baseline.py ├── DBN-baseline.py ├── baseline_util.py └── dbn │ ├── activations.py │ ├── models.py │ └── utils.py ├── generate_prediction.py ├── generate_prediction_cross_projects.py ├── get_evaluation_result.R ├── line-level-baseline ├── ErrorProne │ ├── dataflow-shaded-3.1.2.jar │ ├── error_prone_core-2.4.0-with-dependencies.jar │ ├── jFormatString-3.0.0.jar │ ├── javac-9+181-r4173-1.jar │ └── run_ErrorProne.ipynb ├── RF-line-level.py └── ngram │ ├── commons-io-2.8.0.jar │ ├── n_gram.java │ └── slp-core.jar ├── my_util.py ├── preprocess_data.py ├── train_model.py └── train_word2vec.py /.gitignore: -------------------------------------------------------------------------------- 1 | /datasets 2 | /output -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Supplementary Materials for "DeepLineDP: Towards a Deep Learning Approach for Line-Level Defect Prediction" 3 | 4 | 5 | ## Datasets 6 | 7 | The datasets are obtained from Wattanakriengkrai et. al. The datasets contain 32 software releases across 9 software projects. The datasets that we used in our experiment can be found in this [github](https://github.com/awsm-research/line-level-defect-prediction). 8 | 9 | The file-level datasets (in the File-level directory) contain the following columns 10 | 11 | - `File`: A file name of source code 12 | - `Bug`: A label indicating whether source code is clean or defective 13 | - `SRC`: A content in source code file 14 | 15 | The line-level datasets (in the Line-level directory) contain the following columns 16 | - `File`: A file name of source code 17 | - `Commit`: A commit id of bug-fixing commit of the file 18 | - `Line_number`: A line number where source code is modified 19 | - `SRC`: An actual source code that is modified 20 | 21 | For each software project, we use the oldest release to train DeepLineDP models. The subsequent release is used as validation sets. The other releases are used as test sets. 22 | 23 | For example, there are 5 releases in ActiveMQ (e.g., R1, R2, R3, R4, R5), R1 is used as training set, R2 is used as validation set, and R3 - R5 are used as test sets. 24 | 25 | ## Repository Structure 26 | 27 | Our repository contains the following directory 28 | 29 | - `output`: This directory contains the following sub-directories: 30 | - `loss`: This directory stores training and validation loss 31 | - `model`: This directory stores trained models 32 | - `prediction`: This directory stores prediction (in CSV files) obtained from the trained models 33 | - `Word2Vec_model`: This directory stores word2vec models of each software project 34 | - `script`: This directory contains the following directories and files: 35 | - `preprocess_data.py`: The source code used to preprocess datasets for file-level model training and evaluation 36 | - `export_data_for_line_level_baseline.py`: The source code used to prepare data for line-level baseline 37 | - `my_util.py`: The source code used to store utility functions 38 | - `train_word2vec.py`: The source code used to train word2vec models 39 | - `DeepLineDP_model.py`: The source code that stores DeepLineDP architecture 40 | - `train_model.py`: The source code used to train DeepLineDP models 41 | - `generate_prediction.py`: The source code used to generate prediction (for RQ1-RQ3) 42 | - `generate_prediction_cross_projects.py`: The source code used to generate prediction (for RQ4) 43 | - `get_evaluation_result.R`: The source code used to generate figures for RQ1-RQ3, and show RQ4 result 44 | - `file-level-baseline`: The directory that stores implementation of the file-level baselines, and `baseline_util.py` that stores utility function of the baselines 45 | - `line-level-baseline`: The directory that stores implementation of the line-level baselines 46 | 47 | ## Environment Setup 48 | 49 | 50 | ### Python Environment Setup 51 | 1. clone the github repository by using the following command: 52 | 53 | git clone https://github.com/awsm-research/DeepLineDP.git 54 | 55 | 2. download the dataset from this [github](https://github.com/awsm-research/line-level-defect-prediction) and keep it in `./datasets/original/` 56 | 57 | 3. use the following command to install required libraries in conda environment 58 | 59 | conda env create -f requirements.yml 60 | conda activate DeepLineDP_env 61 | 62 | 4. install PyTorch library by following the instruction from this [link](https://pytorch.org/) (the installation instruction may vary based on OS and CUDA version) 63 | 64 | 65 | ### R Environment Setup 66 | 67 | Download the following package: `tidyverse`, `gridExtra`, `ModelMetrics`, `caret`, `reshape2`, `pROC`, `effsize`, `ScottKnottESD` 68 | 69 | 70 | ## Experiment 71 | 72 | ### Experimental Setup 73 | 74 | We use the following hyper-parameters to train our DeepLineDP model 75 | 76 | - `batch_size` = 32 77 | - `num_epochs` = 10 78 | - `embed_dim (word embedding size)` = 50 79 | - `word_gru_hidden_dim` = 64 80 | - `sent_gru_hidden_dim` = 64 81 | - `word_gru_num_layers` = 1 82 | - `sent_gru_num_layers` = 1 83 | - `dropout` = 0.2 84 | - `lr (learning rate)` = 0.001 85 | 86 | ### Data Preprocessing 87 | 88 | 1. run the command to prepare data for file-level model training. The output will be stored in `./datasets/preprocessed_data` 89 | 90 | python preprocess_data.py 91 | 92 | 2. run the command to prepare data for line-level baseline. The output will be stored in `./datasets/ErrorProne_data/` (for ErrorProne), and `./datasets/n_gram_data/` (for n-gram) 93 | 94 | python export_data_for_line_level_baseline.py 95 | 96 | ### Word2Vec Model Training 97 | 98 | To train Word2Vec models, run the following command: 99 | 100 | python train_word2vec.py 101 | 102 | Where \ is one of the following: `activemq`, `camel`, `derby`, `groovy`, `hbase`, `hive`, `jruby`, `lucene`, `wicket` 103 | 104 | ### DeepLineDP Model Training and Prediction Generation 105 | 106 | To train DeepLineDP models, run the following command: 107 | 108 | python train_model.py -dataset 109 | 110 | 111 | The trained models will be saved in `./output/model/DeepLineDP//`, and the loss will be saved in `../output/loss/DeepLineDP/-loss_record.csv` 112 | 113 | To make a prediction of each software release, run the following command: 114 | 115 | python generate_prediction.py -dataset 116 | 117 | The generated output is a csv file which contains the following information: 118 | 119 | - `project`: A software project, as specified by \ 120 | - `train`: A software release that is used to train DeepLineDP models 121 | - `test`: A software release that is used to make a prediction 122 | - `filename`: A file name of source code 123 | - `file-level-ground-truth`: A label indicating whether source code is clean or defective 124 | - `prediction-prob`: A probability of being a defective file 125 | - `prediction-label`: A prediction indicating whether source code is clean or defective 126 | - `line-number`: A line number of a source code file 127 | - `line-level-ground-truth`: A label indicating whether the line is modified 128 | - `is-comment-line`: A flag indicating whether the line is comment 129 | - `token`: A token in a code line 130 | - `token-attention-score`: An attention score of a token 131 | 132 | The generated output is stored in `./output/prediction/DeepLineDP/within-release/` 133 | 134 | 135 | To make a prediction across software project, run the following command: 136 | 137 | python generate_prediction_cross_projects.py -dataset 138 | 139 | The generated output is a csv file which has the same information as above, and is stored in `./output/prediction/DeepLineDP/cross-project/` 140 | 141 | ### File-level Baseline Implementation 142 | 143 | There are 4 baselines in the experiment (i.e., `Bi-LSTM`, `CNN`, `DBN` and `BoW`). To train the file-level baselines, go to `./script/file-level-baseline/` then run the following commands 144 | 145 | - `python Bi-LSTM-baseline.py -data -train` 146 | - `python CNN-baseline.py -data -train` 147 | - `python DBN-baseline.py -data -train` 148 | - `python BoW-baseline.py -data -train` 149 | 150 | The trained models will be saved in `./output/model///`, and the loss will be saved in `../output/loss//-loss_record.csv` 151 | 152 | where \ is one of the following: `Bi-LSTM`, `CNN`, `DBN` or `BoW`. 153 | 154 | To make a prediction, run the following command: 155 | - `python Bi-LSTM-baseline.py -data -predict -target_epochs 6` 156 | - `python CNN-baseline.py -data -predict -target_epochs 6` 157 | - `python DBN-baseline.py -data -predict` 158 | - `python BoW-baseline.py -data -predict` 159 | 160 | The generated output is a csv file which contains the following information: 161 | 162 | - `project`: A software project, as specified by \ 163 | - `train`: A software release that is used to train DeepLineDP models 164 | - `test`: A software release that is used to make a prediction 165 | - `filename`: A file name of source code 166 | - `file-level-ground-truth`: A label indicating whether source code is clean or defective 167 | - `prediction-prob`: A probability of being a defective file 168 | - `prediction-label`: A prediction indicating whether source code is clean or defective 169 | 170 | The generated output is stored in `./output/prediction//` 171 | 172 | ### Line-level Baseline Implementation 173 | 174 | There are 2 baselines in this experiment (i.e., `N-gram` and `ErrorProne`). 175 | 176 | To obtain the result from `N-gram`, go to `/script/line-level-baseline/ngram/` and run code in `n_gram.java`. The result will be stored in `/n_gram_result/` directory. After all results are obtained, copy the `/n_gram_result/` directory to the `/output/` directory. 177 | 178 | To obtain the result from `ErrorProne`, go to `/script/line-level-baseline/ErrorProne/` and run code in `run_ErrorProne.ipynb`. The result will be stored in `/ErrorProne_result/` directory. After all results are obtained, copy the `/ErrorProne_result/` directory to the `/output/` directory. 179 | 180 | ### Obtaining the Evaluation Result 181 | 182 | Run `get_evaluation_result.R` to get the result of RQ1-RQ4 (may run in IDE or by the following command) 183 | 184 | Rscript get_evaluation_result.R 185 | 186 | The results are figures that are stored in `./output/figures/` 187 | -------------------------------------------------------------------------------- /requirements.yml: -------------------------------------------------------------------------------- 1 | name: DeepLineDP_env 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=4.5=1_gnu 7 | - ca-certificates=2021.7.5=h06a4308_1 8 | - certifi=2021.5.30=py37h06a4308_0 9 | - ld_impl_linux-64=2.35.1=h7274673_9 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=9.3.0=h5101ec6_17 12 | - libgomp=9.3.0=h5101ec6_17 13 | - libstdcxx-ng=9.3.0=hd4cf53a_17 14 | - ncurses=6.2=he6710b0_1 15 | - openssl=1.1.1l=h7f8727e_0 16 | - pip=21.0.1=py37h06a4308_0 17 | - python=3.7.11=h12debd9_0 18 | - readline=8.1=h27cfd23_0 19 | - setuptools=58.0.4=py37h06a4308_0 20 | - sqlite=3.36.0=hc218d9a_0 21 | - tk=8.6.11=h1ccaba5_0 22 | - wheel=0.37.0=pyhd3eb1b0_1 23 | - xz=5.2.5=h7b6447c_0 24 | - zlib=1.2.11=h7b6447c_3 25 | - pip: 26 | - gensim==3.8.3 27 | - joblib==1.0.1 28 | - more-itertools==8.10.0 29 | - numpy==1.21.2 30 | - pandas==1.3.3 31 | - pillow==8.3.2 32 | - python-dateutil==2.8.2 33 | - pytz==2021.3 34 | - scikit-learn==1.0 35 | - scipy==1.7.1 36 | - six==1.16.0 37 | - smart-open==5.2.1 38 | - threadpoolctl==3.0.0 39 | - tqdm==4.62.3 40 | - typing-extensions==3.10.0.2 41 | prefix: /home/oathaha/.conda/envs/DeepLineDP_env 42 | -------------------------------------------------------------------------------- /script/DeepLineDP_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence 4 | 5 | 6 | # Model structure 7 | class HierarchicalAttentionNetwork(nn.Module): 8 | def __init__(self, vocab_size, embed_dim, word_gru_hidden_dim, sent_gru_hidden_dim, word_gru_num_layers, sent_gru_num_layers, word_att_dim, sent_att_dim, use_layer_norm, dropout): 9 | """ 10 | vocab_size: number of words in the vocabulary of the model 11 | embed_dim: dimension of word embeddings 12 | word_gru_hidden_dim: dimension of word-level GRU; biGRU output is double this size 13 | sent_gru_hidden_dim: dimension of sentence-level GRU; biGRU output is double this size 14 | word_gru_num_layers: number of layers in word-level GRU 15 | sent_gru_num_layers: number of layers in sentence-level GRU 16 | word_att_dim: dimension of word-level attention layer 17 | sent_att_dim: dimension of sentence-level attention layer 18 | use_layer_norm: whether to use layer normalization 19 | dropout: dropout rate; 0 to not use dropout 20 | """ 21 | super(HierarchicalAttentionNetwork, self).__init__() 22 | 23 | self.sent_attention = SentenceAttention( 24 | vocab_size, embed_dim, word_gru_hidden_dim, sent_gru_hidden_dim, 25 | word_gru_num_layers, sent_gru_num_layers, word_att_dim, sent_att_dim, use_layer_norm, dropout) 26 | 27 | self.fc = nn.Linear(2 * sent_gru_hidden_dim, 1) 28 | self.sig = nn.Sigmoid() 29 | 30 | self.use_layer_nome = use_layer_norm 31 | self.dropout = dropout 32 | 33 | def forward(self, code_tensor): 34 | 35 | code_lengths = [] 36 | sent_lengths = [] 37 | 38 | for file in code_tensor: 39 | code_line = [] 40 | code_lengths.append(len(file)) 41 | for line in file: 42 | code_line.append(len(line)) 43 | sent_lengths.append(code_line) 44 | 45 | code_tensor = code_tensor.type(torch.LongTensor) 46 | code_lengths = torch.tensor(code_lengths).type(torch.LongTensor).cuda() 47 | sent_lengths = torch.tensor(sent_lengths).type(torch.LongTensor).cuda() 48 | 49 | code_embeds, word_att_weights, sent_att_weights, sents = self.sent_attention(code_tensor, code_lengths, sent_lengths) 50 | 51 | scores = self.fc(code_embeds) 52 | final_scrs = self.sig(scores) 53 | 54 | return final_scrs, word_att_weights, sent_att_weights, sents 55 | 56 | class SentenceAttention(nn.Module): 57 | """ 58 | Sentence-level attention module. Contains a word-level attention module. 59 | """ 60 | def __init__(self, vocab_size, embed_dim, word_gru_hidden_dim, sent_gru_hidden_dim, 61 | word_gru_num_layers, sent_gru_num_layers, word_att_dim, sent_att_dim, use_layer_norm, dropout): 62 | super(SentenceAttention, self).__init__() 63 | 64 | # Word-level attention module 65 | self.word_attention = WordAttention(vocab_size, embed_dim, word_gru_hidden_dim, word_gru_num_layers, word_att_dim, use_layer_norm, dropout) 66 | 67 | # Bidirectional sentence-level GRU 68 | self.gru = nn.GRU(2 * word_gru_hidden_dim, sent_gru_hidden_dim, num_layers=sent_gru_num_layers, 69 | batch_first=True, bidirectional=True, dropout=dropout) 70 | 71 | self.use_layer_norm = use_layer_norm 72 | if use_layer_norm: 73 | self.layer_norm = nn.LayerNorm(2 * sent_gru_hidden_dim, elementwise_affine=True) 74 | self.dropout = nn.Dropout(dropout) 75 | 76 | # Sentence-level attention 77 | self.sent_attention = nn.Linear(2 * sent_gru_hidden_dim, sent_att_dim) 78 | 79 | # Sentence context vector u_s to take dot product with 80 | # This is equivalent to taking that dot product (Eq.10 in the paper), 81 | # as u_s is the linear layer's 1D parameter vector here 82 | self.sentence_context_vector = nn.Linear(sent_att_dim, 1, bias=False) 83 | 84 | def forward(self, code_tensor, code_lengths, sent_lengths): 85 | 86 | # Sort code_tensor by decreasing order in length 87 | code_lengths, code_perm_idx = code_lengths.sort(dim=0, descending=True) 88 | code_tensor = code_tensor[code_perm_idx] 89 | sent_lengths = sent_lengths[code_perm_idx] 90 | 91 | # Make a long batch of sentences by removing pad-sentences 92 | # i.e. `code_tensor` was of size (num_code_tensor, padded_code_lengths, padded_sent_length) 93 | # -> `packed_sents.data` is now of size (num_sents, padded_sent_length) 94 | packed_sents = pack_padded_sequence(code_tensor, lengths=code_lengths.tolist(), batch_first=True) 95 | 96 | # effective batch size at each timestep 97 | valid_bsz = packed_sents.batch_sizes 98 | 99 | # Make a long batch of sentence lengths by removing pad-sentences 100 | # i.e. `sent_lengths` was of size (num_code_tensor, padded_code_lengths) 101 | # -> `packed_sent_lengths.data` is now of size (num_sents) 102 | packed_sent_lengths = pack_padded_sequence(sent_lengths, lengths=code_lengths.tolist(), batch_first=True) 103 | 104 | 105 | 106 | # Word attention module 107 | sents, word_att_weights = self.word_attention(packed_sents.data, packed_sent_lengths.data) 108 | 109 | sents = self.dropout(sents) 110 | 111 | # Sentence-level GRU over sentence embeddings 112 | packed_sents, _ = self.gru(PackedSequence(sents, valid_bsz)) 113 | 114 | if self.use_layer_norm: 115 | normed_sents = self.layer_norm(packed_sents.data) 116 | else: 117 | normed_sents = packed_sents 118 | 119 | # Sentence attention 120 | att = torch.tanh(self.sent_attention(normed_sents)) 121 | att = self.sentence_context_vector(att).squeeze(1) 122 | 123 | val = att.max() 124 | att = torch.exp(att - val) 125 | 126 | # Restore as documents by repadding 127 | att, _ = pad_packed_sequence(PackedSequence(att, valid_bsz), batch_first=True) 128 | 129 | sent_att_weights = att / torch.sum(att, dim=1, keepdim=True) 130 | 131 | # Restore as documents by repadding 132 | code_tensor, _ = pad_packed_sequence(packed_sents, batch_first=True) 133 | 134 | # Compute document vectors 135 | code_tensor = code_tensor * sent_att_weights.unsqueeze(2) 136 | code_tensor = code_tensor.sum(dim=1) 137 | 138 | # Restore as documents by repadding 139 | word_att_weights, _ = pad_packed_sequence(PackedSequence(word_att_weights, valid_bsz), batch_first=True) 140 | 141 | # Restore the original order of documents (undo the first sorting) 142 | _, code_tensor_unperm_idx = code_perm_idx.sort(dim=0, descending=False) 143 | code_tensor = code_tensor[code_tensor_unperm_idx] 144 | 145 | word_att_weights = word_att_weights[code_tensor_unperm_idx] 146 | sent_att_weights = sent_att_weights[code_tensor_unperm_idx] 147 | 148 | return code_tensor, word_att_weights, sent_att_weights, sents 149 | 150 | 151 | class WordAttention(nn.Module): 152 | """ 153 | Word-level attention module. 154 | """ 155 | 156 | def __init__(self, vocab_size, embed_dim, gru_hidden_dim, gru_num_layers, att_dim, use_layer_norm, dropout): 157 | super(WordAttention, self).__init__() 158 | 159 | self.embeddings = nn.Embedding(vocab_size, embed_dim) 160 | 161 | # output (batch, hidden_size) 162 | self.gru = nn.GRU(embed_dim, gru_hidden_dim, num_layers=gru_num_layers, batch_first=True, bidirectional=True, dropout=dropout) 163 | 164 | self.use_layer_norm = use_layer_norm 165 | if use_layer_norm: 166 | self.layer_norm = nn.LayerNorm(2 * gru_hidden_dim, elementwise_affine=True) 167 | self.dropout = nn.Dropout(dropout) 168 | 169 | # Maps gru output to `att_dim` sized tensor 170 | self.attention = nn.Linear(2 * gru_hidden_dim, att_dim) 171 | 172 | # Word context vector (u_w) to take dot-product with 173 | self.context_vector = nn.Linear(att_dim, 1, bias=False) 174 | 175 | def init_embeddings(self, embeddings): 176 | """ 177 | Initialized embedding layer with pretrained embeddings. 178 | embeddings: embeddings to init with 179 | """ 180 | self.embeddings.weight = nn.Parameter(embeddings) 181 | 182 | def freeze_embeddings(self, freeze=False): 183 | """ 184 | Set whether to freeze pretrained embeddings. 185 | """ 186 | self.embeddings.weight.requires_grad = freeze 187 | 188 | def forward(self, sents, sent_lengths): 189 | """ 190 | sents: encoded sentence-level data; LongTensor (num_sents, pad_len, embed_dim) 191 | return: sentence embeddings, attention weights of words 192 | """ 193 | # Sort sents by decreasing order in sentence lengths 194 | sent_lengths, sent_perm_idx = sent_lengths.sort(dim=0, descending=True) 195 | sents = sents[sent_perm_idx] 196 | 197 | sents = self.embeddings(sents.cuda()) 198 | 199 | packed_words = pack_padded_sequence(sents, lengths=sent_lengths.tolist(), batch_first=True) 200 | 201 | # effective batch size at each timestep 202 | valid_bsz = packed_words.batch_sizes 203 | 204 | # Apply word-level GRU over word embeddings 205 | packed_words, _ = self.gru(packed_words) 206 | 207 | if self.use_layer_norm: 208 | normed_words = self.layer_norm(packed_words.data) 209 | else: 210 | normed_words = packed_words 211 | 212 | # Word Attenton 213 | att = torch.tanh(self.attention(normed_words.data)) 214 | att = self.context_vector(att).squeeze(1) 215 | 216 | val = att.max() 217 | att = torch.exp(att - val) # att.size: (n_words) 218 | 219 | # Restore as sentences by repadding 220 | att, _ = pad_packed_sequence(PackedSequence(att, valid_bsz), batch_first=True) 221 | 222 | att_weights = att / torch.sum(att, dim=1, keepdim=True) 223 | 224 | # Restore as sentences by repadding 225 | sents, _ = pad_packed_sequence(packed_words, batch_first=True) 226 | 227 | # Compute sentence vectors 228 | sents = sents * att_weights.unsqueeze(2) 229 | sents = sents.sum(dim=1) 230 | 231 | # Restore the original order of sentences (undo the first sorting) 232 | _, sent_unperm_idx = sent_perm_idx.sort(dim=0, descending=False) 233 | sents = sents[sent_unperm_idx] 234 | 235 | att_weights = att_weights[sent_unperm_idx] 236 | 237 | return sents, att_weights -------------------------------------------------------------------------------- /script/export_data_for_line_level_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | from my_util import * 7 | 8 | base_data_dir = '../datasets/preprocessed_data/' 9 | base_original_data_dir = '../datasets/original/File-level/' 10 | 11 | data_for_ngram_dir = '../datasets/n_gram_data/' 12 | data_for_error_prone_dir = '../datasets/ErrorProne_data/' 13 | 14 | proj_names = list(all_train_releases.keys()) 15 | 16 | def export_df_to_files(data_df, code_file_dir, line_file_dir): 17 | 18 | for filename, df in tqdm(data_df.groupby('filename')): 19 | 20 | code_lines = list(df['code_line']) 21 | code_str = '\n'.join(code_lines) 22 | code_str = code_str.lower() 23 | line_num = list(df['line_number']) 24 | line_num = [str(l) for l in line_num] 25 | 26 | code_filename = filename.replace('/','_').replace('.java','')+'.txt' 27 | line_filename = filename.replace('/','_').replace('.java','')+'_line_num.txt' 28 | 29 | with open(code_file_dir+code_filename,'w') as f: 30 | f.write(code_str) 31 | 32 | with open(line_file_dir+line_filename, 'w') as f: 33 | f.write('\n'.join(line_num)) 34 | 35 | def export_ngram_data_each_release(release, is_train = False): 36 | 37 | file_dir = data_for_ngram_dir+release+'/' 38 | file_src_dir = file_dir+'src/' 39 | file_line_num_dir = file_dir+'line_num/' 40 | 41 | if not os.path.exists(file_src_dir): 42 | os.makedirs(file_src_dir) 43 | 44 | if not os.path.exists(file_line_num_dir): 45 | os.makedirs(file_line_num_dir) 46 | 47 | data_df = pd.read_csv(base_data_dir+release+'.csv', encoding='latin') 48 | 49 | # get clean files for training only 50 | if is_train: 51 | data_df = data_df[(data_df['is_test_file']==False) & (data_df['is_blank']==False) & (data_df['file-label']==False)] 52 | # get defective files for prediction only 53 | else: 54 | data_df = data_df[(data_df['is_test_file']==False) & (data_df['is_blank']==False) & (data_df['file-label']==True)] 55 | 56 | data_df = data_df.fillna('') 57 | 58 | export_df_to_files(data_df, file_src_dir, file_line_num_dir) 59 | 60 | def export_data_all_releases(proj_name): 61 | train_rel = all_train_releases[proj_name] 62 | eval_rels = all_eval_releases[proj_name] 63 | 64 | export_ngram_data_each_release(train_rel, True) 65 | 66 | for rel in eval_rels: 67 | export_ngram_data_each_release(rel, False) 68 | # break 69 | 70 | def export_ngram_data_all_projs(): 71 | for proj in proj_names: 72 | export_data_all_releases(proj) 73 | print('finish',proj) 74 | 75 | def export_errorprone_data(proj_name): 76 | cur_eval_rels = all_eval_releases[proj_name][1:] 77 | 78 | for rel in cur_eval_rels: 79 | 80 | save_dir = data_for_error_prone_dir+rel+'/' 81 | 82 | if not os.path.exists(save_dir): 83 | os.makedirs(save_dir) 84 | 85 | data_df = pd.read_csv(base_original_data_dir+rel+'_ground-truth-files_dataset.csv', encoding='latin') 86 | 87 | data_df = data_df[data_df['Bug']==True] 88 | 89 | for filename, df in data_df.groupby('File'): 90 | 91 | if 'test' in filename or '.java' not in filename: 92 | continue 93 | 94 | filename = filename.replace('/','_') 95 | 96 | code = list(df['SRC'])[0].strip() 97 | 98 | with open(save_dir+filename,'w') as f: 99 | f.write(code) 100 | 101 | print('finish release',rel) 102 | 103 | 104 | def export_error_prone_data_all_projs(): 105 | for proj in proj_names: 106 | export_errorprone_data(proj) 107 | print('finish',proj) 108 | 109 | export_ngram_data_all_projs() 110 | export_error_prone_data_all_projs() -------------------------------------------------------------------------------- /script/file-level-baseline/Bi-LSTM-baseline.py: -------------------------------------------------------------------------------- 1 | import os, sys, argparse, re 2 | 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch 6 | from torch.utils.data import DataLoader, TensorDataset 7 | from torch.autograd import Variable 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from gensim.models import Word2Vec 12 | 13 | from tqdm import tqdm 14 | 15 | from baseline_util import * 16 | 17 | # for importing file from previous directory 18 | sys.path.append('../') 19 | 20 | from my_util import * 21 | 22 | arg = argparse.ArgumentParser() 23 | arg.add_argument('-dataset',type=str, default='activemq', help='software project name (lowercase)') 24 | arg.add_argument('-epochs', type=int, default=40) 25 | arg.add_argument('-target_epochs', type=str, default='6', help='which epoch of model to load') 26 | arg.add_argument('-exp_name',type=str,default='') 27 | arg.add_argument('-train',action='store_true') 28 | arg.add_argument('-predict',action='store_true') 29 | 30 | args = arg.parse_args() 31 | 32 | torch.manual_seed(0) 33 | 34 | # model parameters 35 | 36 | batch_size = 32 37 | embed_dim = 50 38 | hidden_dim = 64 39 | lr = 0.001 40 | epochs = args.epochs 41 | 42 | exp_name = args.exp_name 43 | 44 | save_every_epochs = 1 45 | 46 | max_seq_len = 50 47 | 48 | 49 | save_model_dir = '../../output/model/Bi-LSTM/' 50 | save_prediction_dir = '../../output/prediction/Bi-LSTM/' 51 | 52 | if not os.path.exists(save_prediction_dir): 53 | os.makedirs(save_prediction_dir) 54 | 55 | class LSTMClassifier(nn.Module): 56 | def __init__(self, batch_size, hidden_size, vocab_size, embedding_length): 57 | super(LSTMClassifier, self).__init__() 58 | 59 | """ 60 | Arguments 61 | --------- 62 | batch_size : Size of the batch 63 | hidden_sie : Size of the hidden_state of the LSTM 64 | vocab_size : Size of the vocabulary containing unique words 65 | embedding_length : Embeddding dimension word embeddings 66 | """ 67 | 68 | self.batch_size = batch_size 69 | self.hidden_size = hidden_size 70 | self.vocab_size = vocab_size 71 | self.embedding_length = embedding_length 72 | 73 | # Initializing the look-up table. 74 | self.word_embeddings = nn.Embedding(vocab_size, embedding_length) 75 | self.lstm = nn.LSTM(embedding_length, hidden_size,bidirectional=True) 76 | 77 | # dropout layer 78 | self.dropout = nn.Dropout(0.2) 79 | 80 | # linear and sigmoid layer 81 | self.fc = nn.Linear(hidden_dim, 1) 82 | self.sig = nn.Sigmoid() 83 | 84 | def forward(self, input_tensor): 85 | 86 | """ 87 | Parameters 88 | ---------- 89 | input_sentence: input_sentence of shape = (batch_size, num_sequences) 90 | 91 | Returns 92 | ------- 93 | Output of the linear layer containing logits for positive & negative class which receives its input as the final_hidden_state of the LSTM 94 | 95 | """ 96 | 97 | # embedded input of shape = (batch_size, num_sequences, embedding_length) 98 | input = self.word_embeddings(input_tensor.type(torch.LongTensor).cuda()) 99 | 100 | # input.size() = (num_sequences, batch_size, embedding_length) 101 | input = input.permute(1, 0, 2) 102 | h_0 = Variable(torch.zeros(2, self.batch_size, self.hidden_size).cuda()) # Initialize hidden state of the LSTM 103 | c_0 = Variable(torch.zeros(2, self.batch_size, self.hidden_size).cuda()) # Initialize cell state of the LSTM 104 | 105 | lstm_out, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0)) 106 | 107 | output = self.fc(final_hidden_state[-1]) # the last hidden state is output of lstm model 108 | 109 | sig_out = self.sig(output) 110 | 111 | return sig_out 112 | 113 | def train_model(dataset_name): 114 | 115 | loss_dir = '../../output/loss/Bi-LSTM/' 116 | actual_save_model_dir = save_model_dir+dataset_name+'/' 117 | 118 | if not exp_name == '': 119 | actual_save_model_dir = actual_save_model_dir+exp_name+'/' 120 | loss_dir = loss_dir + exp_name 121 | 122 | if not os.path.exists(actual_save_model_dir): 123 | os.makedirs(actual_save_model_dir) 124 | 125 | if not os.path.exists(loss_dir): 126 | os.makedirs(loss_dir) 127 | 128 | w2v_dir = get_w2v_path() 129 | w2v_dir = os.path.join('../'+w2v_dir,dataset_name+'-'+str(embed_dim)+'dim.bin') 130 | 131 | train_rel = all_train_releases[dataset_name] 132 | valid_rel = all_eval_releases[dataset_name][0] 133 | 134 | train_df = get_df(train_rel, is_baseline=True) 135 | 136 | valid_df = get_df(valid_rel, is_baseline=True) 137 | 138 | train_code, train_label = prepare_data(train_df, to_lowercase = True) 139 | valid_code, valid_label = prepare_data(valid_df, to_lowercase = True) 140 | 141 | word2vec_model = Word2Vec.load(w2v_dir) 142 | 143 | padding_idx = word2vec_model.wv.vocab[''].index 144 | 145 | vocab_size = len(word2vec_model.wv.vocab)+1 146 | 147 | train_dl = get_dataloader(word2vec_model, train_code,train_label, padding_idx, batch_size) 148 | valid_dl = get_dataloader(word2vec_model, valid_code,valid_label, padding_idx, batch_size) 149 | 150 | net = LSTMClassifier(batch_size, hidden_dim, vocab_size, embed_dim) 151 | 152 | net = net.cuda() 153 | 154 | optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, net.parameters()), lr=lr) 155 | criterion = nn.BCELoss() 156 | 157 | checkpoint_files = os.listdir(actual_save_model_dir) 158 | 159 | if '.ipynb_checkpoints' in checkpoint_files: 160 | checkpoint_files.remove('.ipynb_checkpoints') 161 | 162 | total_checkpoints = len(checkpoint_files) 163 | 164 | # no model is trained 165 | if total_checkpoints == 0: 166 | 167 | current_checkpoint_num = 1 168 | 169 | train_loss_all_epochs = [] 170 | val_loss_all_epochs = [] 171 | 172 | else: 173 | checkpoint_nums = [int(re.findall('\d+',s)[0]) for s in checkpoint_files] 174 | current_checkpoint_num = max(checkpoint_nums) 175 | 176 | checkpoint = torch.load(actual_save_model_dir+'checkpoint_'+str(current_checkpoint_num)+'epochs.pth') 177 | net.load_state_dict(checkpoint['model_state_dict']) 178 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 179 | 180 | loss_df = pd.read_csv(loss_dir+dataset_name+'-Bi-LSTM-loss_record.csv') 181 | train_loss_all_epochs = list(loss_df['train_loss']) 182 | val_loss_all_epochs = list(loss_df['valid_loss']) 183 | 184 | current_checkpoint_num = current_checkpoint_num+1 # go to next epoch 185 | 186 | print('train model from epoch',current_checkpoint_num) 187 | 188 | clip=5 # gradient clipping 189 | 190 | print('training model of',dataset_name) 191 | 192 | for e in tqdm(range(current_checkpoint_num,epochs+1)): 193 | 194 | train_losses = [] 195 | val_losses = [] 196 | 197 | net.train() 198 | 199 | for inputs, labels in train_dl: 200 | 201 | inputs, labels = inputs.cuda(), labels.cuda() 202 | 203 | net.zero_grad() 204 | 205 | output = net(inputs) 206 | 207 | # calculate the loss and perform backprop 208 | loss = criterion(output, labels.reshape(-1,1).float()) 209 | 210 | train_losses.append(loss.item()) 211 | 212 | loss.backward() 213 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 214 | nn.utils.clip_grad_norm_(net.parameters(), clip) 215 | optimizer.step() 216 | 217 | train_loss_all_epochs.append(np.mean(train_losses)) 218 | 219 | with torch.no_grad(): 220 | 221 | net.eval() 222 | 223 | for inputs, labels in valid_dl: 224 | 225 | inputs, labels = inputs.cuda(), labels.cuda() 226 | output = net(inputs) 227 | 228 | val_loss = criterion(output, labels.reshape(batch_size,1).float()) 229 | 230 | val_losses.append(val_loss.item()) 231 | 232 | val_loss_all_epochs.append(np.mean(val_losses)) 233 | 234 | if e % save_every_epochs == 0: 235 | torch.save({ 236 | 'epoch': e, 237 | 'model_state_dict': net.state_dict(), 238 | 'optimizer_state_dict': optimizer.state_dict() 239 | }, 240 | actual_save_model_dir+'checkpoint_'+str(e)+'epochs.pth') 241 | 242 | loss_df = pd.DataFrame() 243 | loss_df['epoch'] = np.arange(1,len(train_loss_all_epochs)+1) 244 | loss_df['train_loss'] = train_loss_all_epochs 245 | loss_df['valid_loss'] = val_loss_all_epochs 246 | 247 | loss_df.to_csv(loss_dir+dataset_name+'-Bi-LSTM-loss_record.csv',index=False) 248 | 249 | print('finished training model of',dataset_name) 250 | 251 | 252 | # target_epochs (int): which epoch to load model 253 | def predict_defective_files_in_releases(dataset_name, target_epochs = 6): 254 | actual_save_model_dir = save_model_dir+dataset_name+'/' 255 | 256 | w2v_dir = get_w2v_path() 257 | w2v_dir = os.path.join('../'+w2v_dir,dataset_name+'-'+str(embed_dim)+'dim.bin') 258 | 259 | train_rel = all_train_releases[dataset_name] 260 | eval_rels = all_eval_releases[dataset_name][1:] 261 | 262 | word2vec_model = Word2Vec.load(w2v_dir) 263 | 264 | vocab_size = len(word2vec_model.wv.vocab) + 1 265 | 266 | net = LSTMClassifier(1, hidden_dim, vocab_size, embed_dim) 267 | 268 | checkpoint = torch.load(actual_save_model_dir+'checkpoint_'+target_epochs+'epochs.pth') 269 | 270 | net.load_state_dict(checkpoint['model_state_dict']) 271 | 272 | net = net.cuda() 273 | 274 | net.eval() 275 | 276 | 277 | for rel in eval_rels: 278 | row_list = [] 279 | 280 | test_df = get_df(rel, is_baseline=True) 281 | 282 | for filename, df in tqdm(test_df.groupby('filename')): 283 | 284 | file_label = bool(df['file-label'].unique()) 285 | 286 | code = list(df['code_line']) 287 | 288 | code_str = get_code_str(code, True) 289 | code_list = [code_str] 290 | 291 | code_vec = get_code_vec(code_list, word2vec_model) 292 | 293 | code_tensor = torch.tensor(code_vec) 294 | 295 | output = net(code_tensor) 296 | file_prob = output.item() 297 | prediction = bool(round(file_prob)) 298 | 299 | row_dict = { 300 | 'project': dataset_name, 301 | 'train': train_rel, 302 | 'test': rel, 303 | 'filename': filename, 304 | 'file-level-ground-truth': file_label, 305 | 'prediction-prob': file_prob, 306 | 'prediction-label': prediction 307 | } 308 | 309 | row_list.append(row_dict) 310 | 311 | 312 | df = pd.DataFrame(row_list) 313 | df.to_csv(save_prediction_dir+rel+'-'+target_epochs+'-epochs.csv', index=False) 314 | 315 | print('finished',rel) 316 | 317 | proj_name = args.dataset 318 | 319 | if args.train: 320 | train_model(proj_name) 321 | 322 | if args.predict: 323 | target_epochs = args.target_epochs 324 | predict_defective_files_in_releases(proj_name, target_epochs) 325 | -------------------------------------------------------------------------------- /script/file-level-baseline/BoW-baseline.py: -------------------------------------------------------------------------------- 1 | import re, os, pickle, warnings, sys, argparse 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from sklearn.feature_extraction.text import CountVectorizer 7 | from sklearn.linear_model import LogisticRegression 8 | 9 | from baseline_util import * 10 | 11 | sys.path.append('../') 12 | 13 | from my_util import * 14 | 15 | from imblearn.over_sampling import SMOTE 16 | 17 | warnings.filterwarnings('ignore') 18 | 19 | arg = argparse.ArgumentParser() 20 | arg.add_argument('-dataset',type=str, default='activemq', help='software project name (lowercase)') 21 | arg.add_argument('-train',action='store_true') 22 | arg.add_argument('-predict',action='store_true') 23 | 24 | args = arg.parse_args() 25 | 26 | save_model_dir = '../../output/model/BoW/' 27 | save_prediction_dir = '../../output/prediction/BoW/' 28 | 29 | if not os.path.exists(save_model_dir): 30 | os.makedirs(save_model_dir) 31 | 32 | if not os.path.exists(save_prediction_dir): 33 | os.makedirs(save_prediction_dir) 34 | 35 | 36 | # train_release is str 37 | def train_model(dataset_name): 38 | train_rel = all_train_releases[dataset_name] 39 | train_df = get_df(train_rel, is_baseline=True) 40 | 41 | train_code, train_label = prepare_data(train_df, True) 42 | 43 | vectorizer = CountVectorizer() 44 | vectorizer.fit(train_code) 45 | X = vectorizer.transform(train_code).toarray() 46 | train_feature = pd.DataFrame(X) 47 | Y = np.array([1 if label == True else 0 for label in train_label]) 48 | 49 | sm = SMOTE(random_state=42) 50 | X_res, y_res = sm.fit_resample(train_feature, Y) 51 | 52 | clf = LogisticRegression(solver='liblinear') 53 | clf.fit(X_res, y_res) 54 | 55 | pickle.dump(clf,open(save_model_dir+re.sub('-.*','',train_rel)+"-BoW-model.bin",'wb')) 56 | pickle.dump(vectorizer,open(save_model_dir+re.sub('-.*','',train_rel)+"-vectorizer.bin",'wb')) 57 | 58 | print('finished training model for',dataset_name) 59 | 60 | count_vec_df = pd.DataFrame(X) 61 | count_vec_df.columns = vectorizer.get_feature_names() 62 | 63 | 64 | # test_release is str 65 | def predict_defective_files_in_releases(dataset_name): 66 | train_release = all_train_releases[dataset_name] 67 | eval_releases = all_eval_releases[dataset_name][1:] 68 | 69 | clf = pickle.load(open(save_model_dir+re.sub('-.*','',train_release)+"-BoW-model.bin",'rb')) 70 | vectorizer = pickle.load(open(save_model_dir+re.sub('-.*','',train_release)+"-vectorizer.bin",'rb')) 71 | 72 | for rel in eval_releases: 73 | 74 | test_df = get_df(rel,is_baseline=True) 75 | 76 | test_code, train_label = prepare_data(test_df, True) 77 | 78 | X = vectorizer.transform(test_code).toarray() 79 | 80 | Y_pred = list(map(bool,list(clf.predict(X)))) 81 | Y_prob = clf.predict_proba(X) 82 | Y_prob = list(Y_prob[:,1]) 83 | 84 | result_df = pd.DataFrame() 85 | result_df['project'] = [dataset_name]*len(Y_pred) 86 | result_df['train'] = [train_release]*len(Y_pred) 87 | result_df['test'] = [rel]*len(Y_pred) 88 | result_df['file-level-ground-truth'] = train_label 89 | result_df['prediction-prob'] = Y_prob 90 | result_df['prediction-label'] = Y_pred 91 | 92 | result_df.to_csv(save_prediction_dir+rel+'.csv', index=False) 93 | 94 | print('finish',rel) 95 | 96 | proj_name = args.dataset 97 | 98 | if args.train: 99 | train_model(proj_name) 100 | 101 | if args.predict: 102 | predict_defective_files_in_releases(proj_name) 103 | -------------------------------------------------------------------------------- /script/file-level-baseline/CNN-baseline.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.optim as optim 4 | from torch.nn import functional as F 5 | 6 | import pandas as pd 7 | import os, sys ,argparse, re 8 | 9 | from gensim.models import Word2Vec 10 | 11 | from tqdm import tqdm 12 | 13 | from baseline_util import * 14 | 15 | sys.path.append('../') 16 | 17 | from my_util import * 18 | 19 | arg = argparse.ArgumentParser() 20 | arg.add_argument('-dataset',type=str, default='activemq', help='software project name (lowercase)') 21 | arg.add_argument('-epochs', type=int, default=30) 22 | arg.add_argument('-target_epochs', type=str, default='6', help='which epoch of model to load') 23 | arg.add_argument('-exp_name',type=str,default='') 24 | arg.add_argument('-train',action='store_true') 25 | arg.add_argument('-predict',action='store_true') 26 | 27 | args = arg.parse_args() 28 | 29 | 30 | # model parameters 31 | 32 | batch_size = 32 33 | embed_dim = 50 34 | n_filters = 100 # default is 100 35 | lr = 0.001 36 | 37 | epochs = args.epochs 38 | save_every_epochs = 1 39 | max_seq_len = 50 # number of tokens of each line in a file 40 | max_train_LOC = 900 # max length of all code in the whole dataset 41 | 42 | exp_name = args.exp_name 43 | 44 | save_model_dir = '../../output/model/CNN/' 45 | save_prediction_dir = '../../output/prediction/CNN/' 46 | 47 | if not os.path.exists(save_prediction_dir): 48 | os.makedirs(save_prediction_dir) 49 | 50 | class CNN(nn.Module): 51 | def __init__(self, batch_size, in_channels, out_channels, keep_probab, vocab_size, embedding_dim): 52 | super(CNN, self).__init__() 53 | ''' 54 | Arguments 55 | --------- 56 | batch_size : Size of each batch 57 | in_channels : Number of input channels. Here it is 1 as the input data has dimension = (batch_size, num_seq, embedding_length) 58 | out_channels : Number of output channels after convolution operation performed on the input matrix 59 | keep_probab : Probability of retaining an activation node during dropout operation 60 | vocab_size : Size of the vocabulary containing unique words 61 | ''' 62 | self.batch_size = batch_size 63 | self.in_channels = in_channels 64 | self.out_channels = out_channels 65 | self.vocab_size = vocab_size 66 | self.embedding_length = embedding_dim 67 | 68 | self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) 69 | 70 | self.conv = nn.Conv2d(1, 100, (5, embedding_dim)) 71 | 72 | 73 | self.dropout = nn.Dropout(keep_probab) 74 | self.fc = nn.Linear(100, 1) 75 | 76 | self.sig = nn.Sigmoid() 77 | 78 | def conv_block(self, input, conv_layer): 79 | 80 | conv_out = F.relu(conv_layer(input)) 81 | conv_out = torch.squeeze(conv_out,-1) 82 | max_out = F.max_pool1d(conv_out, conv_out.size()[2]) 83 | 84 | return max_out 85 | 86 | def forward(self, input_tensor): 87 | ''' 88 | Parameters 89 | ---------- 90 | input_tensor: input_tensor of shape = (batch_size, num_sequences) 91 | batch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1) 92 | 93 | Returns 94 | ------- 95 | Output of the linear layer containing logits for pos & neg class. 96 | 97 | ''' 98 | # input.size() = (batch_size, num_seq, embedding_length) 99 | input = self.word_embeddings(input_tensor.type(torch.LongTensor).cuda()) 100 | 101 | # input.size() = (batch_size, 1, num_seq, embedding_length) 102 | input = input.unsqueeze(1) 103 | 104 | 105 | max_out = self.conv_block(input, self.conv) 106 | max_out = max_out.view(max_out.size(0),-1) 107 | 108 | fc_in = self.dropout(max_out) 109 | 110 | logits = self.fc(fc_in) 111 | sig_out = self.sig(logits) 112 | 113 | return sig_out 114 | 115 | 116 | def train_model(dataset_name): 117 | 118 | loss_dir = '../../output/loss/CNN/' 119 | actual_save_model_dir = save_model_dir+dataset_name+'/' 120 | 121 | if not exp_name == '': 122 | actual_save_model_dir = actual_save_model_dir+exp_name+'/' 123 | loss_dir = loss_dir + exp_name 124 | 125 | if not os.path.exists(actual_save_model_dir): 126 | os.makedirs(actual_save_model_dir) 127 | 128 | if not os.path.exists(loss_dir): 129 | os.makedirs(loss_dir) 130 | 131 | w2v_dir = get_w2v_path() 132 | 133 | w2v_dir = os.path.join('../'+w2v_dir,dataset_name+'-'+str(embed_dim)+'dim.bin') 134 | 135 | train_rel = all_train_releases[dataset_name] 136 | valid_rel = all_eval_releases[dataset_name][0] 137 | 138 | train_df = get_df(train_rel,is_baseline=True) 139 | 140 | valid_df = get_df(valid_rel,is_baseline=True) 141 | 142 | word2vec_model = Word2Vec.load(w2v_dir) 143 | 144 | vocab_size = len(word2vec_model.wv.vocab) + 1 # for unknown tokens 145 | 146 | train_code, train_label = prepare_data(train_df, to_lowercase = True) 147 | valid_code, valid_label = prepare_data(valid_df, to_lowercase = True) 148 | 149 | word2vec_model = Word2Vec.load(w2v_dir) 150 | 151 | padding_idx = word2vec_model.wv.vocab[''].index 152 | 153 | vocab_size = len(word2vec_model.wv.vocab)+1 154 | 155 | train_dl = get_dataloader(word2vec_model, train_code,train_label, padding_idx) 156 | valid_dl = get_dataloader(word2vec_model, valid_code,valid_label, padding_idx) 157 | 158 | net = CNN(batch_size, 1, n_filters, 0.5, vocab_size, embed_dim) 159 | 160 | net = net.cuda() 161 | 162 | optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, net.parameters()), lr=lr) 163 | criterion = nn.BCELoss() 164 | 165 | checkpoint_files = os.listdir(actual_save_model_dir) 166 | 167 | if '.ipynb_checkpoints' in checkpoint_files: 168 | checkpoint_files.remove('.ipynb_checkpoints') 169 | 170 | total_checkpoints = len(checkpoint_files) 171 | 172 | # no model is trained 173 | if total_checkpoints == 0: 174 | 175 | current_checkpoint_num = 1 176 | 177 | train_loss_all_epochs = [] 178 | val_loss_all_epochs = [] 179 | 180 | 181 | else: 182 | checkpoint_nums = [int(re.findall('\d+',s)[0]) for s in checkpoint_files] 183 | current_checkpoint_num = max(checkpoint_nums) 184 | 185 | checkpoint = torch.load(actual_save_model_dir+'checkpoint_'+str(current_checkpoint_num)+'epochs.pth') 186 | net.load_state_dict(checkpoint['model_state_dict']) 187 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 188 | 189 | loss_df = pd.read_csv(loss_dir+dataset_name+'-Bi-LSTM-loss_record.csv') 190 | train_loss_all_epochs = list(loss_df['train_loss']) 191 | val_loss_all_epochs = list(loss_df['valid_loss']) 192 | 193 | current_checkpoint_num = current_checkpoint_num+1 # go to next epoch 194 | 195 | print('train model from epoch',current_checkpoint_num) 196 | 197 | clip=5 # gradient clipping 198 | 199 | print('training model of',dataset_name) 200 | 201 | for e in tqdm(range(current_checkpoint_num,epochs+1)): 202 | train_losses = [] 203 | val_losses = [] 204 | 205 | net.train() 206 | 207 | # batch loop 208 | for inputs, labels in train_dl: 209 | 210 | inputs, labels = inputs.cuda(), labels.cuda() 211 | net.zero_grad() 212 | 213 | # get the output from the model 214 | output = net(inputs) 215 | 216 | # calculate the loss and perform backprop 217 | loss = criterion(output, labels.reshape(-1,1).float()) 218 | train_losses.append(loss.item()) 219 | 220 | loss.backward() 221 | 222 | # `clip_grad_norm` helps prevent the exploding gradient problem 223 | nn.utils.clip_grad_norm_(net.parameters(), clip) 224 | optimizer.step() 225 | 226 | 227 | train_loss_all_epochs.append(np.mean(train_losses)) 228 | 229 | with torch.no_grad(): 230 | 231 | net.eval() 232 | 233 | for inputs, labels in valid_dl: 234 | 235 | inputs, labels = inputs.cuda(), labels.cuda() 236 | output = net(inputs) 237 | 238 | val_loss = criterion(output, labels.reshape(batch_size,1).float()) 239 | 240 | val_losses.append(val_loss.item()) 241 | 242 | val_loss_all_epochs.append(np.mean(val_losses)) 243 | 244 | if e % save_every_epochs == 0: 245 | torch.save({ 246 | 'epoch': e, 247 | 'model_state_dict': net.state_dict(), 248 | 'optimizer_state_dict': optimizer.state_dict() 249 | }, 250 | actual_save_model_dir+'checkpoint_'+str(e)+'epochs.pth') 251 | 252 | 253 | loss_df = pd.DataFrame() 254 | loss_df['epoch'] = np.arange(1,len(train_loss_all_epochs)+1) 255 | loss_df['train_loss'] = train_loss_all_epochs 256 | loss_df['valid_loss'] = val_loss_all_epochs 257 | 258 | loss_df.to_csv(loss_dir+dataset_name+'-CNN-loss_record.csv',index=False) 259 | 260 | print('finished training model of',dataset_name) 261 | 262 | 263 | # target_epochs (int): which epoch to load model 264 | def predict_defective_files_in_releases(dataset_name, target_epochs = 100): 265 | actual_save_model_dir = save_model_dir+dataset_name+'/' 266 | 267 | w2v_dir = get_w2v_path() 268 | 269 | w2v_dir = os.path.join('../'+w2v_dir,dataset_name+'-'+str(embed_dim)+'dim.bin') 270 | 271 | train_rel = all_train_releases[dataset_name] 272 | eval_rels = all_eval_releases[dataset_name][1:] 273 | 274 | word2vec_model = Word2Vec.load(w2v_dir) 275 | 276 | vocab_size = len(word2vec_model.wv.vocab) + 1 277 | 278 | net = CNN(batch_size, 1, n_filters, 0.5, vocab_size, embed_dim) 279 | 280 | checkpoint = torch.load(actual_save_model_dir+'checkpoint_'+target_epochs+'epochs.pth') 281 | 282 | net.load_state_dict(checkpoint['model_state_dict']) 283 | 284 | net = net.cuda() 285 | 286 | net.eval() 287 | 288 | for rel in eval_rels: 289 | row_list = [] 290 | 291 | test_df = get_df(rel, is_baseline=True) 292 | 293 | for filename, df in tqdm(test_df.groupby('filename')): 294 | 295 | file_label = bool(df['file-label'].unique()) 296 | 297 | code = list(df['code_line']) 298 | 299 | code_str = get_code_str(code, True) 300 | code_list = [code_str] 301 | 302 | code_vec = get_code_vec(code_list, word2vec_model) 303 | 304 | code_tensor = torch.tensor(code_vec) 305 | 306 | output = net(code_tensor) 307 | file_prob = output.item() 308 | prediction = bool(round(file_prob)) 309 | 310 | row_dict = { 311 | 'project': dataset_name, 312 | 'train': train_rel, 313 | 'test': rel, 314 | 'filename': filename, 315 | 'file-level-ground-truth': file_label, 316 | 'prediction-prob': file_prob, 317 | 'prediction-label': prediction 318 | } 319 | 320 | row_list.append(row_dict) 321 | 322 | df = pd.DataFrame(row_list) 323 | df.to_csv(save_prediction_dir+rel+'-'+target_epochs+'-epochs.csv', index=False) 324 | 325 | proj_name = args.dataset 326 | if args.train: 327 | train_model(proj_name) 328 | 329 | if args.predict: 330 | target_epochs = args.target_epochs 331 | predict_defective_files_in_releases(proj_name, target_epochs) 332 | -------------------------------------------------------------------------------- /script/file-level-baseline/DBN-baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os, pickle, sys, argparse 4 | 5 | from sklearn.ensemble import RandomForestClassifier 6 | from sklearn.preprocessing import MinMaxScaler 7 | 8 | 9 | from gensim.models import Word2Vec 10 | 11 | from dbn.models import SupervisedDBNClassification 12 | 13 | from baseline_util import * 14 | 15 | sys.path.append('../') 16 | 17 | from tqdm import tqdm 18 | 19 | from my_util import * 20 | 21 | arg = argparse.ArgumentParser() 22 | arg.add_argument('-dataset',type=str, default='activemq', help='software project name (lowercase)') 23 | arg.add_argument('-exp_name',type=str, default='') 24 | arg.add_argument('-train',action='store_true') 25 | arg.add_argument('-predict',action='store_true') 26 | 27 | args = arg.parse_args() 28 | 29 | 30 | ''' 31 | The model setting is based on the paper "Automatically Learning Semantic Features for Defect Prediction" 32 | code from https://github.com/albertbup/deep-belief-network 33 | ''' 34 | 35 | # model parameter 36 | batch_size = 30 37 | hidden_layers_structure = [100]*10 38 | embed_dim = 50 39 | exp_name = arg.exp_name 40 | 41 | save_model_dir = '../../output/model/DBN/' 42 | save_prediction_dir = '../../output/prediction/DBN/' 43 | 44 | if not os.path.exists(save_prediction_dir): 45 | os.makedirs(save_prediction_dir) 46 | 47 | def convert_to_token_index(w2v_model, code, padding_idx, max_seq_len = None): 48 | codevec = get_code_vec(code, w2v_model) 49 | 50 | if max_seq_len is None: 51 | max_seq_len = min(max([len(cv) for cv in codevec]),45000) 52 | 53 | features = pad_features(codevec, padding_idx, seq_length=max_seq_len) 54 | 55 | return features 56 | 57 | 58 | def train_model(dataset_name): 59 | actual_save_model_dir = save_model_dir+dataset_name+'/' 60 | 61 | if not exp_name == '': 62 | actual_save_model_dir = actual_save_model_dir+exp_name+'/' 63 | 64 | if not os.path.exists(actual_save_model_dir): 65 | os.makedirs(actual_save_model_dir) 66 | 67 | w2v_dir = get_w2v_path() 68 | w2v_dir = os.path.join('../'+w2v_dir,dataset_name+'-'+str(embed_dim)+'dim.bin') 69 | 70 | train_rel = all_train_releases[dataset_name] 71 | 72 | train_df = get_df(train_rel, is_baseline=True) 73 | 74 | train_code, train_label = prepare_data(train_df, to_lowercase = True) 75 | 76 | word2vec_model = Word2Vec.load(w2v_dir) 77 | 78 | padding_idx = word2vec_model.wv.vocab[''].index 79 | 80 | token_idx = convert_to_token_index(word2vec_model, train_code, padding_idx) 81 | 82 | scaler = MinMaxScaler(feature_range=(0,1)) 83 | features = scaler.fit_transform(token_idx) 84 | 85 | dbn_clf = SupervisedDBNClassification(hidden_layers_structure=hidden_layers_structure, 86 | learning_rate_rbm=0.01, 87 | learning_rate=0.1, 88 | n_epochs_rbm=200, # default is 200 89 | n_iter_backprop=200, # default is 200 90 | batch_size=batch_size, 91 | activation_function='sigmoid') 92 | 93 | dbn_clf.fit(features,train_label) 94 | 95 | dbn_features = dbn_clf.transform(features) 96 | 97 | rf_clf = RandomForestClassifier(n_jobs=24) 98 | 99 | rf_clf.fit(dbn_features,train_label) 100 | 101 | pickle.dump(dbn_clf,open(save_model_dir+dataset_name+'-DBN.pkl','wb')) 102 | pickle.dump(rf_clf,open(save_model_dir+dataset_name+'-RF.pkl','wb')) 103 | 104 | print('finished training model of',dataset_name) 105 | 106 | # epoch (int): which epoch to load model 107 | def predict_defective_files_in_releases(dataset_name): 108 | 109 | w2v_dir = get_w2v_path() 110 | w2v_dir = os.path.join('../'+w2v_dir,dataset_name+'-'+str(embed_dim)+'dim.bin') 111 | 112 | train_rel = all_train_releases[dataset_name] 113 | eval_rel = all_eval_releases[dataset_name][1:] 114 | 115 | train_df = get_df(train_rel, is_baseline=True) 116 | 117 | train_code, _ = prepare_data(train_df, to_lowercase = True) 118 | 119 | word2vec_model = Word2Vec.load(w2v_dir) 120 | 121 | train_codevec = get_code_vec(train_code, word2vec_model) 122 | 123 | # find max sequence from training data (for later padding) 124 | max_seq_len = min(max([len(cv) for cv in train_codevec]),45000) 125 | 126 | padding_idx = word2vec_model.wv.vocab[''].index 127 | 128 | token_idx = convert_to_token_index(word2vec_model, train_code, padding_idx) 129 | 130 | dbn_clf = pickle.load(open(save_model_dir+dataset_name+'-DBN.pkl','rb')) 131 | rf_clf = pickle.load(open(save_model_dir+dataset_name+'-RF.pkl','rb')) 132 | 133 | scaler = MinMaxScaler(feature_range=(0,1)) 134 | 135 | 136 | scaler.fit(token_idx) 137 | 138 | for rel in eval_rel: 139 | all_rows = [] 140 | 141 | test_df = get_df(rel, is_baseline=True) 142 | 143 | for filename, df in tqdm(test_df.groupby('filename')): 144 | 145 | file_label = bool(df['file-label'].unique()) 146 | 147 | code = list(df['code_line']) 148 | 149 | code_str = get_code_str(code, True) 150 | code_list = [code_str] 151 | 152 | code_vec = get_code_vec(code_list, word2vec_model) 153 | code_vec = pad_features(code_vec,padding_idx, max_seq_len) 154 | 155 | features = scaler.transform(np.array(code_vec[0]).reshape(1,-1)) 156 | 157 | dbn_features = dbn_clf.transform(features) 158 | 159 | y_pred = bool(rf_clf.predict(dbn_features)) 160 | y_prob = rf_clf.predict_proba(dbn_features) 161 | y_prob = float(y_prob[:,1]) 162 | 163 | row_dict = { 164 | 'project': dataset_name, 165 | 'train': train_rel, 166 | 'test': rel, 167 | 'filename': filename, 168 | 'file-level-ground-truth': file_label, 169 | 'prediction-prob': y_prob, 170 | 'prediction-label': y_pred 171 | } 172 | all_rows.append(row_dict) 173 | 174 | df = pd.DataFrame(all_rows) 175 | df.to_csv(save_prediction_dir+rel+'.csv', index=False) 176 | 177 | print('finished release',rel) 178 | 179 | 180 | proj_name = args.dataset 181 | 182 | if args.train: 183 | train_model(proj_name) 184 | 185 | if args.predict: 186 | predict_defective_files_in_releases(proj_name) 187 | -------------------------------------------------------------------------------- /script/file-level-baseline/baseline_util.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch.utils.data.dataloader import DataLoader 5 | from torch.utils.data.dataset import TensorDataset 6 | 7 | def get_code_str(code, to_lowercase): 8 | ''' 9 | input 10 | code (list): a list of code lines from dataset 11 | to_lowercase (bool) 12 | output 13 | code_str: a code in string format 14 | ''' 15 | 16 | code_str = '\n'.join(code) 17 | 18 | if to_lowercase: 19 | code_str = code_str.lower() 20 | 21 | return code_str 22 | 23 | def prepare_data(df, to_lowercase = False): 24 | ''' 25 | input 26 | df (DataFrame): input data from get_df() function 27 | output 28 | all_code_str (list): a list of source code in string format 29 | all_file_label (list): a list of label 30 | ''' 31 | all_code_str = [] 32 | all_file_label = [] 33 | 34 | for filename, group_df in df.groupby('filename'): 35 | 36 | file_label = bool(group_df['file-label'].unique()) 37 | 38 | code = list(group_df['code_line']) 39 | 40 | code_str = get_code_str(code,to_lowercase) 41 | 42 | all_code_str.append(code_str) 43 | 44 | all_file_label.append(file_label) 45 | 46 | return all_code_str, all_file_label 47 | 48 | def get_code_vec(code, w2v_model): 49 | ''' 50 | input 51 | code (list): a list of code string (from prepare_data_for_LSTM()) 52 | w2v_model (Word2Vec) 53 | output 54 | codevec (list): a list of token index of each file 55 | ''' 56 | codevec = [] 57 | 58 | for c in code: 59 | codevec.append([w2v_model.wv.vocab[word].index if word in w2v_model.wv.vocab else len(w2v_model.wv.vocab) for word in c.split()]) 60 | 61 | return codevec 62 | 63 | def pad_features(codevec, padding_idx, seq_length): 64 | ''' 65 | input 66 | codevec (list): a list from get_code_vec() 67 | padding_idx (int): value used for padding 68 | seq_length (int): max sequence length of each code line 69 | ''' 70 | features = np.zeros((len(codevec), seq_length), dtype=int) 71 | 72 | for i, row in enumerate(codevec): 73 | if len(row) > seq_length: 74 | features[i,:] = row[:seq_length] 75 | else: 76 | features[i, :] = row + [padding_idx]* (seq_length - len(row)) 77 | 78 | return features 79 | 80 | def get_dataloader(w2v_model, code,encoded_labels, padding_idx, batch_size): 81 | ''' 82 | input 83 | w2v_model (Word2Vec) 84 | code (list of string) 85 | encoded_labels (list) 86 | output 87 | dataloader object 88 | 89 | ''' 90 | codevec = get_code_vec(code, w2v_model) 91 | 92 | # to prevent out of memory error 93 | max_seq_len = min(max([len(cv) for cv in codevec]),45000) 94 | 95 | features = pad_features(codevec, padding_idx, seq_length=max_seq_len) 96 | tensor_data = TensorDataset(torch.from_numpy(features), torch.from_numpy(np.array(encoded_labels).astype(int))) 97 | dl = DataLoader(tensor_data, shuffle=True, batch_size=batch_size,drop_last=True) 98 | 99 | return dl -------------------------------------------------------------------------------- /script/file-level-baseline/dbn/activations.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import numpy as np 4 | 5 | 6 | class ActivationFunction(object): 7 | """ 8 | Class for abstract activation function. 9 | """ 10 | __metaclass__ = ABCMeta 11 | 12 | @abstractmethod 13 | def function(self, x): 14 | return 15 | 16 | @abstractmethod 17 | def prime(self, x): 18 | return 19 | 20 | 21 | class SigmoidActivationFunction(ActivationFunction): 22 | @classmethod 23 | def function(cls, x): 24 | """ 25 | Sigmoid function. 26 | :param x: array-like, shape = (n_features, ) 27 | :return: 28 | """ 29 | return 1 / (1.0 + np.exp(-x)) 30 | 31 | @classmethod 32 | def prime(cls, x): 33 | """ 34 | Compute sigmoid first derivative. 35 | :param x: array-like, shape = (n_features, ) 36 | :return: 37 | """ 38 | return x * (1 - x) 39 | 40 | 41 | class ReLUActivationFunction(ActivationFunction): 42 | @classmethod 43 | def function(cls, x): 44 | """ 45 | Rectified linear function. 46 | :param x: array-like, shape = (n_features, ) 47 | :return: 48 | """ 49 | return np.maximum(np.zeros(x.shape), x) 50 | 51 | @classmethod 52 | def prime(cls, x): 53 | """ 54 | Rectified linear first derivative. 55 | :param x: array-like, shape = (n_features, ) 56 | :return: 57 | """ 58 | return (x > 0).astype(int) 59 | 60 | 61 | class TanhActivationFunction(ActivationFunction): 62 | @classmethod 63 | def function(cls, x): 64 | """ 65 | Hyperbolic tangent function. 66 | :param x: array-like, shape = (n_features, ) 67 | :return: 68 | """ 69 | return np.tanh(x) 70 | 71 | @classmethod 72 | def prime(cls, x): 73 | """ 74 | Hyperbolic tangent first derivative. 75 | :param x: array-like, shape = (n_features, ) 76 | :return: 77 | """ 78 | return 1 - x * x -------------------------------------------------------------------------------- /script/file-level-baseline/dbn/models.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import numpy as np 4 | from scipy.stats import truncnorm 5 | from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin, RegressorMixin 6 | 7 | from .activations import SigmoidActivationFunction, ReLUActivationFunction 8 | from .utils import batch_generator 9 | 10 | 11 | class BaseModel(object): 12 | def save(self, save_path): 13 | import pickle 14 | 15 | with open(save_path, 'wb') as fp: 16 | pickle.dump(self, fp) 17 | 18 | @classmethod 19 | def load(cls, load_path): 20 | import pickle 21 | 22 | with open(load_path, 'rb') as fp: 23 | return pickle.load(fp) 24 | 25 | 26 | class BinaryRBM(BaseEstimator, TransformerMixin, BaseModel): 27 | """ 28 | This class implements a Binary Restricted Boltzmann machine. 29 | """ 30 | 31 | def __init__(self, 32 | n_hidden_units=100, 33 | activation_function='sigmoid', 34 | optimization_algorithm='sgd', 35 | learning_rate=1e-3, 36 | n_epochs=10, 37 | contrastive_divergence_iter=1, 38 | batch_size=32, 39 | verbose=True): 40 | self.n_hidden_units = n_hidden_units 41 | self.activation_function = activation_function 42 | self.optimization_algorithm = optimization_algorithm 43 | self.learning_rate = learning_rate 44 | self.n_epochs = n_epochs 45 | self.contrastive_divergence_iter = contrastive_divergence_iter 46 | self.batch_size = batch_size 47 | self.verbose = verbose 48 | 49 | def fit(self, X): 50 | """ 51 | Fit a model given data. 52 | :param X: array-like, shape = (n_samples, n_features) 53 | :return: 54 | """ 55 | # Initialize RBM parameters 56 | self.n_visible_units = X.shape[1] 57 | if self.activation_function == 'sigmoid': 58 | self.W = np.random.randn(self.n_hidden_units, self.n_visible_units) / np.sqrt(self.n_visible_units) 59 | self.c = np.random.randn(self.n_hidden_units) / np.sqrt(self.n_visible_units) 60 | self.b = np.random.randn(self.n_visible_units) / np.sqrt(self.n_visible_units) 61 | self._activation_function_class = SigmoidActivationFunction 62 | elif self.activation_function == 'relu': 63 | self.W = truncnorm.rvs(-0.2, 0.2, size=[self.n_hidden_units, self.n_visible_units]) / np.sqrt( 64 | self.n_visible_units) 65 | self.c = np.full(self.n_hidden_units, 0.1) / np.sqrt(self.n_visible_units) 66 | self.b = np.full(self.n_visible_units, 0.1) / np.sqrt(self.n_visible_units) 67 | self._activation_function_class = ReLUActivationFunction 68 | else: 69 | raise ValueError("Invalid activation function.") 70 | 71 | if self.optimization_algorithm == 'sgd': 72 | self._stochastic_gradient_descent(X) 73 | else: 74 | raise ValueError("Invalid optimization algorithm.") 75 | return self 76 | 77 | def transform(self, X): 78 | """ 79 | Transforms data using the fitted model. 80 | :param X: array-like, shape = (n_samples, n_features) 81 | :return: 82 | """ 83 | if len(X.shape) == 1: # It is a single sample 84 | return self._compute_hidden_units(X) 85 | transformed_data = self._compute_hidden_units_matrix(X) 86 | return transformed_data 87 | 88 | def _reconstruct(self, transformed_data): 89 | """ 90 | Reconstruct visible units given the hidden layer output. 91 | :param transformed_data: array-like, shape = (n_samples, n_features) 92 | :return: 93 | """ 94 | return self._compute_visible_units_matrix(transformed_data) 95 | 96 | def _stochastic_gradient_descent(self, _data): 97 | """ 98 | Performs stochastic gradient descend optimization algorithm. 99 | :param _data: array-like, shape = (n_samples, n_features) 100 | :return: 101 | """ 102 | accum_delta_W = np.zeros(self.W.shape) 103 | accum_delta_b = np.zeros(self.b.shape) 104 | accum_delta_c = np.zeros(self.c.shape) 105 | for iteration in range(1, self.n_epochs + 1): 106 | idx = np.random.permutation(len(_data)) 107 | data = _data[idx] 108 | for batch in batch_generator(self.batch_size, data): 109 | accum_delta_W[:] = .0 110 | accum_delta_b[:] = .0 111 | accum_delta_c[:] = .0 112 | for sample in batch: 113 | delta_W, delta_b, delta_c = self._contrastive_divergence(sample) 114 | accum_delta_W += delta_W 115 | accum_delta_b += delta_b 116 | accum_delta_c += delta_c 117 | self.W += self.learning_rate * (accum_delta_W / self.batch_size) 118 | self.b += self.learning_rate * (accum_delta_b / self.batch_size) 119 | self.c += self.learning_rate * (accum_delta_c / self.batch_size) 120 | if self.verbose: 121 | error = self._compute_reconstruction_error(data) 122 | print(">> Epoch %d finished \tRBM Reconstruction error %f" % (iteration, error)) 123 | 124 | def _contrastive_divergence(self, vector_visible_units): 125 | """ 126 | Computes gradients using Contrastive Divergence method. 127 | :param vector_visible_units: array-like, shape = (n_features, ) 128 | :return: 129 | """ 130 | v_0 = vector_visible_units 131 | v_t = np.array(v_0) 132 | 133 | # Sampling 134 | for t in range(self.contrastive_divergence_iter): 135 | h_t = self._sample_hidden_units(v_t) 136 | v_t = self._compute_visible_units(h_t) 137 | 138 | # Computing deltas 139 | v_k = v_t 140 | h_0 = self._compute_hidden_units(v_0) 141 | h_k = self._compute_hidden_units(v_k) 142 | delta_W = np.outer(h_0, v_0) - np.outer(h_k, v_k) 143 | delta_b = v_0 - v_k 144 | delta_c = h_0 - h_k 145 | 146 | return delta_W, delta_b, delta_c 147 | 148 | def _sample_hidden_units(self, vector_visible_units): 149 | """ 150 | Computes hidden unit activations by sampling from a binomial distribution. 151 | :param vector_visible_units: array-like, shape = (n_features, ) 152 | :return: 153 | """ 154 | hidden_units = self._compute_hidden_units(vector_visible_units) 155 | return (np.random.random_sample(len(hidden_units)) < hidden_units).astype(np.int64) 156 | 157 | def _sample_visible_units(self, vector_hidden_units): 158 | """ 159 | Computes visible unit activations by sampling from a binomial distribution. 160 | :param vector_hidden_units: array-like, shape = (n_features, ) 161 | :return: 162 | """ 163 | visible_units = self._compute_visible_units(vector_hidden_units) 164 | return (np.random.random_sample(len(visible_units)) < visible_units).astype(np.int64) 165 | 166 | def _compute_hidden_units(self, vector_visible_units): 167 | """ 168 | Computes hidden unit outputs. 169 | :param vector_visible_units: array-like, shape = (n_features, ) 170 | :return: 171 | """ 172 | v = np.expand_dims(vector_visible_units, 0) 173 | h = np.squeeze(self._compute_hidden_units_matrix(v)) 174 | return np.array([h]) if not h.shape else h 175 | 176 | def _compute_hidden_units_matrix(self, matrix_visible_units): 177 | """ 178 | Computes hidden unit outputs. 179 | :param matrix_visible_units: array-like, shape = (n_samples, n_features) 180 | :return: 181 | """ 182 | return np.transpose(self._activation_function_class.function( 183 | np.dot(self.W, np.transpose(matrix_visible_units)) + self.c[:, np.newaxis])) 184 | 185 | def _compute_visible_units(self, vector_hidden_units): 186 | """ 187 | Computes visible (or input) unit outputs. 188 | :param vector_hidden_units: array-like, shape = (n_features, ) 189 | :return: 190 | """ 191 | h = np.expand_dims(vector_hidden_units, 0) 192 | v = np.squeeze(self._compute_visible_units_matrix(h)) 193 | return np.array([v]) if not v.shape else v 194 | 195 | def _compute_visible_units_matrix(self, matrix_hidden_units): 196 | """ 197 | Computes visible (or input) unit outputs. 198 | :param matrix_hidden_units: array-like, shape = (n_samples, n_features) 199 | :return: 200 | """ 201 | return self._activation_function_class.function(np.dot(matrix_hidden_units, self.W) + self.b[np.newaxis, :]) 202 | 203 | def _compute_free_energy(self, vector_visible_units): 204 | """ 205 | Computes the RBM free energy. 206 | :param vector_visible_units: array-like, shape = (n_features, ) 207 | :return: 208 | """ 209 | v = vector_visible_units 210 | return - np.dot(self.b, v) - np.sum(np.log(1 + np.exp(np.dot(self.W, v) + self.c))) 211 | 212 | def _compute_reconstruction_error(self, data): 213 | """ 214 | Computes the reconstruction error of the data. 215 | :param data: array-like, shape = (n_samples, n_features) 216 | :return: 217 | """ 218 | data_transformed = self.transform(data) 219 | data_reconstructed = self._reconstruct(data_transformed) 220 | return np.mean(np.sum((data_reconstructed - data) ** 2, 1)) 221 | 222 | 223 | class UnsupervisedDBN(BaseEstimator, TransformerMixin, BaseModel): 224 | """ 225 | This class implements a unsupervised Deep Belief Network. 226 | """ 227 | 228 | def __init__(self, 229 | hidden_layers_structure=[100, 100], 230 | activation_function='sigmoid', 231 | optimization_algorithm='sgd', 232 | learning_rate_rbm=1e-3, 233 | n_epochs_rbm=10, 234 | contrastive_divergence_iter=1, 235 | batch_size=32, 236 | verbose=True): 237 | self.hidden_layers_structure = hidden_layers_structure 238 | self.activation_function = activation_function 239 | self.optimization_algorithm = optimization_algorithm 240 | self.learning_rate_rbm = learning_rate_rbm 241 | self.n_epochs_rbm = n_epochs_rbm 242 | self.contrastive_divergence_iter = contrastive_divergence_iter 243 | self.batch_size = batch_size 244 | self.rbm_layers = None 245 | self.verbose = verbose 246 | self.rbm_class = BinaryRBM 247 | 248 | def fit(self, X, y=None): 249 | """ 250 | Fits a model given data. 251 | :param X: array-like, shape = (n_samples, n_features) 252 | :return: 253 | """ 254 | # Initialize rbm layers 255 | self.rbm_layers = list() 256 | for n_hidden_units in self.hidden_layers_structure: 257 | rbm = self.rbm_class(n_hidden_units=n_hidden_units, 258 | activation_function=self.activation_function, 259 | optimization_algorithm=self.optimization_algorithm, 260 | learning_rate=self.learning_rate_rbm, 261 | n_epochs=self.n_epochs_rbm, 262 | contrastive_divergence_iter=self.contrastive_divergence_iter, 263 | batch_size=self.batch_size, 264 | verbose=self.verbose) 265 | self.rbm_layers.append(rbm) 266 | 267 | # Fit RBM 268 | if self.verbose: 269 | print("[START] Pre-training step:") 270 | input_data = X 271 | for rbm in self.rbm_layers: 272 | rbm.fit(input_data) 273 | input_data = rbm.transform(input_data) 274 | if self.verbose: 275 | print("[END] Pre-training step") 276 | return self 277 | 278 | def transform(self, X): 279 | """ 280 | Transforms data using the fitted model. 281 | :param X: array-like, shape = (n_samples, n_features) 282 | :return: 283 | """ 284 | input_data = X 285 | for rbm in self.rbm_layers: 286 | input_data = rbm.transform(input_data) 287 | return input_data 288 | 289 | 290 | class AbstractSupervisedDBN(BaseEstimator, BaseModel): 291 | """ 292 | Abstract class for supervised Deep Belief Network. 293 | """ 294 | __metaclass__ = ABCMeta 295 | 296 | def __init__(self, 297 | unsupervised_dbn_class, 298 | hidden_layers_structure=[100, 100], 299 | activation_function='sigmoid', 300 | optimization_algorithm='sgd', 301 | learning_rate=1e-3, 302 | learning_rate_rbm=1e-3, 303 | n_iter_backprop=100, 304 | l2_regularization=1.0, 305 | n_epochs_rbm=10, 306 | contrastive_divergence_iter=1, 307 | batch_size=32, 308 | dropout_p=0, # float between 0 and 1. Fraction of the input units to drop 309 | verbose=True): 310 | self.unsupervised_dbn = unsupervised_dbn_class(hidden_layers_structure=hidden_layers_structure, 311 | activation_function=activation_function, 312 | optimization_algorithm=optimization_algorithm, 313 | learning_rate_rbm=learning_rate_rbm, 314 | n_epochs_rbm=n_epochs_rbm, 315 | contrastive_divergence_iter=contrastive_divergence_iter, 316 | batch_size=batch_size, 317 | verbose=verbose) 318 | self.unsupervised_dbn_class = unsupervised_dbn_class 319 | self.n_iter_backprop = n_iter_backprop 320 | self.l2_regularization = l2_regularization 321 | self.learning_rate = learning_rate 322 | self.batch_size = batch_size 323 | self.dropout_p = dropout_p 324 | self.p = 1 - self.dropout_p 325 | self.verbose = verbose 326 | 327 | def fit(self, X, y=None, pre_train=True): 328 | """ 329 | Fits a model given data. 330 | :param X: array-like, shape = (n_samples, n_features) 331 | :param y : array-like, shape = (n_samples, ) 332 | :param pre_train: bool 333 | :return: 334 | """ 335 | if pre_train: 336 | self.pre_train(X) 337 | self._fine_tuning(X, y) 338 | return self 339 | 340 | def predict(self, X): 341 | """ 342 | Predicts the target given data. 343 | :param X: array-like, shape = (n_samples, n_features) 344 | :return: 345 | """ 346 | if len(X.shape) == 1: # It is a single sample 347 | X = np.expand_dims(X, 0) 348 | transformed_data = self.transform(X) 349 | predicted_data = self._compute_output_units_matrix(transformed_data) 350 | return predicted_data 351 | 352 | def pre_train(self, X): 353 | """ 354 | Apply unsupervised network pre-training. 355 | :param X: array-like, shape = (n_samples, n_features) 356 | :return: 357 | """ 358 | self.unsupervised_dbn.fit(X) 359 | return self 360 | 361 | def transform(self, *args): 362 | return self.unsupervised_dbn.transform(*args) 363 | 364 | @abstractmethod 365 | def _transform_labels_to_network_format(self, labels): 366 | return 367 | 368 | @abstractmethod 369 | def _compute_output_units_matrix(self, matrix_visible_units): 370 | return 371 | 372 | @abstractmethod 373 | def _determine_num_output_neurons(self, labels): 374 | return 375 | 376 | @abstractmethod 377 | def _stochastic_gradient_descent(self, data, labels): 378 | return 379 | 380 | @abstractmethod 381 | def _fine_tuning(self, data, _labels): 382 | return 383 | 384 | 385 | class NumPyAbstractSupervisedDBN(AbstractSupervisedDBN): 386 | """ 387 | Abstract class for supervised Deep Belief Network in NumPy 388 | """ 389 | __metaclass__ = ABCMeta 390 | 391 | def __init__(self, **kwargs): 392 | super(NumPyAbstractSupervisedDBN, self).__init__(UnsupervisedDBN, **kwargs) 393 | 394 | def _compute_activations(self, sample): 395 | """ 396 | Compute output values of all layers. 397 | :param sample: array-like, shape = (n_features, ) 398 | :return: 399 | """ 400 | input_data = sample 401 | if self.dropout_p > 0: 402 | r = np.random.binomial(1, self.p, len(input_data)) 403 | input_data *= r 404 | layers_activation = list() 405 | 406 | for rbm in self.unsupervised_dbn.rbm_layers: 407 | input_data = rbm.transform(input_data) 408 | if self.dropout_p > 0: 409 | r = np.random.binomial(1, self.p, len(input_data)) 410 | input_data *= r 411 | layers_activation.append(input_data) 412 | 413 | # Computing activation of output layer 414 | input_data = self._compute_output_units(input_data) 415 | layers_activation.append(input_data) 416 | 417 | return layers_activation 418 | 419 | def _stochastic_gradient_descent(self, _data, _labels): 420 | """ 421 | Performs stochastic gradient descend optimization algorithm. 422 | :param _data: array-like, shape = (n_samples, n_features) 423 | :param _labels: array-like, shape = (n_samples, targets) 424 | :return: 425 | """ 426 | if self.verbose: 427 | matrix_error = np.zeros([len(_data), self.num_classes]) 428 | num_samples = len(_data) 429 | accum_delta_W = [np.zeros(rbm.W.shape) for rbm in self.unsupervised_dbn.rbm_layers] 430 | accum_delta_W.append(np.zeros(self.W.shape)) 431 | accum_delta_bias = [np.zeros(rbm.c.shape) for rbm in self.unsupervised_dbn.rbm_layers] 432 | accum_delta_bias.append(np.zeros(self.b.shape)) 433 | 434 | for iteration in range(1, self.n_iter_backprop + 1): 435 | idx = np.random.permutation(len(_data)) 436 | data = _data[idx] 437 | labels = _labels[idx] 438 | i = 0 439 | for batch_data, batch_labels in batch_generator(self.batch_size, data, labels): 440 | # Clear arrays 441 | for arr1, arr2 in zip(accum_delta_W, accum_delta_bias): 442 | arr1[:], arr2[:] = .0, .0 443 | for sample, label in zip(batch_data, batch_labels): 444 | delta_W, delta_bias, predicted = self._backpropagation(sample, label) 445 | for layer in range(len(self.unsupervised_dbn.rbm_layers) + 1): 446 | accum_delta_W[layer] += delta_W[layer] 447 | accum_delta_bias[layer] += delta_bias[layer] 448 | if self.verbose: 449 | loss = self._compute_loss(predicted, label) 450 | matrix_error[i, :] = loss 451 | i += 1 452 | 453 | layer = 0 454 | for rbm in self.unsupervised_dbn.rbm_layers: 455 | # Updating parameters of hidden layers 456 | rbm.W = (1 - ( 457 | self.learning_rate * self.l2_regularization) / num_samples) * rbm.W - self.learning_rate * ( 458 | accum_delta_W[layer] / self.batch_size) 459 | rbm.c -= self.learning_rate * (accum_delta_bias[layer] / self.batch_size) 460 | layer += 1 461 | # Updating parameters of output layer 462 | self.W = (1 - ( 463 | self.learning_rate * self.l2_regularization) / num_samples) * self.W - self.learning_rate * ( 464 | accum_delta_W[layer] / self.batch_size) 465 | self.b -= self.learning_rate * (accum_delta_bias[layer] / self.batch_size) 466 | 467 | if self.verbose: 468 | error = np.mean(np.sum(matrix_error, 1)) 469 | print(">> Epoch %d finished \tANN training loss %f" % (iteration, error)) 470 | 471 | def _backpropagation(self, input_vector, label): 472 | """ 473 | Performs Backpropagation algorithm for computing gradients. 474 | :param input_vector: array-like, shape = (n_features, ) 475 | :param label: array-like, shape = (n_targets, ) 476 | :return: 477 | """ 478 | x, y = input_vector, label 479 | deltas = list() 480 | list_layer_weights = list() 481 | for rbm in self.unsupervised_dbn.rbm_layers: 482 | list_layer_weights.append(rbm.W) 483 | list_layer_weights.append(self.W) 484 | 485 | # Forward pass 486 | layers_activation = self._compute_activations(input_vector) 487 | 488 | # Backward pass: computing deltas 489 | activation_output_layer = layers_activation[-1] 490 | delta_output_layer = self._compute_output_layer_delta(y, activation_output_layer) 491 | deltas.append(delta_output_layer) 492 | layer_idx = list(range(len(self.unsupervised_dbn.rbm_layers))) 493 | layer_idx.reverse() 494 | delta_previous_layer = delta_output_layer 495 | for layer in layer_idx: 496 | neuron_activations = layers_activation[layer] 497 | W = list_layer_weights[layer + 1] 498 | delta = np.dot(delta_previous_layer, W) * self.unsupervised_dbn.rbm_layers[ 499 | layer]._activation_function_class.prime(neuron_activations) 500 | deltas.append(delta) 501 | delta_previous_layer = delta 502 | deltas.reverse() 503 | 504 | # Computing gradients 505 | layers_activation.pop() 506 | layers_activation.insert(0, input_vector) 507 | layer_gradient_weights, layer_gradient_bias = list(), list() 508 | for layer in range(len(list_layer_weights)): 509 | neuron_activations = layers_activation[layer] 510 | delta = deltas[layer] 511 | gradient_W = np.outer(delta, neuron_activations) 512 | layer_gradient_weights.append(gradient_W) 513 | layer_gradient_bias.append(delta) 514 | 515 | return layer_gradient_weights, layer_gradient_bias, activation_output_layer 516 | 517 | def _fine_tuning(self, data, _labels): 518 | """ 519 | Entry point of the fine tuning procedure. 520 | :param data: array-like, shape = (n_samples, n_features) 521 | :param _labels: array-like, shape = (n_samples, targets) 522 | :return: 523 | """ 524 | self.num_classes = self._determine_num_output_neurons(_labels) 525 | n_hidden_units_previous_layer = self.unsupervised_dbn.rbm_layers[-1].n_hidden_units 526 | self.W = np.random.randn(self.num_classes, n_hidden_units_previous_layer) / np.sqrt( 527 | n_hidden_units_previous_layer) 528 | self.b = np.random.randn(self.num_classes) / np.sqrt(n_hidden_units_previous_layer) 529 | 530 | labels = self._transform_labels_to_network_format(_labels) 531 | 532 | # Scaling up weights obtained from pretraining 533 | for rbm in self.unsupervised_dbn.rbm_layers: 534 | rbm.W /= self.p 535 | rbm.c /= self.p 536 | 537 | if self.verbose: 538 | print("[START] Fine tuning step:") 539 | 540 | if self.unsupervised_dbn.optimization_algorithm == 'sgd': 541 | self._stochastic_gradient_descent(data, labels) 542 | else: 543 | raise ValueError("Invalid optimization algorithm.") 544 | 545 | # Scaling down weights obtained from pretraining 546 | for rbm in self.unsupervised_dbn.rbm_layers: 547 | rbm.W *= self.p 548 | rbm.c *= self.p 549 | 550 | if self.verbose: 551 | print("[END] Fine tuning step") 552 | 553 | @abstractmethod 554 | def _compute_loss(self, predicted, label): 555 | return 556 | 557 | @abstractmethod 558 | def _compute_output_layer_delta(self, label, predicted): 559 | return 560 | 561 | 562 | class SupervisedDBNClassification(NumPyAbstractSupervisedDBN, ClassifierMixin): 563 | """ 564 | This class implements a Deep Belief Network for classification problems. 565 | It appends a Softmax Linear Classifier as output layer. 566 | """ 567 | 568 | def _transform_labels_to_network_format(self, labels): 569 | """ 570 | Converts labels as single integer to row vectors. For instance, given a three class problem, labels would be 571 | mapped as label_1: [1 0 0], label_2: [0 1 0], label_3: [0, 0, 1] where labels can be either int or string. 572 | :param labels: array-like, shape = (n_samples, ) 573 | :return: 574 | """ 575 | new_labels = np.zeros([len(labels), self.num_classes]) 576 | self.label_to_idx_map, self.idx_to_label_map = dict(), dict() 577 | idx = 0 578 | for i, label in enumerate(labels): 579 | if label not in self.label_to_idx_map: 580 | self.label_to_idx_map[label] = idx 581 | self.idx_to_label_map[idx] = label 582 | idx += 1 583 | new_labels[i][self.label_to_idx_map[label]] = 1 584 | return new_labels 585 | 586 | def _transform_network_format_to_labels(self, indexes): 587 | """ 588 | Converts network output to original labels. 589 | :param indexes: array-like, shape = (n_samples, ) 590 | :return: 591 | """ 592 | return list(map(lambda idx: self.idx_to_label_map[idx], indexes)) 593 | 594 | def _compute_output_units(self, vector_visible_units): 595 | """ 596 | Compute activations of output units. 597 | :param vector_visible_units: array-like, shape = (n_features, ) 598 | :return: 599 | """ 600 | v = vector_visible_units 601 | scores = np.dot(self.W, v) + self.b 602 | # get unnormalized probabilities 603 | exp_scores = np.exp(scores) 604 | # normalize them for each example 605 | return exp_scores / np.sum(exp_scores) 606 | 607 | def _compute_output_units_matrix(self, matrix_visible_units): 608 | """ 609 | Compute activations of output units. 610 | :param matrix_visible_units: shape = (n_samples, n_features) 611 | :return: 612 | """ 613 | matrix_scores = np.transpose(np.dot(self.W, np.transpose(matrix_visible_units)) + self.b[:, np.newaxis]) 614 | exp_scores = np.exp(matrix_scores) 615 | return exp_scores / np.expand_dims(np.sum(exp_scores, axis=1), 1) 616 | 617 | def _compute_output_layer_delta(self, label, predicted): 618 | """ 619 | Compute deltas of the output layer, using cross-entropy cost function. 620 | :param label: array-like, shape = (n_features, ) 621 | :param predicted: array-like, shape = (n_features, ) 622 | :return: 623 | """ 624 | dscores = np.array(predicted) 625 | dscores[np.where(label == 1)] -= 1 626 | return dscores 627 | 628 | def predict_proba(self, X): 629 | """ 630 | Predicts probability distribution of classes for each sample in the given data. 631 | :param X: array-like, shape = (n_samples, n_features) 632 | :return: 633 | """ 634 | return super(SupervisedDBNClassification, self).predict(X) 635 | 636 | def predict_proba_dict(self, X): 637 | """ 638 | Predicts probability distribution of classes for each sample in the given data. 639 | Returns a list of dictionaries, one per sample. Each dict contains {label_1: prob_1, ..., label_j: prob_j} 640 | :param X: array-like, shape = (n_samples, n_features) 641 | :return: 642 | """ 643 | if len(X.shape) == 1: # It is a single sample 644 | X = np.expand_dims(X, 0) 645 | 646 | predicted_probs = self.predict_proba(X) 647 | 648 | result = [] 649 | num_of_data, num_of_labels = predicted_probs.shape 650 | for i in range(num_of_data): 651 | # key : label 652 | # value : predicted probability 653 | dict_prob = {} 654 | for j in range(num_of_labels): 655 | dict_prob[self.idx_to_label_map[j]] = predicted_probs[i][j] 656 | result.append(dict_prob) 657 | 658 | return result 659 | 660 | def predict(self, X): 661 | probs = self.predict_proba(X) 662 | indexes = np.argmax(probs, axis=1) 663 | return self._transform_network_format_to_labels(indexes) 664 | 665 | def _determine_num_output_neurons(self, labels): 666 | """ 667 | Given labels, compute the needed number of output units. 668 | :param labels: shape = (n_samples, ) 669 | :return: 670 | """ 671 | return len(np.unique(labels)) 672 | 673 | def _compute_loss(self, probs, label): 674 | """ 675 | Computes categorical cross-entropy loss 676 | :param probs: 677 | :param label: 678 | :return: 679 | """ 680 | return -np.log(probs[np.where(label == 1)]) 681 | 682 | 683 | class SupervisedDBNRegression(NumPyAbstractSupervisedDBN, RegressorMixin): 684 | """ 685 | This class implements a Deep Belief Network for regression problems. 686 | """ 687 | 688 | def _transform_labels_to_network_format(self, labels): 689 | """ 690 | Returns the same labels since regression case does not need to convert anything. 691 | :param labels: array-like, shape = (n_samples, targets) 692 | :return: 693 | """ 694 | return labels 695 | 696 | def _compute_output_units(self, vector_visible_units): 697 | """ 698 | Compute activations of output units. 699 | :param vector_visible_units: array-like, shape = (n_features, ) 700 | :return: 701 | """ 702 | v = vector_visible_units 703 | return np.dot(self.W, v) + self.b 704 | 705 | def _compute_output_units_matrix(self, matrix_visible_units): 706 | """ 707 | Compute activations of output units. 708 | :param matrix_visible_units: shape = (n_samples, n_features) 709 | :return: 710 | """ 711 | return np.transpose(np.dot(self.W, np.transpose(matrix_visible_units)) + self.b[:, np.newaxis]) 712 | 713 | def _compute_output_layer_delta(self, label, predicted): 714 | """ 715 | Compute deltas of the output layer for the regression case, using common (one-half) squared-error cost function. 716 | :param label: array-like, shape = (n_features, ) 717 | :param predicted: array-like, shape = (n_features, ) 718 | :return: 719 | """ 720 | return -(label - predicted) 721 | 722 | def _determine_num_output_neurons(self, labels): 723 | """ 724 | Given labels, compute the needed number of output units. 725 | :param labels: shape = (n_samples, n_targets) 726 | :return: 727 | """ 728 | if len(labels.shape) == 1: 729 | return 1 730 | else: 731 | return labels.shape[1] 732 | 733 | def _compute_loss(self, predicted, label): 734 | """ 735 | Computes Mean squared error loss. 736 | :param predicted: 737 | :param label: 738 | :return: 739 | """ 740 | error = predicted - label 741 | return error * error -------------------------------------------------------------------------------- /script/file-level-baseline/dbn/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def batch_generator(batch_size, data, labels=None): 5 | """ 6 | Generates batches of samples 7 | :param data: array-like, shape = (n_samples, n_features) 8 | :param labels: array-like, shape = (n_samples, ) 9 | :return: 10 | """ 11 | n_batches = int(np.ceil(len(data) / float(batch_size))) 12 | idx = np.random.permutation(len(data)) 13 | data_shuffled = data[idx] 14 | if labels is not None: 15 | labels_shuffled = labels[idx] 16 | for i in range(n_batches): 17 | start = i * batch_size 18 | end = start + batch_size 19 | if labels is not None: 20 | yield data_shuffled[start:end, :], labels_shuffled[start:end] 21 | else: 22 | yield data_shuffled[start:end, :] 23 | 24 | 25 | def to_categorical(labels, num_classes): 26 | """ 27 | Converts labels as single integer to row vectors. For instance, given a three class problem, labels would be 28 | mapped as label_1: [1 0 0], label_2: [0 1 0], label_3: [0, 0, 1] where labels can be either int or string. 29 | :param labels: array-like, shape = (n_samples, ) 30 | :return: 31 | """ 32 | new_labels = np.zeros([len(labels), num_classes]) 33 | label_to_idx_map, idx_to_label_map = dict(), dict() 34 | idx = 0 35 | for i, label in enumerate(labels): 36 | if label not in label_to_idx_map: 37 | label_to_idx_map[label] = idx 38 | idx_to_label_map[idx] = label 39 | idx += 1 40 | new_labels[i][label_to_idx_map[label]] = 1 41 | return new_labels, label_to_idx_map, idx_to_label_map -------------------------------------------------------------------------------- /script/generate_prediction.py: -------------------------------------------------------------------------------- 1 | import os, argparse, pickle 2 | 3 | import pandas as pd 4 | 5 | from gensim.models import Word2Vec 6 | 7 | from tqdm import tqdm 8 | 9 | from DeepLineDP_model import * 10 | from my_util import * 11 | 12 | torch.manual_seed(0) 13 | 14 | arg = argparse.ArgumentParser() 15 | 16 | arg.add_argument('-dataset',type=str, default='activemq', help='software project name (lowercase)') 17 | arg.add_argument('-embed_dim', type=int, default=50, help='word embedding size') 18 | arg.add_argument('-word_gru_hidden_dim', type=int, default=64, help='word attention hidden size') 19 | arg.add_argument('-sent_gru_hidden_dim', type=int, default=64, help='sentence attention hidden size') 20 | arg.add_argument('-word_gru_num_layers', type=int, default=1, help='number of GRU layer at word level') 21 | arg.add_argument('-sent_gru_num_layers', type=int, default=1, help='number of GRU layer at sentence level') 22 | arg.add_argument('-exp_name',type=str,default='') 23 | arg.add_argument('-target_epochs',type=str,default='7', help='the epoch to load model') 24 | arg.add_argument('-dropout', type=float, default=0.2, help='dropout rate') 25 | 26 | args = arg.parse_args() 27 | 28 | weight_dict = {} 29 | 30 | # model setting 31 | max_grad_norm = 5 32 | embed_dim = args.embed_dim 33 | word_gru_hidden_dim = args.word_gru_hidden_dim 34 | sent_gru_hidden_dim = args.sent_gru_hidden_dim 35 | word_gru_num_layers = args.word_gru_num_layers 36 | sent_gru_num_layers = args.sent_gru_num_layers 37 | word_att_dim = 64 38 | sent_att_dim = 64 39 | use_layer_norm = True 40 | dropout = args.dropout 41 | 42 | save_every_epochs = 5 43 | exp_name = args.exp_name 44 | 45 | save_model_dir = '../output/model/DeepLineDP/' 46 | intermediate_output_dir = '../output/intermediate_output/DeepLineDP/within-release/' 47 | prediction_dir = '../output/prediction/DeepLineDP/within-release/' 48 | 49 | file_lvl_gt = '../datasets/preprocessed_data/' 50 | 51 | 52 | if not os.path.exists(prediction_dir): 53 | os.makedirs(prediction_dir) 54 | 55 | def predict_defective_files_in_releases(dataset_name, target_epochs): 56 | 57 | actual_save_model_dir = save_model_dir+dataset_name+'/' 58 | 59 | train_rel = all_train_releases[dataset_name] 60 | test_rel = all_eval_releases[dataset_name][1:] 61 | 62 | w2v_dir = get_w2v_path() 63 | 64 | word2vec_file_dir = os.path.join(w2v_dir,dataset_name+'-'+str(embed_dim)+'dim.bin') 65 | 66 | word2vec = Word2Vec.load(word2vec_file_dir) 67 | print('load Word2Vec for',dataset_name,'finished') 68 | 69 | total_vocab = len(word2vec.wv.vocab) 70 | 71 | vocab_size = total_vocab +1 # for unknown tokens 72 | 73 | model = HierarchicalAttentionNetwork( 74 | vocab_size=vocab_size, 75 | embed_dim=embed_dim, 76 | word_gru_hidden_dim=word_gru_hidden_dim, 77 | sent_gru_hidden_dim=sent_gru_hidden_dim, 78 | word_gru_num_layers=word_gru_num_layers, 79 | sent_gru_num_layers=sent_gru_num_layers, 80 | word_att_dim=word_att_dim, 81 | sent_att_dim=sent_att_dim, 82 | use_layer_norm=use_layer_norm, 83 | dropout=dropout) 84 | 85 | if exp_name == '': 86 | checkpoint = torch.load(actual_save_model_dir+'checkpoint_'+target_epochs+'epochs.pth') 87 | 88 | else: 89 | checkpoint = torch.load(actual_save_model_dir+exp_name+'/checkpoint_'+target_epochs+'epochs.pth') 90 | 91 | model.load_state_dict(checkpoint['model_state_dict']) 92 | 93 | model.sent_attention.word_attention.freeze_embeddings(True) 94 | 95 | model = model.cuda() 96 | model.eval() 97 | 98 | for rel in test_rel: 99 | print('generating prediction of release:', rel) 100 | 101 | actual_intermediate_output_dir = intermediate_output_dir+rel+'/' 102 | 103 | if not os.path.exists(actual_intermediate_output_dir): 104 | os.makedirs(actual_intermediate_output_dir) 105 | 106 | test_df = get_df(rel) 107 | 108 | row_list = [] # for creating dataframe later... 109 | 110 | for filename, df in tqdm(test_df.groupby('filename')): 111 | 112 | file_label = bool(df['file-label'].unique()) 113 | line_label = df['line-label'].tolist() 114 | line_number = df['line_number'].tolist() 115 | is_comments = df['is_comment'].tolist() 116 | 117 | code = df['code_line'].tolist() 118 | 119 | code2d = prepare_code2d(code, True) 120 | 121 | code3d = [code2d] 122 | 123 | codevec = get_x_vec(code3d, word2vec) 124 | 125 | save_file_path = actual_intermediate_output_dir+filename.replace('/','_').replace('.java','')+'_'+target_epochs+'_epochs.pkl' 126 | 127 | if not os.path.exists(save_file_path): 128 | with torch.no_grad(): 129 | codevec_padded_tensor = torch.tensor(codevec) 130 | output, word_att_weights, line_att_weight, _ = model(codevec_padded_tensor) 131 | file_prob = output.item() 132 | prediction = bool(round(output.item())) 133 | 134 | torch.cuda.empty_cache() 135 | 136 | output_dict = { 137 | 'filename': filename, 138 | 'file-label': file_label, 139 | 'prob': file_prob, 140 | 'pred': prediction, 141 | 'word_attention_mat': word_att_weights, 142 | 'line_attention_mat': line_att_weight, 143 | 'line-label': line_label, 144 | 'line-number': line_number 145 | } 146 | 147 | pickle.dump(output_dict, open(save_file_path, 'wb')) 148 | 149 | else: 150 | output_dict = pickle.load(open(save_file_path, 'rb')) 151 | file_prob = output_dict['prob'] 152 | prediction = output_dict['pred'] 153 | word_att_weights = output_dict['word_attention_mat'] 154 | line_att_weight = output_dict['line_attention_mat'] 155 | 156 | numpy_word_attn = word_att_weights[0].cpu().detach().numpy() 157 | numpy_line_attn = line_att_weight[0].cpu().detach().numpy() 158 | 159 | # for each line in source code 160 | for i in range(0,len(code)): 161 | cur_line = code[i] 162 | cur_line_label = line_label[i] 163 | cur_line_number = line_number[i] 164 | cur_is_comment = is_comments[i] 165 | cur_line_attn = numpy_line_attn[i] 166 | 167 | token_list = cur_line.strip().split() 168 | 169 | max_len = min(len(token_list),50) # limit max token each line 170 | 171 | # for each token in a line 172 | for j in range(0,max_len): 173 | tok = token_list[j] 174 | word_attn = numpy_word_attn[i][j] 175 | 176 | row_dict = { 177 | 'project': dataset_name, 178 | 'train': train_rel, 179 | 'test': rel, 180 | 'filename': filename, 181 | 'file-level-ground-truth': file_label, 182 | 'prediction-prob': file_prob, 183 | 'prediction-label': prediction, 184 | 'line-number': cur_line_number, 185 | 'line-level-ground-truth': cur_line_label, 186 | 'is-comment-line': cur_is_comment, 187 | 'token': tok, 188 | 'token-attention-score': word_attn, 189 | 'line-attention-score': cur_line_attn 190 | } 191 | 192 | row_list.append(row_dict) 193 | 194 | df = pd.DataFrame(row_list) 195 | 196 | df.to_csv(prediction_dir+rel+'.csv', index=False) 197 | 198 | print('finished release', rel) 199 | 200 | dataset_name = args.dataset 201 | target_epochs = args.target_epochs 202 | 203 | predict_defective_files_in_releases(dataset_name, target_epochs) -------------------------------------------------------------------------------- /script/generate_prediction_cross_projects.py: -------------------------------------------------------------------------------- 1 | import os, re, argparse 2 | 3 | import pandas as pd 4 | 5 | from gensim.models import Word2Vec 6 | 7 | from tqdm import tqdm 8 | 9 | from DeepLineDP_model import * 10 | from my_util import * 11 | 12 | torch.manual_seed(0) 13 | 14 | 15 | all_eval_rels_cross_projects = { 16 | 'activemq': ['camel-2.10.0', 'camel-2.11.0', 'derby-10.5.1.1', 'groovy-1_6_BETA_2', 'hbase-0.95.2', 'hive-0.12.0', 'jruby-1.5.0', 'jruby-1.7.0.preview1', 'lucene-3.0.0', 'lucene-3.1', 'wicket-1.5.3'], 17 | 'camel': ['activemq-5.2.0', 'activemq-5.3.0', 'activemq-5.8.0', 'derby-10.5.1.1', 'groovy-1_6_BETA_2', 'hbase-0.95.2', 'hive-0.12.0', 'jruby-1.5.0', 'jruby-1.7.0.preview1', 'lucene-3.0.0', 'lucene-3.1', 'wicket-1.5.3'], 18 | 'derby': ['activemq-5.2.0', 'activemq-5.3.0', 'activemq-5.8.0', 'camel-2.10.0', 'camel-2.11.0', 'groovy-1_6_BETA_2', 'hbase-0.95.2', 'hive-0.12.0', 'jruby-1.5.0', 'jruby-1.7.0.preview1', 'lucene-3.0.0', 'lucene-3.1', 'wicket-1.5.3'], 19 | 'groovy': ['activemq-5.2.0', 'activemq-5.3.0', 'activemq-5.8.0', 'camel-2.10.0', 'camel-2.11.0', 'derby-10.5.1.1', 'hbase-0.95.2', 'hive-0.12.0', 'jruby-1.5.0', 'jruby-1.7.0.preview1', 'lucene-3.0.0', 'lucene-3.1', 'wicket-1.5.3'], 20 | 'hbase': ['activemq-5.2.0', 'activemq-5.3.0', 'activemq-5.8.0', 'camel-2.10.0', 'camel-2.11.0', 'derby-10.5.1.1', 'groovy-1_6_BETA_2', 'hive-0.12.0', 'jruby-1.5.0', 'jruby-1.7.0.preview1', 'lucene-3.0.0', 'lucene-3.1', 'wicket-1.5.3'], 21 | 'hive': ['activemq-5.2.0', 'activemq-5.3.0', 'activemq-5.8.0', 'camel-2.10.0', 'camel-2.11.0', 'derby-10.5.1.1', 'groovy-1_6_BETA_2', 'hbase-0.95.2', 'jruby-1.5.0', 'jruby-1.7.0.preview1', 'lucene-3.0.0', 'lucene-3.1', 'wicket-1.5.3'], 22 | 'jruby': ['activemq-5.2.0', 'activemq-5.3.0', 'activemq-5.8.0', 'camel-2.10.0', 'camel-2.11.0', 'derby-10.5.1.1', 'groovy-1_6_BETA_2', 'hbase-0.95.2', 'hive-0.12.0', 'lucene-3.0.0', 'lucene-3.1', 'wicket-1.5.3'], 23 | 'lucene': ['activemq-5.2.0', 'activemq-5.3.0', 'activemq-5.8.0', 'camel-2.10.0', 'camel-2.11.0', 'derby-10.5.1.1', 'groovy-1_6_BETA_2', 'hbase-0.95.2', 'hive-0.12.0', 'jruby-1.5.0', 'jruby-1.7.0.preview1', 'wicket-1.5.3'], 24 | 'wicket': ['activemq-5.2.0', 'activemq-5.3.0', 'activemq-5.8.0', 'camel-2.10.0', 'camel-2.11.0', 'derby-10.5.1.1', 'groovy-1_6_BETA_2', 'hbase-0.95.2', 'hive-0.12.0', 'jruby-1.5.0', 'jruby-1.7.0.preview1', 'lucene-3.0.0', 'lucene-3.1']} 25 | 26 | 27 | arg = argparse.ArgumentParser() 28 | 29 | arg.add_argument('-dataset',type=str, default='activemq', help='software project name (lowercase)') 30 | arg.add_argument('-embed_dim', type=int, default=50, help='word embedding size') 31 | arg.add_argument('-word_gru_hidden_dim', type=int, default=64, help='word attention hidden size') 32 | arg.add_argument('-sent_gru_hidden_dim', type=int, default=64, help='sentence attention hidden size') 33 | arg.add_argument('-word_gru_num_layers', type=int, default=1, help='number of GRU layer at word level') 34 | arg.add_argument('-sent_gru_num_layers', type=int, default=1, help='number of GRU layer at sentence level') 35 | arg.add_argument('-exp_name',type=str,default='') 36 | arg.add_argument('-target_epochs',type=str,default='7') 37 | arg.add_argument('-dropout', type=float, default=0.2, help='dropout rate') 38 | 39 | args = arg.parse_args() 40 | 41 | weight_dict = {} 42 | 43 | # model setting 44 | max_grad_norm = 5 45 | embed_dim = args.embed_dim 46 | word_gru_hidden_dim = args.word_gru_hidden_dim 47 | sent_gru_hidden_dim = args.sent_gru_hidden_dim 48 | word_gru_num_layers = args.word_gru_num_layers 49 | sent_gru_num_layers = args.sent_gru_num_layers 50 | word_att_dim = 64 51 | sent_att_dim = 64 52 | use_layer_norm = True 53 | dropout = args.dropout 54 | 55 | save_every_epochs = 5 56 | exp_name = args.exp_name 57 | 58 | save_model_dir = '../output/model/DeepLineDP/' 59 | 60 | intermediate_output_dir = '../output/intermediate_output/DeepLineDP/cross-project/' 61 | prediction_dir = '../output/prediction/DeepLineDP/cross-project/' 62 | 63 | file_lvl_gt = '../datasets/preprocessed_data/' 64 | 65 | 66 | if not os.path.exists(prediction_dir): 67 | os.makedirs(prediction_dir) 68 | 69 | def predict_defective_files_in_releases(dataset_name, target_epochs): 70 | 71 | 72 | actual_save_model_dir = save_model_dir+dataset_name+'/' 73 | actual_prediction_dir = prediction_dir+dataset_name+'/' 74 | 75 | if not os.path.exists(actual_prediction_dir): 76 | os.makedirs(actual_prediction_dir) 77 | 78 | train_rel = all_train_releases[dataset_name] 79 | test_rel = all_eval_rels_cross_projects[dataset_name] 80 | 81 | w2v_dir = get_w2v_path() 82 | 83 | word2vec_file_dir = os.path.join(w2v_dir,dataset_name+'-'+str(embed_dim)+'dim.bin') 84 | 85 | word2vec = Word2Vec.load(word2vec_file_dir) 86 | print('load Word2Vec for',dataset_name,'finished') 87 | 88 | total_vocab = len(word2vec.wv.vocab) 89 | 90 | vocab_size = total_vocab +1 # for unknown tokens 91 | 92 | 93 | max_sent_len = 999999 94 | 95 | 96 | model = HierarchicalAttentionNetwork( 97 | vocab_size=vocab_size, 98 | embed_dim=embed_dim, 99 | word_gru_hidden_dim=word_gru_hidden_dim, 100 | sent_gru_hidden_dim=sent_gru_hidden_dim, 101 | word_gru_num_layers=word_gru_num_layers, 102 | sent_gru_num_layers=sent_gru_num_layers, 103 | word_att_dim=word_att_dim, 104 | sent_att_dim=sent_att_dim, 105 | use_layer_norm=use_layer_norm, 106 | dropout=dropout) 107 | 108 | if exp_name == '': 109 | checkpoint = torch.load(actual_save_model_dir+'checkpoint_'+target_epochs+'epochs.pth') 110 | 111 | else: 112 | checkpoint = torch.load(actual_save_model_dir+exp_name+'/checkpoint_'+target_epochs+'epochs.pth') 113 | 114 | model.load_state_dict(checkpoint['model_state_dict']) 115 | 116 | model.sent_attention.word_attention.freeze_embeddings(True) 117 | 118 | model = model.cuda() 119 | model.eval() 120 | 121 | for rel in test_rel: 122 | print('using model from {} to generate prediction of {}'.format(train_rel,rel)) 123 | 124 | actual_intermediate_output_dir = intermediate_output_dir+dataset_name+'/'+train_rel+'-'+rel+'/' 125 | 126 | if not os.path.exists(actual_intermediate_output_dir): 127 | os.makedirs(actual_intermediate_output_dir) 128 | 129 | test_df = get_df(rel) 130 | 131 | row_list = [] # for creating dataframe later... 132 | 133 | for filename, df in tqdm(test_df.groupby('filename')): 134 | 135 | file_label = bool(df['file-label'].unique()) 136 | line_label = df['line-label'].tolist() 137 | line_number = df['line_number'].tolist() 138 | is_comments = df['is_comment'].tolist() 139 | 140 | code = df['code_line'].tolist() 141 | 142 | code2d = prepare_code2d(code, True) 143 | 144 | code3d = [code2d] 145 | 146 | codevec = get_x_vec(code3d, word2vec) 147 | codevec_padded = pad_code(codevec,max_sent_len,limit_sent_len=False, mode='test') 148 | 149 | with torch.no_grad(): 150 | codevec_padded_tensor = torch.tensor(codevec_padded) 151 | output, word_att_weights, line_att_weight, _ = model(codevec_padded_tensor) 152 | file_prob = output.item() 153 | prediction = bool(round(output.item())) 154 | 155 | numpy_word_attn = word_att_weights[0].cpu().detach().numpy() 156 | numpy_line_attn = line_att_weight[0].cpu().detach().numpy() 157 | 158 | for i in range(0,len(code)): 159 | cur_line = code[i] 160 | cur_line_label = line_label[i] 161 | cur_line_number = line_number[i] 162 | cur_is_comment = is_comments[i] 163 | cur_line_attn = numpy_line_attn[i] 164 | 165 | token_list = cur_line.strip().split() 166 | 167 | max_len = min(len(token_list),50) # limit max token each line 168 | 169 | for j in range(0,max_len): 170 | tok = token_list[j] 171 | word_attn = numpy_word_attn[i][j] 172 | 173 | row_dict = { 174 | 'project': dataset_name, 175 | 'train': train_rel, 176 | 'test': rel, 177 | 'filename': filename, 178 | 'file-level-ground-truth': file_label, 179 | 'prediction-prob': file_prob, 180 | 'prediction-label': prediction, 181 | 'line-number': cur_line_number, 182 | 'line-level-ground-truth': cur_line_label, 183 | 'is-comment-line': cur_is_comment, 184 | 'token': tok, 185 | 'token-attention-score': word_attn, 186 | 'line-attention-score': cur_line_attn 187 | } 188 | 189 | row_list.append(row_dict) 190 | 191 | df = pd.DataFrame(row_list) 192 | 193 | df.to_csv(actual_prediction_dir+train_rel+'-'+rel+'.csv', index=False) 194 | 195 | print('finished release', rel) 196 | 197 | dataset_name = args.dataset 198 | target_epochs = args.target_epochs 199 | 200 | predict_defective_files_in_releases(dataset_name, target_epochs) -------------------------------------------------------------------------------- /script/get_evaluation_result.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | library(gridExtra) 3 | 4 | library(ModelMetrics) 5 | 6 | library(caret) 7 | 8 | library(reshape2) 9 | library(pROC) 10 | 11 | library(effsize) 12 | library(ScottKnottESD) 13 | 14 | save.fig.dir = '../output/figure/' 15 | 16 | dir.create(file.path(save.fig.dir), showWarnings = FALSE) 17 | 18 | preprocess <- function(x, reverse){ 19 | colnames(x) <- c("variable","value") 20 | tmp <- do.call(cbind, split(x, x$variable)) 21 | tmp <- tmp[, grep("value", names(tmp))] 22 | names(tmp) <- gsub(".value", "", names(tmp)) 23 | df <- tmp 24 | ranking <- NULL 25 | 26 | if(reverse == TRUE) 27 | { 28 | ranking <- (max(sk_esd(df)$group)-sk_esd(df)$group) +1 29 | } 30 | else 31 | { 32 | ranking <- sk_esd(df)$group 33 | } 34 | 35 | x$rank <- paste("Rank",ranking[as.character(x$variable)]) 36 | return(x) 37 | } 38 | 39 | get.top.k.tokens = function(df, k) 40 | { 41 | top.k <- df %>% filter( is.comment.line=="False" & file.level.ground.truth=="True" & prediction.label=="True" ) %>% 42 | group_by(test, filename) %>% top_n(k, token.attention.score) %>% select("project","train","test","filename","token") %>% distinct() 43 | 44 | top.k$flag = 'topk' 45 | 46 | return(top.k) 47 | } 48 | 49 | 50 | prediction_dir = '../output/prediction/DeepLineDP/within-release/' 51 | 52 | all_files = list.files(prediction_dir) 53 | 54 | df_all <- NULL 55 | 56 | for(f in all_files) 57 | { 58 | df <- read.csv(paste0(prediction_dir, f)) 59 | df_all <- rbind(df_all, df) 60 | } 61 | 62 | # ---------------- Code for RQ1 -----------------------# 63 | 64 | #RQ1-1 65 | df.to.plot = df_all %>% filter(is.comment.line=="False" & file.level.ground.truth=="True" & prediction.label=="True") %>% group_by(test, filename,token) %>% 66 | summarise(Range=max(token.attention.score)-min(token.attention.score), SD=sd(token.attention.score)) %>% 67 | melt() 68 | 69 | df.to.plot %>% ggplot(aes(x=variable, y=value)) + geom_boxplot() + scale_y_continuous(breaks=0:4*0.25) + xlab("") + ylab("") 70 | 71 | ggsave(paste0(save.fig.dir,"rq1-1.pdf"),width=2.5,height=2.5) 72 | 73 | 74 | #RQ1-2 75 | 76 | df_all_copy = data.frame(df_all) 77 | 78 | df_all_copy = filter(df_all_copy, is.comment.line=="False" & file.level.ground.truth=="True" & prediction.label=="True") 79 | 80 | clean.lines.df = filter(df_all_copy, line.level.ground.truth=="False") 81 | buggy.line.df = filter(df_all_copy, line.level.ground.truth=="True") 82 | 83 | clean.lines.token.score = clean.lines.df %>% group_by(test, filename, token) %>% summarise(score = min(token.attention.score)) 84 | clean.lines.token.score$class = "Clean Lines" 85 | 86 | buggy.lines.token.score = buggy.line.df %>% group_by(test, filename, token) %>% summarise(score = max(token.attention.score)) 87 | buggy.lines.token.score$class = "Defective Lines" 88 | 89 | all.lines.token.score = rbind(buggy.lines.token.score, clean.lines.token.score) 90 | all.lines.token.score$class = factor(all.lines.token.score$class, levels = c('Defective Lines', 'Clean Lines')) 91 | 92 | all.lines.token.score %>% ggplot(aes(x=class, y=score)) + geom_boxplot() + xlab("") + ylab("Riskiness Score") 93 | ggsave(paste0(save.fig.dir,"rq1-2.pdf"),width=2.5,height=2.5) 94 | 95 | res = cliff.delta(buggy.lines.token.score$score, clean.lines.token.score$score) 96 | 97 | 98 | # ---------------- Code for RQ2 -----------------------# 99 | 100 | get.file.level.metrics = function(df.file) 101 | { 102 | all.gt = df.file$file.level.ground.truth 103 | all.prob = df.file$prediction.prob 104 | all.pred = df.file$prediction.label 105 | 106 | confusion.mat = confusionMatrix(all.pred, reference = all.gt) 107 | 108 | bal.acc = confusion.mat$byClass["Balanced Accuracy"] 109 | AUC = pROC::auc(all.gt, all.prob) 110 | 111 | levels(all.pred)[levels(all.pred)=="False"] = 0 112 | levels(all.pred)[levels(all.pred)=="True"] = 1 113 | levels(all.gt)[levels(all.gt)=="False"] = 0 114 | levels(all.gt)[levels(all.gt)=="True"] = 1 115 | 116 | all.gt = as.numeric_version(all.gt) 117 | all.gt = as.numeric(all.gt) 118 | 119 | all.pred = as.numeric_version(all.pred) 120 | all.pred = as.numeric(all.pred) 121 | 122 | MCC = mcc(all.gt, all.pred, cutoff = 0.5) 123 | 124 | if(is.nan(MCC)) 125 | { 126 | MCC = 0 127 | } 128 | 129 | eval.result = c(AUC, MCC, bal.acc) 130 | 131 | return(eval.result) 132 | } 133 | 134 | get.file.level.eval.result = function(prediction.dir, method.name) 135 | { 136 | all_files = list.files(prediction.dir) 137 | 138 | all.auc = c() 139 | all.mcc = c() 140 | all.bal.acc = c() 141 | all.test.rels = c() 142 | 143 | for(f in all_files) # for looping through files 144 | { 145 | df = read.csv(paste0(prediction.dir, f)) 146 | 147 | if(method.name == "DeepLineDP") 148 | { 149 | df = as_tibble(df) 150 | df = select(df, c(train, test, filename, file.level.ground.truth, prediction.prob, prediction.label)) 151 | 152 | df = distinct(df) 153 | } 154 | 155 | file.level.result = get.file.level.metrics(df) 156 | 157 | AUC = file.level.result[1] 158 | MCC = file.level.result[2] 159 | bal.acc = file.level.result[3] 160 | 161 | all.auc = append(all.auc,AUC) 162 | all.mcc = append(all.mcc,MCC) 163 | all.bal.acc = append(all.bal.acc,bal.acc) 164 | all.test.rels = append(all.test.rels,f) 165 | 166 | } 167 | 168 | result.df = data.frame(all.auc,all.mcc,all.bal.acc) 169 | 170 | 171 | all.test.rels = str_replace(all.test.rels, ".csv", "") 172 | 173 | result.df$release = all.test.rels 174 | result.df$technique = method.name 175 | 176 | return(result.df) 177 | } 178 | 179 | bi.lstm.prediction.dir = "../output/prediction/Bi-LSTM/" 180 | cnn.prediction.dir = "../output/prediction/CNN/" 181 | 182 | dbn.prediction.dir = "../output/prediction/DBN/" 183 | lr.prediction.dir = "../output/prediction/LR/" 184 | 185 | bi.lstm.result = get.file.level.eval.result(bi.lstm.prediction.dir, "Bi.LSTM") 186 | cnn.result = get.file.level.eval.result(cnn.prediction.dir, "CNN") 187 | dbn.result = get.file.level.eval.result(dbn.prediction.dir, "DBN") 188 | lr.result = get.file.level.eval.result(lr.prediction.dir, "LR") 189 | deepline.dp.result = get.file.level.eval.result(prediction_dir, "DeepLineDP") 190 | 191 | all.result = rbind(bi.lstm.result, cnn.result, dbn.result, lr.result, deepline.dp.result) 192 | 193 | names(all.result) = c("AUC","MCC","Balance.Accuracy","Release", "Technique") 194 | 195 | auc.result = select(all.result, c("Technique","AUC")) 196 | auc.result = preprocess(auc.result,FALSE) 197 | auc.result[auc.result$variable=="Bi.LSTM", "variable"] = "Bi-LSTM" 198 | 199 | mcc.result = select(all.result, c("Technique","MCC")) 200 | mcc.result = preprocess(mcc.result,FALSE) 201 | mcc.result[mcc.result$variable=="Bi.LSTM", "variable"] = "Bi-LSTM" 202 | 203 | bal.acc.result = select(all.result, c("Technique","Balance.Accuracy")) 204 | bal.acc.result = preprocess(bal.acc.result,FALSE) 205 | bal.acc.result[bal.acc.result$variable=="Bi.LSTM", "variable"] = "Bi-LSTM" 206 | 207 | ggplot(auc.result, aes(x=reorder(variable, -value, FUN=median), y=value)) + geom_boxplot() + facet_grid(~rank, drop=TRUE, scales = "free", space = "free") + ylab("AUC") + xlab("") 208 | ggsave(paste0(save.fig.dir,"file-AUC.pdf"),width=4,height=2.5) 209 | 210 | ggplot(bal.acc.result, aes(x=reorder(variable, value, FUN=median), y=value)) + geom_boxplot() + facet_grid(~rank, drop=TRUE, scales = "free", space = "free") + ylab("Balance Accuracy") + xlab("") 211 | ggsave(paste0(save.fig.dir,"file-Balance_Accuracy.pdf"),width=4,height=2.5) 212 | 213 | ggplot(mcc.result, aes(x=reorder(variable, value, FUN=median), y=value)) + geom_boxplot() + facet_grid(~rank, drop=TRUE, scales = "free", space = "free") + ylab("MCC") + xlab("") 214 | ggsave(paste0(save.fig.dir, "file-MCC.pdf"),width=4,height=2.5) 215 | 216 | 217 | # ---------------- Code for RQ3 -----------------------# 218 | 219 | ## prepare data for baseline 220 | line.ground.truth = select(df_all, project, train, test, filename, file.level.ground.truth, prediction.prob, line.number, line.level.ground.truth) 221 | line.ground.truth = filter(line.ground.truth, file.level.ground.truth == "True" & prediction.prob >= 0.5) 222 | line.ground.truth = distinct(line.ground.truth) 223 | 224 | get.line.metrics.result = function(baseline.df, cur.df.file) 225 | { 226 | baseline.df.with.ground.truth = merge(baseline.df, cur.df.file, by=c("filename", "line.number")) 227 | 228 | sorted = baseline.df.with.ground.truth %>% group_by(filename) %>% arrange(-line.score, .by_group = TRUE) %>% mutate(order = row_number()) 229 | 230 | #IFA 231 | IFA = sorted %>% filter(line.level.ground.truth == "True") %>% group_by(filename) %>% top_n(1, -order) 232 | 233 | ifa.list = IFA$order 234 | 235 | total_true = sorted %>% group_by(filename) %>% summarize(total_true = sum(line.level.ground.truth == "True")) 236 | 237 | #Recall20%LOC 238 | recall20LOC = sorted %>% group_by(filename) %>% mutate(effort = round(order/n(),digits = 2 )) %>% filter(effort <= 0.2) %>% 239 | summarize(correct_pred = sum(line.level.ground.truth == "True")) %>% 240 | merge(total_true) %>% mutate(recall20LOC = correct_pred/total_true) 241 | 242 | recall.list = recall20LOC$recall20LOC 243 | 244 | #Effort20%Recall 245 | effort20Recall = sorted %>% merge(total_true) %>% group_by(filename) %>% mutate(cummulative_correct_pred = cumsum(line.level.ground.truth == "True"), recall = round(cumsum(line.level.ground.truth == "True")/total_true, digits = 2)) %>% 246 | summarise(effort20Recall = sum(recall <= 0.2)/n()) 247 | 248 | effort.list = effort20Recall$effort20Recall 249 | 250 | result.df = data.frame(ifa.list, recall.list, effort.list) 251 | 252 | return(result.df) 253 | } 254 | 255 | all_eval_releases = c('activemq-5.2.0', 'activemq-5.3.0', 'activemq-5.8.0', 256 | 'camel-2.10.0', 'camel-2.11.0' , 257 | 'derby-10.5.1.1' , 'groovy-1_6_BETA_2' , 'hbase-0.95.2', 258 | 'hive-0.12.0', 'jruby-1.5.0', 'jruby-1.7.0.preview1', 259 | 'lucene-3.0.0', 'lucene-3.1', 'wicket-1.5.3') 260 | 261 | error.prone.result.dir = '../output/ErrorProne_result/' 262 | ngram.result.dir = '../output/n_gram_result/' 263 | 264 | rf.result.dir = '../output/RF-line-level-result/' 265 | 266 | n.gram.result.df = NULL 267 | error.prone.result.df = NULL 268 | rf.result.df = NULL 269 | 270 | ## get result from baseline 271 | for(rel in all_eval_releases) 272 | { 273 | error.prone.result = read.csv(paste0(error.prone.result.dir,rel,'-line-lvl-result.txt'),quote="") 274 | 275 | levels(error.prone.result$EP_prediction_result)[levels(error.prone.result$EP_prediction_result)=="False"] = 0 276 | levels(error.prone.result$EP_prediction_result)[levels(error.prone.result$EP_prediction_result)=="True"] = 1 277 | 278 | error.prone.result$EP_prediction_result = as.numeric(as.numeric_version(error.prone.result$EP_prediction_result)) 279 | 280 | names(error.prone.result) = c("filename","test","line.number","line.score") 281 | 282 | n.gram.result = read.csv(paste0(ngram.result.dir,rel,'-line-lvl-result.txt'), quote = "") 283 | n.gram.result = select(n.gram.result, "filename", "line.number", "line.score") 284 | n.gram.result = distinct(n.gram.result) 285 | names(n.gram.result) = c("filename", "line.number", "line.score") 286 | 287 | rf.result = read.csv(paste0(rf.result.dir,rel,'-line-lvl-result.csv')) 288 | rf.result = select(rf.result, "filename", "line_number","line.score.pred") 289 | names(rf.result) = c("filename", "line.number", "line.score") 290 | 291 | cur.df.file = filter(line.ground.truth, test==rel) 292 | cur.df.file = select(cur.df.file, filename, line.number, line.level.ground.truth) 293 | 294 | n.gram.eval.result = get.line.metrics.result(n.gram.result, cur.df.file) 295 | 296 | error.prone.eval.result = get.line.metrics.result(error.prone.result, cur.df.file) 297 | 298 | rf.eval.result = get.line.metrics.result(rf.result, cur.df.file) 299 | 300 | n.gram.result.df = rbind(n.gram.result.df, n.gram.eval.result) 301 | error.prone.result.df = rbind(error.prone.result.df, error.prone.eval.result) 302 | rf.result.df = rbind(rf.result.df, rf.eval.result) 303 | 304 | print(paste0('finished ', rel)) 305 | 306 | } 307 | 308 | #Force attention score of comment line is 0 309 | df_all[df_all$is.comment.line == "True",]$token.attention.score = 0 310 | 311 | tmp.top.k = get.top.k.tokens(df_all, 1500) 312 | 313 | merged_df_all = merge(df_all, tmp.top.k, by=c('project', 'train', 'test', 'filename', 'token'), all.x = TRUE) 314 | 315 | merged_df_all[is.na(merged_df_all$flag),]$token.attention.score = 0 316 | 317 | ## use top-k tokens 318 | sum_line_attn = merged_df_all %>% filter(file.level.ground.truth == "True" & prediction.label == "True") %>% group_by(test, filename,is.comment.line, file.level.ground.truth, prediction.label, line.number, line.level.ground.truth) %>% 319 | summarize(attention_score = sum(token.attention.score), num_tokens = n()) 320 | 321 | sorted = sum_line_attn %>% group_by(test, filename) %>% arrange(-attention_score, .by_group=TRUE) %>% mutate(order = row_number()) 322 | 323 | ## get result from DeepLineDP 324 | # calculate IFA 325 | IFA = sorted %>% filter(line.level.ground.truth == "True") %>% group_by(test, filename) %>% top_n(1, -order) 326 | 327 | total_true = sorted %>% group_by(test, filename) %>% summarize(total_true = sum(line.level.ground.truth == "True")) 328 | 329 | # calculate Recall20%LOC 330 | recall20LOC = sorted %>% group_by(test, filename) %>% mutate(effort = round(order/n(),digits = 2 )) %>% filter(effort <= 0.2) %>% 331 | summarize(correct_pred = sum(line.level.ground.truth == "True")) %>% 332 | merge(total_true) %>% mutate(recall20LOC = correct_pred/total_true) 333 | 334 | # calculate Effort20%Recall 335 | effort20Recall = sorted %>% merge(total_true) %>% group_by(test, filename) %>% mutate(cummulative_correct_pred = cumsum(line.level.ground.truth == "True"), recall = round(cumsum(line.level.ground.truth == "True")/total_true, digits = 2)) %>% 336 | summarise(effort20Recall = sum(recall <= 0.2)/n()) 337 | 338 | 339 | ## prepare data for plotting 340 | deeplinedp.ifa = IFA$order 341 | deeplinedp.recall = recall20LOC$recall20LOC 342 | deeplinedp.effort = effort20Recall$effort20Recall 343 | 344 | deepline.dp.line.result = data.frame(deeplinedp.ifa, deeplinedp.recall, deeplinedp.effort) 345 | 346 | names(rf.result.df) = c("IFA", "Recall20%LOC", "Effort@20%Recall") 347 | names(n.gram.result.df) = c("IFA", "Recall20%LOC", "Effort@20%Recall") 348 | names(error.prone.result.df) = c("IFA", "Recall20%LOC", "Effort@20%Recall") 349 | names(deepline.dp.line.result) = c("IFA", "Recall20%LOC", "Effort@20%Recall") 350 | 351 | rf.result.df$technique = 'RF' 352 | n.gram.result.df$technique = 'N.gram' 353 | error.prone.result.df$technique = 'ErrorProne' 354 | deepline.dp.line.result$technique = 'DeepLineDP' 355 | 356 | all.line.result = rbind(rf.result.df, n.gram.result.df, error.prone.result.df, deepline.dp.line.result) 357 | 358 | recall.result.df = select(all.line.result, c('technique', 'Recall20%LOC')) 359 | ifa.result.df = select(all.line.result, c('technique', 'IFA')) 360 | effort.result.df = select(all.line.result, c('technique', 'Effort@20%Recall')) 361 | 362 | recall.result.df = preprocess(recall.result.df, FALSE) 363 | ifa.result.df = preprocess(ifa.result.df, TRUE) 364 | effort.result.df = preprocess(effort.result.df, TRUE) 365 | 366 | ggplot(recall.result.df, aes(x=reorder(variable, -value, FUN=median), y=value)) + geom_boxplot() + facet_grid(~rank, drop=TRUE, scales = "free", space = "free") + ylab("Recall@Top20%LOC") + xlab("") 367 | ggsave(paste0(save.fig.dir,"file-Recall@Top20LOC.pdf"),width=4,height=2.5) 368 | 369 | ggplot(effort.result.df, aes(x=reorder(variable, value, FUN=median), y=value)) + geom_boxplot() + facet_grid(~rank, drop=TRUE, scales = "free", space = "free") + ylab("Effort@Top20%Recall") + xlab("") 370 | ggsave(paste0(save.fig.dir,"file-Effort@Top20Recall.pdf"),width=4,height=2.5) 371 | 372 | ggplot(ifa.result.df, aes(x=reorder(variable, value, FUN=median), y=value)) + geom_boxplot() + coord_cartesian(ylim=c(0,175)) + facet_grid(~rank, drop=TRUE, scales = "free", space = "free") + ylab("IFA") + xlab("") 373 | ggsave(paste0(save.fig.dir, "file-IFA.pdf"),width=4,height=2.5) 374 | 375 | 376 | # ---------------- Code for RQ4 -----------------------# 377 | 378 | ## get within-project result 379 | deepline.dp.result$project = c("activemq", "activemq", "activemq", "camel", "camel", "derby", "groovy", "hbase", "hive", "jruby", "jruby", "lucene", "lucene", "wicket") 380 | 381 | file.level.by.project = deepline.dp.result %>% group_by(project) %>% summarise(mean.AUC = mean(all.auc), mean.MCC = mean(all.mcc), mean.bal.acc = mean(all.bal.acc)) 382 | 383 | names(file.level.by.project) = c("project", "AUC", "MCC", "Balance Accurracy") 384 | 385 | IFA$project = str_replace(IFA$test, '-.*','') 386 | recall20LOC$project = str_replace(recall20LOC$test, '-.*','') 387 | recall20LOC$project = as.factor(recall20LOC$project) 388 | effort20Recall$project = str_replace(effort20Recall$test, '-.*','') 389 | 390 | ifa.each.project = IFA %>% group_by(project) %>% summarise(mean.by.project = mean(order)) 391 | recall.each.project = recall20LOC %>% group_by(project) %>% summarise(mean.by.project = mean(recall20LOC)) 392 | effort.each.project = effort20Recall %>% group_by(project) %>% summarise(mean.by.project = mean(effort20Recall)) 393 | 394 | line.level.all.mean.by.project = data.frame(ifa.each.project$project, ifa.each.project$mean.by.project, recall.each.project$mean.by.project, effort.each.project$mean.by.project) 395 | 396 | names(line.level.all.mean.by.project) = c("project", "IFA", "Recall20%LOC", "Effort@20%Recall") 397 | 398 | 399 | ## get cross-project result 400 | 401 | prediction.dir = '../output/prediction/DeepLineDP/cross-release/' 402 | 403 | projs = c('activemq', 'camel', 'derby', 'groovy', 'hbase', 'hive', 'jruby', 'lucene', 'wicket') 404 | 405 | 406 | get.line.level.metrics = function(df_all) 407 | { 408 | #Force attention score of comment line is 0 409 | df_all[df_all$is.comment.line == "True",]$token.attention.score = 0 410 | 411 | sum_line_attn = df_all %>% filter(file.level.ground.truth == "True" & prediction.label == "True") %>% group_by(filename,is.comment.line, file.level.ground.truth, prediction.label, line.number, line.level.ground.truth) %>% 412 | summarize(attention_score = sum(token.attention.score), num_tokens = n()) 413 | sorted = sum_line_attn %>% group_by(filename) %>% arrange(-attention_score, .by_group=TRUE) %>% mutate(order = row_number()) 414 | 415 | # calculate IFA 416 | IFA = sorted %>% filter(line.level.ground.truth == "True") %>% group_by(filename) %>% top_n(1, -order) 417 | total_true = sorted %>% group_by(filename) %>% summarize(total_true = sum(line.level.ground.truth == "True")) 418 | 419 | # calculate Recall20%LOC 420 | recall20LOC = sorted %>% group_by(filename) %>% mutate(effort = round(order/n(),digits = 2 )) %>% filter(effort <= 0.2) %>% 421 | summarize(correct_pred = sum(line.level.ground.truth == "True")) %>% 422 | merge(total_true) %>% mutate(recall20LOC = correct_pred/total_true) 423 | 424 | # calculate Effort20%Recall 425 | effort20Recall = sorted %>% merge(total_true) %>% group_by(filename) %>% mutate(cummulative_correct_pred = cumsum(line.level.ground.truth == "True"), recall = round(cumsum(line.level.ground.truth == "True")/total_true, digits = 2)) %>% 426 | summarise(effort20Recall = sum(recall <= 0.2)/n()) 427 | 428 | all.ifa = IFA$order 429 | all.recall = recall20LOC$recall20LOC 430 | all.effort = effort20Recall$effort20Recall 431 | 432 | result.df = data.frame(all.ifa, all.recall, all.effort) 433 | 434 | return(result.df) 435 | } 436 | 437 | 438 | all.line.result = NULL 439 | all.file.result = NULL 440 | 441 | 442 | for(p in projs) 443 | { 444 | actual.pred.dir = paste0(prediction.dir,p,'/') 445 | 446 | all.files = list.files(actual.pred.dir) 447 | 448 | all.auc = c() 449 | all.mcc = c() 450 | all.bal.acc = c() 451 | all.src.projs = c() 452 | all.tar.projs = c() 453 | 454 | for(f in all.files) 455 | { 456 | df = read.csv(paste0(actual.pred.dir,f)) 457 | 458 | f = str_replace(f,'.csv','') 459 | f.split = unlist(strsplit(f,'-')) 460 | target = tail(f.split,2)[1] 461 | 462 | df = as_tibble(df) 463 | 464 | df.file = select(df, c(train, test, filename, file.level.ground.truth, prediction.prob, prediction.label)) 465 | 466 | df.file = distinct(df.file) 467 | 468 | file.level.result = get.file.level.metrics(df.file) 469 | 470 | AUC = file.level.result[1] 471 | MCC = file.level.result[2] 472 | bal.acc = file.level.result[3] 473 | 474 | all.auc = append(all.auc, AUC) 475 | all.mcc = append(all.mcc, MCC) 476 | all.bal.acc = append(all.bal.acc, bal.acc) 477 | 478 | all.src.projs = append(all.src.projs, p) 479 | all.tar.projs = append(all.tar.projs,target) 480 | 481 | tmp.top.k = get.top.k.tokens(df, 1500) 482 | 483 | merged_df_all = merge(df, tmp.top.k, by=c('project', 'train', 'test', 'filename', 'token'), all.x = TRUE) 484 | 485 | merged_df_all[is.na(merged_df_all$flag),]$token.attention.score = 0 486 | 487 | line.level.result = get.line.level.metrics(merged_df_all) 488 | line.level.result$src = p 489 | line.level.result$target = target 490 | 491 | all.line.result = rbind(all.line.result, line.level.result) 492 | 493 | print(paste0('finished ',f)) 494 | 495 | } 496 | 497 | file.level.result = data.frame(all.auc,all.mcc,all.bal.acc) 498 | file.level.result$src = p 499 | file.level.result$target = all.tar.projs 500 | 501 | all.file.result = rbind(all.file.result, file.level.result) 502 | 503 | print(paste0('finished ',p)) 504 | 505 | } 506 | 507 | final.file.level.result = all.file.result %>% group_by(target) %>% summarize(auc = mean(all.auc), balance_acc = mean(all.bal.acc), mcc = mean(all.mcc)) 508 | 509 | final.line.level.result = all.line.result %>% group_by(target) %>% summarize(recall = mean(all.recall), effort = mean(all.effort), ifa = mean(all.ifa)) 510 | 511 | -------------------------------------------------------------------------------- /script/line-level-baseline/ErrorProne/dataflow-shaded-3.1.2.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awsm-research/DeepLineDP/517c7e230409220ef0d6f07f4cb19c569a5d05e5/script/line-level-baseline/ErrorProne/dataflow-shaded-3.1.2.jar -------------------------------------------------------------------------------- /script/line-level-baseline/ErrorProne/error_prone_core-2.4.0-with-dependencies.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awsm-research/DeepLineDP/517c7e230409220ef0d6f07f4cb19c569a5d05e5/script/line-level-baseline/ErrorProne/error_prone_core-2.4.0-with-dependencies.jar -------------------------------------------------------------------------------- /script/line-level-baseline/ErrorProne/jFormatString-3.0.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awsm-research/DeepLineDP/517c7e230409220ef0d6f07f4cb19c569a5d05e5/script/line-level-baseline/ErrorProne/jFormatString-3.0.0.jar -------------------------------------------------------------------------------- /script/line-level-baseline/ErrorProne/javac-9+181-r4173-1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awsm-research/DeepLineDP/517c7e230409220ef0d6f07f4cb19c569a5d05e5/script/line-level-baseline/ErrorProne/javac-9+181-r4173-1.jar -------------------------------------------------------------------------------- /script/line-level-baseline/ErrorProne/run_ErrorProne.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "source": [ 7 | "import pandas as pd\n", 8 | "import numpy as np\n", 9 | "import subprocess, re, os, time\n", 10 | "\n", 11 | "from multiprocessing import Pool\n", 12 | "\n", 13 | "from tqdm import tqdm" 14 | ], 15 | "outputs": [], 16 | "metadata": {} 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 7, 21 | "source": [ 22 | "all_eval_releases = ['activemq-5.2.0','activemq-5.3.0','activemq-5.8.0',\n", 23 | " 'camel-2.10.0','camel-2.11.0', \n", 24 | " 'derby-10.5.1.1',\n", 25 | " 'groovy-1_6_BETA_2', \n", 26 | " 'hbase-0.95.2',\n", 27 | " 'hive-0.12.0', \n", 28 | " 'jruby-1.5.0','jruby-1.7.0.preview1',\n", 29 | " 'lucene-3.0.0','lucene-3.1', \n", 30 | " 'wicket-1.5.3']\n", 31 | "\n", 32 | "all_dataset_name = ['activemq','camel','derby','groovy','hbase','hive','jruby','lucene','wicket']\n", 33 | "\n", 34 | "base_file_dir = './ErrorProne_data/'\n", 35 | "base_command = \"javac -J-Xbootclasspath/p:javac-9+181-r4173-1.jar -XDcompilePolicy=simple -processorpath error_prone_core-2.4.0-with-dependencies.jar:dataflow-shaded-3.1.2.jar:jFormatString-3.0.0.jar '-Xplugin:ErrorProne -XepDisableAllChecks -Xep:CollectionIncompatibleType:ERROR' \"\n", 36 | "\n", 37 | "result_dir = './ErrorProne_result/'\n", 38 | "\n", 39 | "if not os.path.exists(result_dir):\n", 40 | " os.makedirs(result_dir)\n", 41 | " " 42 | ], 43 | "outputs": [], 44 | "metadata": {} 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 10, 49 | "source": [ 50 | "def run_ErrorProne(rel):\n", 51 | " df_list = []\n", 52 | " java_file_dir = base_file_dir+rel+'/'\n", 53 | "\n", 54 | " file_list = os.listdir(java_file_dir)\n", 55 | " \n", 56 | " for java_filename in tqdm(file_list): \n", 57 | " f = open(java_file_dir+java_filename,'r',encoding='utf-8',errors='ignore')\n", 58 | " java_code = f.readlines()\n", 59 | "\n", 60 | " code_len = len(java_code)\n", 61 | "\n", 62 | " output = subprocess.getoutput(base_command+java_file_dir+java_filename)\n", 63 | "\n", 64 | " reported_lines = re.findall('\\d+: error:',output)\n", 65 | " reported_lines = [int(l.replace(':','').replace('error','')) for l in reported_lines]\n", 66 | " reported_lines = list(set(reported_lines))\n", 67 | "\n", 68 | " line_df = pd.DataFrame()\n", 69 | "\n", 70 | " line_df['filename'] = [java_filename.replace('_','/')]*code_len\n", 71 | " line_df['test-release'] = [rel]*len(line_df)\n", 72 | " line_df['line_number'] = np.arange(1,code_len+1)\n", 73 | " line_df['EP_prediction_result'] = line_df['line_number'].isin(reported_lines)\n", 74 | "\n", 75 | " df_list.append(line_df)\n", 76 | "\n", 77 | " final_df = pd.concat(df_list)\n", 78 | " final_df.to_csv(result_dir+rel+'-line-lvl-result.txt',index=False)\n", 79 | " print('finished',rel)" 80 | ], 81 | "outputs": [], 82 | "metadata": {} 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 11, 87 | "source": [ 88 | "agents = 5\n", 89 | "chunksize = 8\n", 90 | "\n", 91 | "with Pool(processes=agents) as pool:\n", 92 | " pool.map(run_ErrorProne, all_eval_releases, chunksize)" 93 | ], 94 | "outputs": [ 95 | { 96 | "output_type": "stream", 97 | "name": "stderr", 98 | "text": [ 99 | "100%|██████████| 25/25 [00:59<00:00, 2.36s/it]]\n", 100 | " 35%|███▍ | 27/78 [00:59<01:52, 2.21s/it]" 101 | ] 102 | }, 103 | { 104 | "output_type": "stream", 105 | "name": "stdout", 106 | "text": [ 107 | "finished jruby-1.5.0\n" 108 | ] 109 | }, 110 | { 111 | "output_type": "stream", 112 | "name": "stderr", 113 | "text": [ 114 | "100%|██████████| 38/38 [01:29<00:00, 2.34s/it]]\n", 115 | " 86%|████████▌ | 67/78 [02:28<00:24, 2.24s/it]" 116 | ] 117 | }, 118 | { 119 | "output_type": "stream", 120 | "name": "stdout", 121 | "text": [ 122 | "finished jruby-1.7.0.preview1\n" 123 | ] 124 | }, 125 | { 126 | "output_type": "stream", 127 | "name": "stderr", 128 | "text": [ 129 | "100%|██████████| 78/78 [02:52<00:00, 2.21s/it]]\n", 130 | " 58%|█████▊ | 76/130 [02:52<02:17, 2.54s/it]" 131 | ] 132 | }, 133 | { 134 | "output_type": "stream", 135 | "name": "stdout", 136 | "text": [ 137 | "finished activemq-5.1.0\n" 138 | ] 139 | }, 140 | { 141 | "output_type": "stream", 142 | "name": "stderr", 143 | "text": [ 144 | "100%|██████████| 41/41 [01:30<00:00, 2.21s/it]t]\n" 145 | ] 146 | }, 147 | { 148 | "output_type": "stream", 149 | "name": "stdout", 150 | "text": [ 151 | "finished activemq-5.2.0\n" 152 | ] 153 | }, 154 | { 155 | "output_type": "stream", 156 | "name": "stderr", 157 | "text": [ 158 | "100%|██████████| 62/62 [02:14<00:00, 2.17s/it]t]\n" 159 | ] 160 | }, 161 | { 162 | "output_type": "stream", 163 | "name": "stdout", 164 | "text": [ 165 | "finished lucene-2.9.0\n" 166 | ] 167 | }, 168 | { 169 | "output_type": "stream", 170 | "name": "stderr", 171 | "text": [ 172 | "100%|██████████| 130/130 [04:58<00:00, 2.29s/it]\n" 173 | ] 174 | }, 175 | { 176 | "output_type": "stream", 177 | "name": "stdout", 178 | "text": [ 179 | "finished derby-10.5.1.1\n" 180 | ] 181 | }, 182 | { 183 | "output_type": "stream", 184 | "name": "stderr", 185 | "text": [ 186 | "100%|██████████| 30/30 [01:11<00:00, 2.37s/it]]\n", 187 | " 46%|████▌ | 47/103 [01:47<02:09, 2.32s/it]" 188 | ] 189 | }, 190 | { 191 | "output_type": "stream", 192 | "name": "stdout", 193 | "text": [ 194 | "finished groovy-1_6_BETA_1\n" 195 | ] 196 | }, 197 | { 198 | "output_type": "stream", 199 | "name": "stderr", 200 | "text": [ 201 | "100%|██████████| 55/55 [02:00<00:00, 2.19s/it]]\n", 202 | " 45%|████▌ | 14/31 [00:33<00:38, 2.28s/it]" 203 | ] 204 | }, 205 | { 206 | "output_type": "stream", 207 | "name": "stdout", 208 | "text": [ 209 | "finished lucene-3.0.0\n" 210 | ] 211 | }, 212 | { 213 | "output_type": "stream", 214 | "name": "stderr", 215 | "text": [ 216 | "100%|██████████| 31/31 [01:12<00:00, 2.33s/it]]\n" 217 | ] 218 | }, 219 | { 220 | "output_type": "stream", 221 | "name": "stdout", 222 | "text": [ 223 | "finished groovy-1_6_BETA_2\n" 224 | ] 225 | }, 226 | { 227 | "output_type": "stream", 228 | "name": "stderr", 229 | "text": [ 230 | "100%|██████████| 38/38 [01:23<00:00, 2.19s/it]]\n" 231 | ] 232 | }, 233 | { 234 | "output_type": "stream", 235 | "name": "stdout", 236 | "text": [ 237 | "finished lucene-3.1\n" 238 | ] 239 | }, 240 | { 241 | "output_type": "stream", 242 | "name": "stderr", 243 | "text": [ 244 | "100%|██████████| 103/103 [03:52<00:00, 2.25s/it]\n", 245 | " 19%|█▉ | 22/115 [00:52<03:25, 2.21s/it]" 246 | ] 247 | }, 248 | { 249 | "output_type": "stream", 250 | "name": "stdout", 251 | "text": [ 252 | "finished activemq-5.3.0\n" 253 | ] 254 | }, 255 | { 256 | "output_type": "stream", 257 | "name": "stderr", 258 | "text": [ 259 | "100%|██████████| 71/71 [02:34<00:00, 2.17s/it]]\n", 260 | " 73%|███████▎ | 84/115 [03:18<01:21, 2.62s/it]" 261 | ] 262 | }, 263 | { 264 | "output_type": "stream", 265 | "name": "stdout", 266 | "text": [ 267 | "finished wicket-1.3.0-beta2\n" 268 | ] 269 | }, 270 | { 271 | "output_type": "stream", 272 | "name": "stderr", 273 | "text": [ 274 | "100%|██████████| 115/115 [04:27<00:00, 2.33s/it]\n" 275 | ] 276 | }, 277 | { 278 | "output_type": "stream", 279 | "name": "stdout", 280 | "text": [ 281 | "finished hbase-0.95.0\n" 282 | ] 283 | }, 284 | { 285 | "output_type": "stream", 286 | "name": "stderr", 287 | "text": [ 288 | "100%|██████████| 102/102 [03:49<00:00, 2.25s/it]\n" 289 | ] 290 | }, 291 | { 292 | "output_type": "stream", 293 | "name": "stdout", 294 | "text": [ 295 | "finished activemq-5.8.0\n" 296 | ] 297 | }, 298 | { 299 | "output_type": "stream", 300 | "name": "stderr", 301 | "text": [ 302 | "100%|██████████| 50/50 [01:51<00:00, 2.22s/it]]\n" 303 | ] 304 | }, 305 | { 306 | "output_type": "stream", 307 | "name": "stdout", 308 | "text": [ 309 | "finished wicket-1.5.3\n" 310 | ] 311 | }, 312 | { 313 | "output_type": "stream", 314 | "name": "stderr", 315 | "text": [ 316 | "100%|██████████| 89/89 [02:26<00:00, 1.65s/it]]\n" 317 | ] 318 | }, 319 | { 320 | "output_type": "stream", 321 | "name": "stdout", 322 | "text": [ 323 | "finished camel-2.9.0\n" 324 | ] 325 | }, 326 | { 327 | "output_type": "stream", 328 | "name": "stderr", 329 | "text": [ 330 | "100%|██████████| 110/110 [03:12<00:00, 1.75s/it]\n" 331 | ] 332 | }, 333 | { 334 | "output_type": "stream", 335 | "name": "stdout", 336 | "text": [ 337 | "finished hbase-0.95.2\n" 338 | ] 339 | }, 340 | { 341 | "output_type": "stream", 342 | "name": "stderr", 343 | "text": [ 344 | "100%|██████████| 73/73 [01:59<00:00, 1.64s/it]]\n" 345 | ] 346 | }, 347 | { 348 | "output_type": "stream", 349 | "name": "stdout", 350 | "text": [ 351 | "finished hive-0.10.0\n" 352 | ] 353 | }, 354 | { 355 | "output_type": "stream", 356 | "name": "stderr", 357 | "text": [ 358 | "100%|██████████| 111/111 [02:50<00:00, 1.54s/it]\n" 359 | ] 360 | }, 361 | { 362 | "output_type": "stream", 363 | "name": "stdout", 364 | "text": [ 365 | "finished camel-2.10.0\n" 366 | ] 367 | }, 368 | { 369 | "output_type": "stream", 370 | "name": "stderr", 371 | "text": [ 372 | "100%|██████████| 119/119 [02:56<00:00, 1.49s/it]\n" 373 | ] 374 | }, 375 | { 376 | "output_type": "stream", 377 | "name": "stdout", 378 | "text": [ 379 | "finished camel-2.11.0\n" 380 | ] 381 | }, 382 | { 383 | "output_type": "stream", 384 | "name": "stderr", 385 | "text": [ 386 | "100%|██████████| 171/171 [04:37<00:00, 1.62s/it]\n", 387 | " 19%|█▊ | 53/285 [01:23<05:52, 1.52s/it]" 388 | ] 389 | }, 390 | { 391 | "output_type": "stream", 392 | "name": "stdout", 393 | "text": [ 394 | "finished hive-0.12.0\n" 395 | ] 396 | }, 397 | { 398 | "output_type": "stream", 399 | "name": "stderr", 400 | "text": [ 401 | "100%|██████████| 127/127 [03:22<00:00, 1.59s/it]\n", 402 | " 65%|██████▌ | 186/285 [04:46<02:27, 1.49s/it]" 403 | ] 404 | }, 405 | { 406 | "output_type": "stream", 407 | "name": "stdout", 408 | "text": [ 409 | "finished jruby-1.4.0\n" 410 | ] 411 | }, 412 | { 413 | "output_type": "stream", 414 | "name": "stderr", 415 | "text": [ 416 | "100%|██████████| 285/285 [06:29<00:00, 1.37s/it]\n" 417 | ] 418 | }, 419 | { 420 | "output_type": "stream", 421 | "name": "stdout", 422 | "text": [ 423 | "finished derby-10.3.1.4\n" 424 | ] 425 | } 426 | ], 427 | "metadata": { 428 | "scrolled": false 429 | } 430 | } 431 | ], 432 | "metadata": { 433 | "kernelspec": { 434 | "display_name": "Python 3", 435 | "language": "python", 436 | "name": "python3" 437 | }, 438 | "language_info": { 439 | "codemirror_mode": { 440 | "name": "ipython", 441 | "version": 3 442 | }, 443 | "file_extension": ".py", 444 | "mimetype": "text/x-python", 445 | "name": "python", 446 | "nbconvert_exporter": "python", 447 | "pygments_lexer": "ipython3", 448 | "version": "3.8.10" 449 | } 450 | }, 451 | "nbformat": 4, 452 | "nbformat_minor": 4 453 | } -------------------------------------------------------------------------------- /script/line-level-baseline/RF-line-level.py: -------------------------------------------------------------------------------- 1 | from sklearn.ensemble import RandomForestClassifier 2 | 3 | import os, sys, pickle 4 | 5 | import numpy as np 6 | 7 | from gensim.models import Word2Vec 8 | from torch.utils import data 9 | 10 | from tqdm import tqdm 11 | 12 | sys.path.append('../') 13 | from DeepLineDP_model import * 14 | from my_util import * 15 | 16 | model_dir = '../../output/model/RF-line-level/' 17 | result_dir = '../../output/RF-line-level-result/' 18 | 19 | if not os.path.exists(model_dir): 20 | os.makedirs(model_dir) 21 | 22 | if not os.path.exists(result_dir): 23 | os.makedirs(result_dir) 24 | 25 | max_grad_norm = 5 26 | embed_dim = 50 27 | word_gru_hidden_dim = 64 28 | sent_gru_hidden_dim = 64 29 | word_gru_num_layers = 1 30 | sent_gru_num_layers = 1 31 | word_att_dim = 64 32 | sent_att_dim = 64 33 | use_layer_norm = True 34 | 35 | to_lowercase = True 36 | 37 | def get_DeepLineDP_and_W2V(dataset_name): 38 | w2v_dir = get_w2v_path() 39 | 40 | word2vec_file_dir = os.path.join(w2v_dir,dataset_name+'-'+str(embed_dim)+'dim.bin') 41 | 42 | word2vec = Word2Vec.load('../'+word2vec_file_dir) 43 | print('load Word2Vec for',dataset_name,'finished') 44 | 45 | total_vocab = len(word2vec.wv.vocab) 46 | 47 | vocab_size = total_vocab +1 # for unknown tokens 48 | 49 | model = HierarchicalAttentionNetwork( 50 | vocab_size=vocab_size, 51 | embed_dim=embed_dim, 52 | word_gru_hidden_dim=word_gru_hidden_dim, 53 | sent_gru_hidden_dim=sent_gru_hidden_dim, 54 | word_gru_num_layers=word_gru_num_layers, 55 | sent_gru_num_layers=sent_gru_num_layers, 56 | word_att_dim=word_att_dim, 57 | sent_att_dim=sent_att_dim, 58 | use_layer_norm=use_layer_norm, 59 | dropout=0.001) 60 | 61 | checkpoint = torch.load('../../output/model/DeepLineDP/'+dataset_name+'/checkpoint_7epochs.pth') 62 | 63 | 64 | model.load_state_dict(checkpoint['model_state_dict']) 65 | 66 | model.sent_attention.word_attention.freeze_embeddings(True) 67 | 68 | model = model.cuda() 69 | model.eval() 70 | 71 | return model, word2vec 72 | 73 | def train_RF_model(dataset_name): 74 | model, word2vec = get_DeepLineDP_and_W2V(dataset_name) 75 | 76 | clf = RandomForestClassifier(random_state=0, n_jobs=24) 77 | 78 | train_rel = all_train_releases[dataset_name] 79 | 80 | train_df = get_df(train_rel, is_baseline=True) 81 | 82 | line_rep_list = [] 83 | all_line_label = [] 84 | 85 | # loop to get line representation of each file in train data 86 | for filename, df in tqdm(train_df.groupby('filename')): 87 | 88 | code = df['code_line'].tolist() 89 | line_label = df['line-label'].tolist() 90 | 91 | all_line_label.extend(line_label) 92 | 93 | code2d = prepare_code2d(code, to_lowercase) 94 | 95 | code3d = [code2d] 96 | 97 | codevec = get_x_vec(code3d, word2vec) 98 | 99 | with torch.no_grad(): 100 | codevec_padded_tensor = torch.tensor(codevec) 101 | _, __, ___, line_rep = model(codevec_padded_tensor) 102 | 103 | numpy_line_rep = line_rep.cpu().detach().numpy() 104 | 105 | line_rep_list.append(numpy_line_rep) 106 | 107 | x = np.concatenate(line_rep_list) 108 | 109 | print('prepare data finished') 110 | 111 | clf.fit(x,all_line_label) 112 | 113 | pickle.dump(clf, open(model_dir+dataset_name+'-RF-model.bin','wb')) 114 | 115 | print('finished training model of',dataset_name) 116 | 117 | 118 | def predict_defective_line(dataset_name): 119 | model, word2vec = get_DeepLineDP_and_W2V(dataset_name) 120 | clf = pickle.load(open(model_dir+dataset_name+'-RF-model.bin','rb')) 121 | 122 | print('load model finished') 123 | 124 | test_rels = all_eval_releases[dataset_name][1:] 125 | 126 | for rel in test_rels: 127 | test_df = get_df(rel, is_baseline=True) 128 | 129 | test_df = test_df[test_df['file-label']==True] 130 | test_df = test_df.drop(['is_comment','is_test_file','is_blank'],axis=1) 131 | 132 | all_df_list = [] # store df for saving later... 133 | 134 | for filename, df in tqdm(test_df.groupby('filename')): 135 | 136 | code = df['code_line'].tolist() 137 | 138 | code2d = prepare_code2d(code, to_lowercase) 139 | 140 | code3d = [code2d] 141 | 142 | codevec = get_x_vec(code3d, word2vec) 143 | 144 | with torch.no_grad(): 145 | codevec_padded_tensor = torch.tensor(codevec) 146 | _, __, ___, line_rep = model(codevec_padded_tensor) 147 | 148 | numpy_line_rep = line_rep.cpu().detach().numpy() 149 | 150 | pred = clf.predict(numpy_line_rep) 151 | 152 | df['line-score-pred'] = pred.astype(int) 153 | 154 | all_df_list.append(df) 155 | 156 | all_df = pd.concat(all_df_list) 157 | 158 | all_df.to_csv(result_dir+rel+'-line-lvl-result.csv',index=False) 159 | 160 | print('finished',rel) 161 | 162 | proj_name = sys.argv[1] 163 | 164 | train_RF_model(proj_name) 165 | predict_defective_line(proj_name) 166 | -------------------------------------------------------------------------------- /script/line-level-baseline/ngram/commons-io-2.8.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awsm-research/DeepLineDP/517c7e230409220ef0d6f07f4cb19c569a5d05e5/script/line-level-baseline/ngram/commons-io-2.8.0.jar -------------------------------------------------------------------------------- /script/line-level-baseline/ngram/n_gram.java: -------------------------------------------------------------------------------- 1 | import java.io.IOException; 2 | import java.util.List; 3 | import java.util.Map; 4 | 5 | import org.apache.commons.io.FileUtils; 6 | 7 | import java.io.File; 8 | import java.util.DoubleSummaryStatistics; 9 | import java.util.HashMap; 10 | import java.util.stream.Collectors; 11 | 12 | import slp.core.counting.giga.GigaCounter; 13 | import slp.core.lexing.Lexer; 14 | import slp.core.lexing.code.JavaLexer; 15 | import slp.core.lexing.runners.LexerRunner; 16 | import slp.core.lexing.simple.WhitespaceLexer; 17 | import slp.core.modeling.Model; 18 | import slp.core.modeling.dynamic.CacheModel; 19 | import slp.core.modeling.mix.MixModel; 20 | import slp.core.modeling.ngram.JMModel; 21 | import slp.core.modeling.runners.ModelRunner; 22 | import slp.core.translating.Vocabulary; 23 | 24 | public class n_gram 25 | { 26 | public static String root_dir = "./n_gram_data/"; 27 | public static String result_dir = "./n_gram_result/"; 28 | 29 | public static String all_dataset[] = {"activemq","camel","derby","groovy","hbase","hive", "jruby","lucene","wicket"}; 30 | public static String all_train_releases[] = {"activemq-5.0.0","camel-1.4.0","derby-10.2.1.6","groovy-1_5_7","hbase-0.94.0", "hive-0.9.0","jruby-1.1","lucene-2.3.0","wicket-1.3.0-incubating-beta-1"}; 31 | public static String all_eval_releases[][] = {{"activemq-5.2.0","activemq-5.3.0","activemq-5.8.0"}, 32 | {"camel-2.10.0","camel-2.11.0"}, 33 | {"derby-10.5.1.1"}, 34 | {"groovy-1_6_BETA_2"}, 35 | {"hbase-0.95.2"}, 36 | {"hive-0.12.0"}, 37 | {"jruby-1.5.0","jruby-1.7.0.preview1"}, 38 | {"lucene-3.0.0","lucene-3.1"}, 39 | {"wicket-1.5.3"}}; 40 | 41 | public static String all_releases[][] = {{"activemq-5.0.0","activemq-5.1.0","activemq-5.2.0","activemq-5.3.0","activemq-5.8.0"}, 42 | {"camel-1.4.0","camel-2.9.0","camel-2.10.0","camel-2.11.0"}, 43 | {"derby-10.2.1.6","derby-10.3.1.4","derby-10.5.1.1"}, 44 | {"groovy-1_5_7","groovy-1_6_BETA_1","groovy-1_6_BETA_2"}, 45 | {"hbase-0.94.0","hbase-0.95.0","hbase-0.95.2"}, 46 | {"hive-0.9.0", "hive-0.10.0","hive-0.12.0"}, 47 | {"jruby-1.1", "jruby-1.4.0","jruby-1.5.0","jruby-1.7.0.preview1"}, 48 | {"lucene-2.3.0","lucene-2.9.0","lucene-3.0.0","lucene-3.1"}, 49 | {"wicket-1.3.0-incubating-beta-1", "wicket-1.3.0-beta2","wicket-1.5.3"}}; 50 | 51 | public static ModelRunner train_model(String train_release) 52 | { 53 | Map to_return = new HashMap(); 54 | File train = new File(root_dir+train_release+"/src"); 55 | Lexer lexer = new WhitespaceLexer(); // Use a Java lexer; if your code is already lexed, use whitespace or tokenized lexer 56 | LexerRunner lexerRunner = new LexerRunner(lexer, false); 57 | 58 | lexerRunner.setSentenceMarkers(true); // Add start and end markers to the files 59 | 60 | Vocabulary vocabulary = new Vocabulary(); // Create an empty vocabulary 61 | 62 | Model model = new JMModel(6, new GigaCounter()); // Standard smoothing for code, giga-counter for large corpora 63 | model = MixModel.standard(model, new CacheModel()); // Use a simple cache model; see JavaRunner for more options 64 | ModelRunner modelRunner = new ModelRunner(model, lexerRunner, vocabulary); // Use above lexer and vocabulary 65 | modelRunner.learnDirectory(train); // Teach the model all the data in "train" 66 | 67 | return modelRunner; 68 | } 69 | 70 | public static void predict_defective_lines(String train_release, String test_release, ModelRunner modelRunner) throws Exception 71 | { 72 | LexerRunner lexerRunner = modelRunner.getLexerRunner(); 73 | 74 | StringBuilder sb = new StringBuilder(); 75 | 76 | sb.append("train-release\ttest-release\tfile-name\tline-number\ttoken\ttoken-score\tline-score\n"); 77 | 78 | File test_java_dir = new File(root_dir + test_release+"/src/"); 79 | File java_files[] = test_java_dir.listFiles(); 80 | 81 | String line_num_path = root_dir + test_release+"/line_num/"; 82 | 83 | // loop each file here... 84 | 85 | for(int j = 0; j linenum = FileUtils.readLines(new File(line_num_path+linenum_filename),"UTF-8"); 94 | 95 | List> fileEntropies = modelRunner.modelFile(test); 96 | List> fileTokens = lexerRunner.lexFile(test) // Let's also retrieve the tokens on each line 97 | .map(l -> l.collect(Collectors.toList())) 98 | .collect(Collectors.toList()); 99 | 100 | for (int i = 0; i < linenum.size(); i++) { 101 | List lineTokens = fileTokens.get(i); 102 | List lineEntropies = fileEntropies.get(i); 103 | 104 | String cur_line_num = linenum.get(i); 105 | 106 | // First use Java's stream API to summarize entropies on this line 107 | // (see modelRunner.getStats for summarizing file or directory results) 108 | DoubleSummaryStatistics lineStatistics = lineEntropies.stream() 109 | .mapToDouble(Double::doubleValue) 110 | .summaryStatistics(); 111 | double averageEntropy = lineStatistics.getAverage(); 112 | 113 | for(int k = 0; k< lineTokens.size(); k++) 114 | { 115 | String tok = lineTokens.get(k); 116 | double tok_score = lineEntropies.get(k); 117 | 118 | if(tok == "") 119 | continue; 120 | 121 | sb.append(train_release+"\t"+test_release+"\t"+filename_original+"\t"+cur_line_num+"\t"+tok+"\t"+tok_score+"\t"+averageEntropy+"\n"); 122 | } 123 | 124 | } 125 | } 126 | FileUtils.write(new File(result_dir+test_release+"-line-lvl-result.txt"), sb.toString(),"UTF-8"); 127 | } 128 | 129 | public static void train_eval_model(int dataset_idx) throws Exception 130 | { 131 | String dataset_name = all_dataset[dataset_idx]; 132 | String train_release = all_train_releases[dataset_idx]; 133 | String eval_release[] = all_eval_releases[dataset_idx]; 134 | 135 | ModelRunner modelRunner = train_model(train_release); 136 | 137 | System.out.println("finish training model for " + dataset_name); 138 | 139 | for(int idx = 0; idx' 61 | ''' 62 | code2d = [] 63 | 64 | for c in code_list: 65 | c = re.sub('\\s+',' ',c) 66 | 67 | if to_lowercase: 68 | c = c.lower() 69 | 70 | token_list = c.strip().split() 71 | total_tokens = len(token_list) 72 | 73 | token_list = token_list[:max_seq_len] 74 | 75 | if total_tokens < max_seq_len: 76 | token_list = token_list + ['']*(max_seq_len-total_tokens) 77 | 78 | code2d.append(token_list) 79 | 80 | return code2d 81 | 82 | def get_code3d_and_label(df, to_lowercase = False): 83 | ''' 84 | input 85 | df (DataFrame): a dataframe from get_df() 86 | output 87 | code3d (nested list): a list of code2d from prepare_code2d() 88 | all_file_label (list): a list of file-level label 89 | ''' 90 | 91 | code3d = [] 92 | all_file_label = [] 93 | 94 | for filename, group_df in df.groupby('filename'): 95 | 96 | file_label = bool(group_df['file-label'].unique()) 97 | 98 | code = list(group_df['code_line']) 99 | 100 | code2d = prepare_code2d(code, to_lowercase) 101 | code3d.append(code2d) 102 | 103 | all_file_label.append(file_label) 104 | 105 | return code3d, all_file_label 106 | 107 | def get_w2v_path(): 108 | 109 | return word2vec_dir 110 | 111 | def get_w2v_weight_for_deep_learning_models(word2vec_model, embed_dim): 112 | word2vec_weights = torch.FloatTensor(word2vec_model.wv.syn0).cuda() 113 | 114 | # add zero vector for unknown tokens 115 | word2vec_weights = torch.cat((word2vec_weights, torch.zeros(1,embed_dim).cuda())) 116 | 117 | return word2vec_weights 118 | 119 | def pad_code(code_list_3d,max_sent_len,limit_sent_len=True, mode='train'): 120 | paded = [] 121 | 122 | for file in code_list_3d: 123 | sent_list = [] 124 | for line in file: 125 | new_line = line 126 | if len(line) > max_seq_len: 127 | new_line = line[:max_seq_len] 128 | sent_list.append(new_line) 129 | 130 | 131 | if mode == 'train': 132 | if max_sent_len-len(file) > 0: 133 | for i in range(0,max_sent_len-len(file)): 134 | sent_list.append([0]*max_seq_len) 135 | 136 | if limit_sent_len: 137 | paded.append(sent_list[:max_sent_len]) 138 | else: 139 | paded.append(sent_list) 140 | 141 | return paded 142 | 143 | def get_dataloader(code_vec, label_list,batch_size, max_sent_len): 144 | y_tensor = torch.cuda.FloatTensor([label for label in label_list]) 145 | code_vec_pad = pad_code(code_vec,max_sent_len) 146 | tensor_dataset = TensorDataset(torch.tensor(code_vec_pad), y_tensor) 147 | 148 | dl = DataLoader(tensor_dataset,shuffle=True,batch_size=batch_size,drop_last=True) 149 | 150 | return dl 151 | 152 | def get_x_vec(code_3d, word2vec): 153 | x_vec = [[[word2vec.wv.vocab[token].index if token in word2vec.wv.vocab else len(word2vec.wv.vocab) for token in text] 154 | for text in texts] for texts in code_3d] 155 | 156 | return x_vec -------------------------------------------------------------------------------- /script/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os, re 3 | import numpy as np 4 | 5 | from my_util import * 6 | 7 | data_root_dir = '../datasets/original/' 8 | save_dir = "../datasets/preprocessed_data/" 9 | 10 | char_to_remove = ['+','-','*','/','=','++','--','\\','','','|','&','!'] 11 | 12 | if not os.path.exists(save_dir): 13 | os.makedirs(save_dir) 14 | 15 | file_lvl_dir = data_root_dir+'File-level/' 16 | line_lvl_dir = data_root_dir+'Line-level/' 17 | 18 | 19 | def is_comment_line(code_line, comments_list): 20 | ''' 21 | input 22 | code_line (string): source code in a line 23 | comments_list (list): a list that contains every comments 24 | output 25 | boolean value 26 | ''' 27 | 28 | code_line = code_line.strip() 29 | 30 | if len(code_line) == 0: 31 | return False 32 | elif code_line.startswith('//'): 33 | return True 34 | elif code_line in comments_list: 35 | return True 36 | 37 | return False 38 | 39 | def is_empty_line(code_line): 40 | ''' 41 | input 42 | code_line (string) 43 | output 44 | boolean value 45 | ''' 46 | 47 | if len(code_line.strip()) == 0: 48 | return True 49 | 50 | return False 51 | 52 | def preprocess_code_line(code_line): 53 | ''' 54 | input 55 | code_line (string) 56 | ''' 57 | 58 | code_line = re.sub("\'\'", "\'", code_line) 59 | code_line = re.sub("\".*?\"", "", code_line) 60 | code_line = re.sub("\'.*?\'", "", code_line) 61 | code_line = re.sub('\b\d+\b','',code_line) 62 | code_line = re.sub("\\[.*?\\]", '', code_line) 63 | code_line = re.sub("[\\.|,|:|;|{|}|(|)]", ' ', code_line) 64 | 65 | for char in char_to_remove: 66 | code_line = code_line.replace(char,' ') 67 | 68 | code_line = code_line.strip() 69 | 70 | return code_line 71 | 72 | def create_code_df(code_str, filename): 73 | ''' 74 | input 75 | code_str (string): a source code 76 | filename (string): a file name of source code 77 | 78 | output 79 | code_df (DataFrame): a dataframe of source code that contains the following columns 80 | - code_line (str): source code in a line 81 | - line_number (str): line number of source code line 82 | - is_comment (bool): boolean which indicates if a line is comment 83 | - is_blank_line(bool): boolean which indicates if a line is blank 84 | ''' 85 | 86 | df = pd.DataFrame() 87 | 88 | code_lines = code_str.splitlines() 89 | 90 | preprocess_code_lines = [] 91 | is_comments = [] 92 | is_blank_line = [] 93 | 94 | 95 | comments = re.findall(r'(/\*[\s\S]*?\*/)',code_str,re.DOTALL) 96 | comments_str = '\n'.join(comments) 97 | comments_list = comments_str.split('\n') 98 | 99 | for l in code_lines: 100 | l = l.strip() 101 | is_comment = is_comment_line(l,comments_list) 102 | is_comments.append(is_comment) 103 | # preprocess code here then check empty line... 104 | 105 | if not is_comment: 106 | l = preprocess_code_line(l) 107 | 108 | is_blank_line.append(is_empty_line(l)) 109 | preprocess_code_lines.append(l) 110 | 111 | if 'test' in filename: 112 | is_test = True 113 | else: 114 | is_test = False 115 | 116 | df['filename'] = [filename]*len(code_lines) 117 | df['is_test_file'] = [is_test]*len(code_lines) 118 | df['code_line'] = preprocess_code_lines 119 | df['line_number'] = np.arange(1,len(code_lines)+1) 120 | df['is_comment'] = is_comments 121 | df['is_blank'] = is_blank_line 122 | 123 | return df 124 | 125 | def preprocess_data(proj_name): 126 | 127 | cur_all_rel = all_releases[proj_name] 128 | 129 | for rel in cur_all_rel: 130 | file_level_data = pd.read_csv(file_lvl_dir+rel+'_ground-truth-files_dataset.csv', encoding='latin') 131 | line_level_data = pd.read_csv(line_lvl_dir+rel+'_defective_lines_dataset.csv', encoding='latin') 132 | 133 | file_level_data = file_level_data.fillna('') 134 | 135 | buggy_files = list(line_level_data['File'].unique()) 136 | 137 | preprocessed_df_list = [] 138 | 139 | for idx, row in file_level_data.iterrows(): 140 | 141 | filename = row['File'] 142 | 143 | if '.java' not in filename: 144 | continue 145 | 146 | code = row['SRC'] 147 | label = row['Bug'] 148 | 149 | code_df = create_code_df(code, filename) 150 | code_df['file-label'] = [label]*len(code_df) 151 | code_df['line-label'] = [False]*len(code_df) 152 | 153 | if filename in buggy_files: 154 | buggy_lines = list(line_level_data[line_level_data['File']==filename]['Line_number']) 155 | code_df['line-label'] = code_df['line_number'].isin(buggy_lines) 156 | 157 | if len(code_df) > 0: 158 | preprocessed_df_list.append(code_df) 159 | 160 | all_df = pd.concat(preprocessed_df_list) 161 | all_df.to_csv(save_dir+rel+".csv",index=False) 162 | print('finish release {}'.format(rel)) 163 | 164 | for proj in list(all_releases.keys()): 165 | preprocess_data(proj) 166 | 167 | -------------------------------------------------------------------------------- /script/train_model.py: -------------------------------------------------------------------------------- 1 | import os, re, argparse 2 | 3 | import torch.optim as optim 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from gensim.models import Word2Vec 9 | 10 | from tqdm import tqdm 11 | 12 | from sklearn.utils import compute_class_weight 13 | 14 | from DeepLineDP_model import * 15 | from my_util import * 16 | 17 | torch.manual_seed(0) 18 | 19 | arg = argparse.ArgumentParser() 20 | 21 | arg.add_argument('-dataset',type=str, default='activemq', help='software project name (lowercase)') 22 | arg.add_argument('-batch_size', type=int, default=32) 23 | arg.add_argument('-num_epochs', type=int, default=10) 24 | arg.add_argument('-embed_dim', type=int, default=50, help='word embedding size') 25 | arg.add_argument('-word_gru_hidden_dim', type=int, default=64, help='word attention hidden size') 26 | arg.add_argument('-sent_gru_hidden_dim', type=int, default=64, help='sentence attention hidden size') 27 | arg.add_argument('-word_gru_num_layers', type=int, default=1, help='number of GRU layer at word level') 28 | arg.add_argument('-sent_gru_num_layers', type=int, default=1, help='number of GRU layer at sentence level') 29 | arg.add_argument('-dropout', type=float, default=0.2, help='dropout rate') 30 | arg.add_argument('-lr', type=float, default=0.001, help='learning rate') 31 | arg.add_argument('-exp_name',type=str,default='') 32 | 33 | args = arg.parse_args() 34 | 35 | # model setting 36 | batch_size = args.batch_size 37 | num_epochs = args.num_epochs 38 | max_grad_norm = 5 39 | embed_dim = args.embed_dim 40 | word_gru_hidden_dim = args.word_gru_hidden_dim 41 | sent_gru_hidden_dim = args.sent_gru_hidden_dim 42 | word_gru_num_layers = args.word_gru_num_layers 43 | sent_gru_num_layers = args.sent_gru_num_layers 44 | word_att_dim = 64 45 | sent_att_dim = 64 46 | use_layer_norm = True 47 | dropout = args.dropout 48 | lr = args.lr 49 | 50 | save_every_epochs = 1 51 | exp_name = args.exp_name 52 | 53 | max_train_LOC = 900 54 | 55 | prediction_dir = '../output/prediction/DeepLineDP/' 56 | save_model_dir = '../output/model/DeepLineDP/' 57 | 58 | file_lvl_gt = '../datasets/preprocessed_data/' 59 | 60 | weight_dict = {} 61 | 62 | def get_loss_weight(labels): 63 | ''' 64 | input 65 | labels: a PyTorch tensor that contains labels 66 | output 67 | weight_tensor: a PyTorch tensor that contains weight of defect/clean class 68 | ''' 69 | label_list = labels.cpu().numpy().squeeze().tolist() 70 | weight_list = [] 71 | 72 | for lab in label_list: 73 | if lab == 0: 74 | weight_list.append(weight_dict['clean']) 75 | else: 76 | weight_list.append(weight_dict['defect']) 77 | 78 | weight_tensor = torch.tensor(weight_list).reshape(-1,1).cuda() 79 | return weight_tensor 80 | 81 | def train_model(dataset_name): 82 | 83 | loss_dir = '../output/loss/DeepLineDP/' 84 | actual_save_model_dir = save_model_dir+dataset_name+'/' 85 | 86 | if not exp_name == '': 87 | actual_save_model_dir = actual_save_model_dir+exp_name+'/' 88 | loss_dir = loss_dir + exp_name 89 | 90 | if not os.path.exists(actual_save_model_dir): 91 | os.makedirs(actual_save_model_dir) 92 | 93 | if not os.path.exists(loss_dir): 94 | os.makedirs(loss_dir) 95 | 96 | train_rel = all_train_releases[dataset_name] 97 | valid_rel = all_eval_releases[dataset_name][0] 98 | 99 | train_df = get_df(train_rel) 100 | valid_df = get_df(valid_rel) 101 | 102 | train_code3d, train_label = get_code3d_and_label(train_df, True) 103 | valid_code3d, valid_label = get_code3d_and_label(valid_df, True) 104 | 105 | sample_weights = compute_class_weight(class_weight = 'balanced', classes = np.unique(train_label), y = train_label) 106 | 107 | weight_dict['defect'] = np.max(sample_weights) 108 | weight_dict['clean'] = np.min(sample_weights) 109 | 110 | w2v_dir = get_w2v_path() 111 | 112 | word2vec_file_dir = os.path.join(w2v_dir,dataset_name+'-'+str(embed_dim)+'dim.bin') 113 | 114 | word2vec = Word2Vec.load(word2vec_file_dir) 115 | print('load Word2Vec for',dataset_name,'finished') 116 | 117 | word2vec_weights = get_w2v_weight_for_deep_learning_models(word2vec, embed_dim) 118 | 119 | vocab_size = len(word2vec.wv.vocab) + 1 # for unknown tokens 120 | 121 | x_train_vec = get_x_vec(train_code3d, word2vec) 122 | x_valid_vec = get_x_vec(valid_code3d, word2vec) 123 | 124 | max_sent_len = min(max([len(sent) for sent in (x_train_vec)]), max_train_LOC) 125 | 126 | train_dl = get_dataloader(x_train_vec,train_label,batch_size,max_sent_len) 127 | 128 | valid_dl = get_dataloader(x_valid_vec, valid_label,batch_size,max_sent_len) 129 | 130 | model = HierarchicalAttentionNetwork( 131 | vocab_size=vocab_size, 132 | embed_dim=embed_dim, 133 | word_gru_hidden_dim=word_gru_hidden_dim, 134 | sent_gru_hidden_dim=sent_gru_hidden_dim, 135 | word_gru_num_layers=word_gru_num_layers, 136 | sent_gru_num_layers=sent_gru_num_layers, 137 | word_att_dim=word_att_dim, 138 | sent_att_dim=sent_att_dim, 139 | use_layer_norm=use_layer_norm, 140 | dropout=dropout) 141 | 142 | model = model.cuda() 143 | model.sent_attention.word_attention.freeze_embeddings(False) 144 | 145 | optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), lr=lr) 146 | 147 | criterion = nn.BCELoss() 148 | 149 | checkpoint_files = os.listdir(actual_save_model_dir) 150 | 151 | if '.ipynb_checkpoints' in checkpoint_files: 152 | checkpoint_files.remove('.ipynb_checkpoints') 153 | 154 | total_checkpoints = len(checkpoint_files) 155 | 156 | # no model is trained 157 | if total_checkpoints == 0: 158 | model.sent_attention.word_attention.init_embeddings(word2vec_weights) 159 | current_checkpoint_num = 1 160 | 161 | train_loss_all_epochs = [] 162 | val_loss_all_epochs = [] 163 | 164 | else: 165 | checkpoint_nums = [int(re.findall('\d+',s)[0]) for s in checkpoint_files] 166 | current_checkpoint_num = max(checkpoint_nums) 167 | 168 | checkpoint = torch.load(actual_save_model_dir+'checkpoint_'+str(current_checkpoint_num)+'epochs.pth') 169 | model.load_state_dict(checkpoint['model_state_dict']) 170 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 171 | 172 | loss_df = pd.read_csv(loss_dir+dataset_name+'-loss_record.csv') 173 | train_loss_all_epochs = list(loss_df['train_loss']) 174 | val_loss_all_epochs = list(loss_df['valid_loss']) 175 | 176 | current_checkpoint_num = current_checkpoint_num+1 # go to next epoch 177 | print('continue training model from epoch',current_checkpoint_num) 178 | 179 | for epoch in tqdm(range(current_checkpoint_num,num_epochs+1)): 180 | train_losses = [] 181 | val_losses = [] 182 | 183 | model.train() 184 | 185 | for inputs, labels in train_dl: 186 | 187 | inputs_cuda, labels_cuda = inputs.cuda(), labels.cuda() 188 | output, _, __, ___ = model(inputs_cuda) 189 | 190 | weight_tensor = get_loss_weight(labels) 191 | 192 | criterion.weight = weight_tensor 193 | 194 | loss = criterion(output, labels_cuda.reshape(batch_size,1)) 195 | 196 | train_losses.append(loss.item()) 197 | 198 | torch.cuda.empty_cache() 199 | 200 | loss.backward() 201 | nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 202 | 203 | optimizer.step() 204 | 205 | torch.cuda.empty_cache() 206 | 207 | train_loss_all_epochs.append(np.mean(train_losses)) 208 | 209 | with torch.no_grad(): 210 | 211 | criterion.weight = None 212 | model.eval() 213 | 214 | for inputs, labels in valid_dl: 215 | 216 | inputs, labels = inputs.cuda(), labels.cuda() 217 | output, _, __, ___ = model(inputs) 218 | 219 | val_loss = criterion(output, labels.reshape(batch_size,1)) 220 | 221 | val_losses.append(val_loss.item()) 222 | 223 | val_loss_all_epochs.append(np.mean(val_losses)) 224 | 225 | if epoch % save_every_epochs == 0: 226 | print(dataset_name,'- at epoch:',str(epoch)) 227 | 228 | if exp_name == '': 229 | torch.save({ 230 | 'epoch': epoch, 231 | 'model_state_dict': model.state_dict(), 232 | 'optimizer_state_dict': optimizer.state_dict() 233 | }, 234 | actual_save_model_dir+'checkpoint_'+str(epoch)+'epochs.pth') 235 | else: 236 | torch.save({ 237 | 'epoch': epoch, 238 | 'model_state_dict': model.state_dict(), 239 | 'optimizer_state_dict': optimizer.state_dict() 240 | }, 241 | actual_save_model_dir+'checkpoint_'+exp_name+'_'+str(epoch)+'epochs.pth') 242 | 243 | loss_df = pd.DataFrame() 244 | loss_df['epoch'] = np.arange(1,len(train_loss_all_epochs)+1) 245 | loss_df['train_loss'] = train_loss_all_epochs 246 | loss_df['valid_loss'] = val_loss_all_epochs 247 | 248 | loss_df.to_csv(loss_dir+dataset_name+'-loss_record.csv',index=False) 249 | 250 | dataset_name = args.dataset 251 | train_model(dataset_name) -------------------------------------------------------------------------------- /script/train_word2vec.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | from gensim.models import Word2Vec 4 | 5 | import more_itertools 6 | 7 | from DeepLineDP_model import * 8 | from my_util import * 9 | 10 | 11 | def train_word2vec_model(dataset_name, embedding_dim = 50): 12 | 13 | w2v_path = get_w2v_path() 14 | 15 | save_path = w2v_path+'/'+dataset_name+'-'+str(embedding_dim)+'dim.bin' 16 | 17 | if os.path.exists(save_path): 18 | print('word2vec model at {} is already exists'.format(save_path)) 19 | return 20 | 21 | if not os.path.exists(w2v_path): 22 | os.makedirs(w2v_path) 23 | 24 | train_rel = all_train_releases[dataset_name] 25 | 26 | train_df = get_df(train_rel) 27 | 28 | train_code_3d, _ = get_code3d_and_label(train_df, True) 29 | 30 | all_texts = list(more_itertools.collapse(train_code_3d[:],levels=1)) 31 | 32 | word2vec = Word2Vec(all_texts,size=embedding_dim, min_count=1,sorted_vocab=1) 33 | 34 | word2vec.save(save_path) 35 | print('save word2vec model at path {} done'.format(save_path)) 36 | 37 | 38 | p = sys.argv[1] 39 | 40 | train_word2vec_model(p,50) 41 | --------------------------------------------------------------------------------