├── parser ├── my-languages.so ├── __pycache__ │ ├── DFG.cpython-310.pyc │ ├── DFG.cpython-36.pyc │ ├── DFG.cpython-37.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-37.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── __init__.py ├── build.sh ├── build.py ├── utils.py └── DFG.py ├── requirements.txt ├── __pycache__ ├── LSTM.cpython-36.pyc ├── LSTM.cpython-37.pyc ├── bleu.cpython-36.pyc ├── bleu.cpython-37.pyc ├── bleu.cpython-39.pyc ├── BiLSTM.cpython-37.pyc ├── LSTM.cpython-310.pyc ├── bleu.cpython-310.pyc ├── utils.cpython-310.pyc ├── utils.cpython-36.pyc ├── utils.cpython-37.pyc ├── utils.cpython-39.pyc ├── Transformer.cpython-36.pyc ├── Transformer.cpython-37.pyc ├── Transformer.cpython-310.pyc ├── calc_code_bleu.cpython-36.pyc ├── calc_code_bleu.cpython-37.pyc ├── calc_code_bleu.cpython-39.pyc ├── dataflow_match.cpython-36.pyc ├── dataflow_match.cpython-37.pyc ├── syntax_match.cpython-310.pyc ├── syntax_match.cpython-36.pyc ├── syntax_match.cpython-37.pyc ├── syntax_match.cpython-39.pyc ├── calc_code_bleu.cpython-310.pyc ├── dataflow_match.cpython-310.pyc ├── weighted_ngram_match.cpython-310.pyc ├── weighted_ngram_match.cpython-36.pyc ├── weighted_ngram_match.cpython-37.pyc └── weighted_ngram_match.cpython-39.pyc ├── keywords ├── go.txt ├── python.txt ├── ruby.txt ├── java.txt ├── javascript.txt ├── php.txt └── c_sharp.txt ├── LICENSE ├── syntax_match.py ├── README.md ├── utils.py ├── calc_code_bleu.py ├── dataflow_match.py ├── inference.py ├── pre-training.py ├── generate_encodings.py ├── fine-tuning.py ├── weighted_ngram_match.py └── bleu.py /parser/my-languages.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/parser/my-languages.so -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.64.1 2 | transformers==4.25.1 3 | evaluate==0.4.0 4 | tree_sitter 5 | easyocr -------------------------------------------------------------------------------- /__pycache__/LSTM.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/LSTM.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/LSTM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/LSTM.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/bleu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/bleu.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/bleu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/bleu.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/bleu.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/bleu.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/BiLSTM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/BiLSTM.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/LSTM.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/LSTM.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/bleu.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/bleu.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/Transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/Transformer.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/Transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/Transformer.cpython-37.pyc -------------------------------------------------------------------------------- /parser/__pycache__/DFG.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/parser/__pycache__/DFG.cpython-310.pyc -------------------------------------------------------------------------------- /parser/__pycache__/DFG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/parser/__pycache__/DFG.cpython-36.pyc -------------------------------------------------------------------------------- /parser/__pycache__/DFG.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/parser/__pycache__/DFG.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/Transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/Transformer.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/calc_code_bleu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/calc_code_bleu.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/calc_code_bleu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/calc_code_bleu.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/calc_code_bleu.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/calc_code_bleu.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/dataflow_match.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/dataflow_match.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/dataflow_match.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/dataflow_match.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/syntax_match.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/syntax_match.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/syntax_match.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/syntax_match.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/syntax_match.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/syntax_match.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/syntax_match.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/syntax_match.cpython-39.pyc -------------------------------------------------------------------------------- /parser/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/parser/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /parser/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/parser/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /parser/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/parser/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/calc_code_bleu.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/calc_code_bleu.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/dataflow_match.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/dataflow_match.cpython-310.pyc -------------------------------------------------------------------------------- /parser/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/parser/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /parser/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/parser/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /parser/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/parser/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/weighted_ngram_match.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/weighted_ngram_match.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/weighted_ngram_match.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/weighted_ngram_match.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/weighted_ngram_match.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/weighted_ngram_match.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/weighted_ngram_match.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vl2g/floco/HEAD/__pycache__/weighted_ngram_match.cpython-39.pyc -------------------------------------------------------------------------------- /keywords/go.txt: -------------------------------------------------------------------------------- 1 | break 2 | case 3 | chan 4 | const 5 | continue 6 | default 7 | defer 8 | else 9 | fallthrough 10 | for 11 | func 12 | go 13 | goto 14 | if 15 | import 16 | interface 17 | map 18 | package 19 | range 20 | return 21 | select 22 | struct 23 | switch 24 | type 25 | var -------------------------------------------------------------------------------- /keywords/python.txt: -------------------------------------------------------------------------------- 1 | False 2 | None 3 | True 4 | and 5 | as 6 | assert 7 | break 8 | class 9 | continue 10 | def 11 | del 12 | elif 13 | else 14 | except 15 | finally 16 | for 17 | from 18 | global 19 | if 20 | import 21 | in 22 | is 23 | lambda 24 | nonlocal 25 | not 26 | or 27 | pass 28 | raise 29 | return 30 | try 31 | while 32 | with 33 | yield -------------------------------------------------------------------------------- /parser/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .utils import (remove_comments_and_docstrings, 5 | tree_to_token_index, 6 | index_to_code_token, 7 | tree_to_variable_index) 8 | from .DFG import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp -------------------------------------------------------------------------------- /keywords/ruby.txt: -------------------------------------------------------------------------------- 1 | BEGIN 2 | END 3 | alias 4 | and 5 | begin 6 | break 7 | case 8 | class 9 | def 10 | module 11 | next 12 | nil 13 | not 14 | or 15 | redo 16 | rescue 17 | retry 18 | return 19 | elsif 20 | end 21 | false 22 | ensure 23 | for 24 | if 25 | true 26 | undef 27 | unless 28 | do 29 | else 30 | super 31 | then 32 | until 33 | when 34 | while 35 | defined? 36 | self -------------------------------------------------------------------------------- /parser/build.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/tree-sitter/tree-sitter-go 2 | git clone https://github.com/tree-sitter/tree-sitter-javascript 3 | git clone https://github.com/tree-sitter/tree-sitter-python 4 | git clone https://github.com/tree-sitter/tree-sitter-ruby 5 | git clone https://github.com/tree-sitter/tree-sitter-php 6 | git clone https://github.com/tree-sitter/tree-sitter-java 7 | git clone https://github.com/tree-sitter/tree-sitter-c-sharp 8 | python build.py 9 | -------------------------------------------------------------------------------- /parser/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from tree_sitter import Language, Parser 5 | 6 | Language.build_library( 7 | # Store the library in the `build` directory 8 | 'my-languages.so', 9 | 10 | # Include one or more languages 11 | [ 12 | 'tree-sitter-go', 13 | 'tree-sitter-javascript', 14 | 'tree-sitter-python', 15 | 'tree-sitter-php', 16 | 'tree-sitter-java', 17 | 'tree-sitter-ruby', 18 | 'tree-sitter-c-sharp', 19 | ] 20 | ) 21 | 22 | -------------------------------------------------------------------------------- /keywords/java.txt: -------------------------------------------------------------------------------- 1 | abstract 2 | assert 3 | boolean 4 | break 5 | byte 6 | case 7 | catch 8 | char 9 | class 10 | const 11 | continue 12 | default 13 | do 14 | double 15 | else 16 | enum 17 | extends 18 | final 19 | finally 20 | float 21 | for 22 | goto 23 | if 24 | implements 25 | import 26 | instanceof 27 | int 28 | interface 29 | long 30 | native 31 | new 32 | package 33 | private 34 | protected 35 | public 36 | return 37 | short 38 | static 39 | strictfp 40 | super 41 | switch 42 | synchronized 43 | this 44 | throw 45 | throws 46 | transient 47 | try 48 | void 49 | volatile 50 | while -------------------------------------------------------------------------------- /keywords/javascript.txt: -------------------------------------------------------------------------------- 1 | abstracti 2 | arguments 3 | boolean 4 | break 5 | byte 6 | case 7 | catch 8 | char 9 | const 10 | continue 11 | debugger 12 | default 13 | delete 14 | do 15 | double 16 | else 17 | eval 18 | false 19 | final 20 | finally 21 | float 22 | for 23 | function 24 | goto 25 | if 26 | implements 27 | in 28 | instanceof 29 | int 30 | interface 31 | let 32 | long 33 | native 34 | new 35 | null 36 | package 37 | private 38 | protected 39 | public 40 | return 41 | short 42 | static 43 | switch 44 | synchronized 45 | this 46 | throw 47 | throws 48 | transient 49 | true 50 | try 51 | typeof 52 | var 53 | void 54 | volatile 55 | while 56 | with 57 | yield -------------------------------------------------------------------------------- /keywords/php.txt: -------------------------------------------------------------------------------- 1 | __halt_compiler() 2 | abstract 3 | and 4 | array() 5 | as 6 | break 7 | callable 8 | case 9 | catch 10 | class 11 | clone 12 | const 13 | continue 14 | declare 15 | default 16 | die() 17 | do 18 | echo 19 | else 20 | elseif 21 | empty() 22 | enddeclare 23 | endfor 24 | endforeach 25 | endif 26 | endswitch 27 | endwhile 28 | eval() 29 | exit() 30 | extends 31 | final 32 | finally 33 | for 34 | foreach 35 | function 36 | global 37 | goto 38 | if 39 | implements 40 | include 41 | include_once 42 | instanceof 43 | insteadof 44 | interface 45 | isset() 46 | list() 47 | namespace 48 | new 49 | or 50 | print 51 | private 52 | protected 53 | public 54 | require 55 | require_once 56 | return 57 | static 58 | switch 59 | throw 60 | trait 61 | try 62 | unset() 63 | use 64 | var 65 | while 66 | xor 67 | yield 68 | __CLASS__ 69 | __DIR__ 70 | __FILE__ 71 | __FUNCTION__ 72 | __LINE__ 73 | __METHOD__ 74 | __NAMESPACE__ 75 | __TRAIT__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 vl2g 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /keywords/c_sharp.txt: -------------------------------------------------------------------------------- 1 | abstract 2 | as 3 | base 4 | bool 5 | break 6 | byte 7 | case 8 | catch 9 | char 10 | checked 11 | class 12 | const 13 | continue 14 | decimal 15 | default 16 | delegate 17 | do 18 | double 19 | else 20 | enum 21 | event 22 | explicit 23 | extern 24 | false 25 | finally 26 | fixed 27 | float 28 | for 29 | foreach 30 | goto 31 | if 32 | implicit 33 | in 34 | int 35 | interface 36 | internal 37 | is 38 | lock 39 | long 40 | namespace 41 | new 42 | null 43 | object 44 | operator 45 | out 46 | override 47 | params 48 | private 49 | protected 50 | public 51 | readonly 52 | ref 53 | return 54 | sbyte 55 | sealed 56 | short 57 | sizeof 58 | stackalloc 59 | static 60 | string 61 | struct 62 | switch 63 | this 64 | throw 65 | true 66 | try 67 | typeof 68 | uint 69 | ulong 70 | unchecked 71 | unsafe 72 | ushort 73 | using 74 | virtual 75 | void 76 | volatile 77 | while 78 | add 79 | alias 80 | ascending 81 | async 82 | await 83 | by 84 | descending 85 | dynamic 86 | equals 87 | from 88 | get 89 | global 90 | group 91 | into 92 | join 93 | let 94 | nameof 95 | notnull 96 | on 97 | orderby 98 | partial 99 | remove 100 | select 101 | set 102 | unmanaged 103 | value 104 | var 105 | when 106 | where 107 | yield -------------------------------------------------------------------------------- /syntax_match.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from pathlib import Path 5 | from tree_sitter import Language, Parser 6 | from parser import ( 7 | DFG_python, 8 | DFG_java, 9 | DFG_ruby, 10 | DFG_go, 11 | DFG_php, 12 | DFG_javascript, 13 | DFG_csharp, 14 | remove_comments_and_docstrings, 15 | tree_to_token_index, 16 | index_to_code_token, 17 | tree_to_variable_index 18 | ) 19 | 20 | dfg_function = { 21 | 'python': DFG_python, 22 | 'java': DFG_java, 23 | 'ruby': DFG_ruby, 24 | 'go': DFG_go, 25 | 'php': DFG_php, 26 | 'javascript': DFG_javascript, 27 | 'c_sharp': DFG_csharp, 28 | } 29 | 30 | root_directory = Path(__file__).parents[2] 31 | PARSER_LOCATION = root_directory.joinpath("files_to_be_submitted/code_implementations/parser/my-languages.so") 32 | 33 | 34 | def calc_syntax_match(references, candidate, lang): 35 | return corpus_syntax_match([references], [candidate], lang) 36 | 37 | 38 | def corpus_syntax_match(references, candidates, lang): 39 | JAVA_LANGUAGE = Language(PARSER_LOCATION, lang) 40 | parser = Parser() 41 | parser.set_language(JAVA_LANGUAGE) 42 | match_count = 0 43 | total_count = 0 44 | 45 | for i in range(len(candidates)): 46 | references_sample = references[i] 47 | candidate = candidates[i] 48 | for reference in references_sample: 49 | try: 50 | candidate = remove_comments_and_docstrings(candidate, 'java') 51 | except: 52 | pass 53 | try: 54 | reference = remove_comments_and_docstrings(reference, 'java') 55 | except: 56 | pass 57 | 58 | candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node 59 | 60 | reference_tree = parser.parse(bytes(reference, 'utf8')).root_node 61 | 62 | def get_all_sub_trees(root_node): 63 | node_stack = [] 64 | sub_tree_sexp_list = [] 65 | depth = 1 66 | node_stack.append([root_node, depth]) 67 | while len(node_stack) != 0: 68 | cur_node, cur_depth = node_stack.pop() 69 | sub_tree_sexp_list.append([cur_node.sexp(), cur_depth]) 70 | for child_node in cur_node.children: 71 | if len(child_node.children) != 0: 72 | depth = cur_depth + 1 73 | node_stack.append([child_node, depth]) 74 | return sub_tree_sexp_list 75 | 76 | cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)] 77 | ref_sexps = get_all_sub_trees(reference_tree) 78 | 79 | for sub_tree, depth in ref_sexps: 80 | if sub_tree in cand_sexps: 81 | match_count += 1 82 | total_count += len(ref_sexps) 83 | 84 | score = match_count / total_count 85 | return score 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FloCo 2 | **Official implementation of the Towards Making Flowchart Images Machine Interpretable paper (ICDAR 2023)** 3 | 4 | [Paper](https://vl2g.github.io/projects/floco/docs/FLOCO-ICDAR2023.pdf) | [Project Page](https://vl2g.github.io/projects/floco/) 5 | ## Requirements 6 | * Use **python >= 3.10.8**. Conda recommended : [https://docs.anaconda.com/anaconda/install/linux/](https://docs.anaconda.com/anaconda/install/linux/) 7 | 8 | * Use **pytorch 1.13.1 CUDA 11.6** 9 | 10 | * Other requirements from 'requirements.txt' 11 | 12 | **To setup environment** 13 | ``` 14 | # create new env flow 15 | $ conda create -n flow python=3.10.8 16 | 17 | # activate flow 18 | $ conda activate flow 19 | 20 | # install pytorch, torchvision 21 | $ conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 22 | 23 | # install other dependencies 24 | $ pip install -r requirements.txt 25 | ``` 26 | 27 | ## Training 28 | 29 | ### Preparing dataset 30 | - Download the dataset [here](https://drive.google.com/file/d/1wkt0Tf8YpFFEx18jeB3ofUWl1xP3qzAn/view?usp=sharing) and unzip it. 31 | - The dataset directory should have the following structure: 32 | ``` 33 | [FloCo] 34 | ├── Train 35 | │   ├── codes 36 | │   ├── flowchart 37 | │   ├── png 38 | │   └── svg 39 | ├── Validation 40 | │   ├── codes 41 | │   ├── flowchart 42 | │   ├── png 43 | │   └── svg 44 | └── Test 45 |    ├── codes 46 |    ├── flowchart 47 |    ├── png 48 |    └── svg 49 | ``` 50 | 51 | ### Generating sequence embeddings 52 | - Encode flowchart images into sequential embeddings for each of train, validation and test sets separately 53 | ``` 54 | # Set path to the folder containing the png flowchart images 55 | # Set path to text file to save the encodings 56 | $ python generate_encodings.py 57 | ``` 58 | - The dataset directory should now look like: 59 | ``` 60 | [FloCo] 61 | ├── Train 62 | │   ├── codes 63 | │   ├── flowchart 64 | │   ├── png 65 | │   ├── svg 66 | │   └── encodings.txt 67 | ├── Validation 68 | │   ├── codes 69 | │   ├── flowchart 70 | │   ├── png 71 | │   ├── svg 72 | │   └── encodings.txt 73 | └── Test 74 |    ├── codes 75 |    ├── flowchart 76 |    ├── png 77 |    ├── svg 78 |    └── encodings.txt 79 | ``` 80 | 81 | ### Pre-train the model architecture 82 | ``` 83 | # Set path to augmented python codes and train set flowchart encodings 84 | # Set path to save model checkpoints and train logs 85 | $ python pre-training.py 86 | ``` 87 | ### Fine-tune the pre-trained model 88 | ``` 89 | # Set path to training and validation flowchart encodings and python codes respectively 90 | # Set path to save model checkpoints and train logs 91 | $ python fine-tuning.py 92 | ``` 93 | 94 | ## Inference 95 | - Generate python codes for unseen flowchart images using best checkpoints of the trained model 96 | ``` 97 | # Set path to flowchart encodings and python codes belonging to the test dataset 98 | # Define path to the best checkpoint saved above 99 | # Set path to save the generated codes 100 | $ python inference.py 101 | ``` 102 | 103 | ## Cite us 104 | - If you find this work useful for your research, please consider citing. 105 | ``` 106 | @inproceedings{shukla2023floco, 107 | author = "Shukla, Shreya and 108 | Gatti, Prajwal and 109 | Kumar, Yogesh and 110 | Yadav, Vikash and 111 | Mishra, Anand", 112 | title = "Towards Making Flowchart Images Machine Interpretable", 113 | booktitle = "ICDAR", 114 | year = "2023", 115 | } 116 | ``` 117 | 118 | ## Acknowledgements 119 | This repo uses scripts from https://github.com/salesforce/CodeT5/tree/main/evaluator/CodeBLEU to compute BLEU and CodeBLEU scores. 120 | 121 | Code provided by https://huggingface.co/Salesforce/codet5-small helped in implementing FloCo-T5. 122 | -------------------------------------------------------------------------------- /parser/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import re 5 | from io import StringIO 6 | import tokenize 7 | 8 | 9 | def remove_comments_and_docstrings(source, lang): 10 | if lang in ['python']: 11 | """ 12 | Returns 'source' minus comments and docstrings. 13 | """ 14 | io_obj = StringIO(source) 15 | out = "" 16 | prev_toktype = tokenize.INDENT 17 | last_lineno = -1 18 | last_col = 0 19 | for tok in tokenize.generate_tokens(io_obj.readline): 20 | token_type = tok[0] 21 | token_string = tok[1] 22 | start_line, start_col = tok[2] 23 | end_line, end_col = tok[3] 24 | ltext = tok[4] 25 | if start_line > last_lineno: 26 | last_col = 0 27 | if start_col > last_col: 28 | out += (" " * (start_col - last_col)) 29 | # Remove comments: 30 | if token_type == tokenize.COMMENT: 31 | pass 32 | # This series of conditionals removes docstrings: 33 | elif token_type == tokenize.STRING: 34 | if prev_toktype != tokenize.INDENT: 35 | # This is likely a docstring; double-check we're not inside an operator: 36 | if prev_toktype != tokenize.NEWLINE: 37 | if start_col > 0: 38 | out += token_string 39 | else: 40 | out += token_string 41 | prev_toktype = token_type 42 | last_col = end_col 43 | last_lineno = end_line 44 | temp = [] 45 | for x in out.split('\n'): 46 | if x.strip() != "": 47 | temp.append(x) 48 | return '\n'.join(temp) 49 | elif lang in ['ruby']: 50 | return source 51 | else: 52 | def replacer(match): 53 | s = match.group(0) 54 | if s.startswith('/'): 55 | return " " # note: a space and not an empty string 56 | else: 57 | return s 58 | 59 | pattern = re.compile( 60 | r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', 61 | re.DOTALL | re.MULTILINE 62 | ) 63 | temp = [] 64 | for x in re.sub(pattern, replacer, source).split('\n'): 65 | if x.strip() != "": 66 | temp.append(x) 67 | return '\n'.join(temp) 68 | 69 | 70 | def tree_to_token_index(root_node): 71 | if (len(root_node.children) == 0 or root_node.type == 'string') and root_node.type != 'comment': 72 | return [(root_node.start_point, root_node.end_point)] 73 | else: 74 | code_tokens = [] 75 | for child in root_node.children: 76 | code_tokens += tree_to_token_index(child) 77 | return code_tokens 78 | 79 | 80 | def tree_to_variable_index(root_node, index_to_code): 81 | if (len(root_node.children) == 0 or root_node.type == 'string') and root_node.type != 'comment': 82 | index = (root_node.start_point, root_node.end_point) 83 | _, code = index_to_code[index] 84 | if root_node.type != code: 85 | return [(root_node.start_point, root_node.end_point)] 86 | else: 87 | return [] 88 | else: 89 | code_tokens = [] 90 | for child in root_node.children: 91 | code_tokens += tree_to_variable_index(child, index_to_code) 92 | return code_tokens 93 | 94 | 95 | def index_to_code_token(index, code): 96 | start_point = index[0] 97 | end_point = index[1] 98 | if start_point[0] == end_point[0]: 99 | s = code[start_point[0]][start_point[1]:end_point[1]] 100 | else: 101 | s = "" 102 | s += code[start_point[0]][start_point[1]:] 103 | for i in range(start_point[0] + 1, end_point[0]): 104 | s += code[i] 105 | s += code[end_point[0]][:end_point[1]] 106 | return s 107 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Natural Language Toolkit: Utility functions 2 | # 3 | # Copyright (C) 2001-2020 NLTK Project 4 | # Author: Steven Bird 5 | # URL: 6 | # For license information, see LICENSE.TXT 7 | 8 | from itertools import chain 9 | 10 | 11 | def pad_sequence( 12 | sequence, 13 | n, 14 | pad_left=False, 15 | pad_right=False, 16 | left_pad_symbol=None, 17 | right_pad_symbol=None, 18 | ): 19 | """ 20 | Returns a padded sequence of items before ngram extraction. 21 | >>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='', right_pad_symbol='')) 22 | ['', 1, 2, 3, 4, 5, ''] 23 | >>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='')) 24 | ['', 1, 2, 3, 4, 5] 25 | >>> list(pad_sequence([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='')) 26 | [1, 2, 3, 4, 5, ''] 27 | :param sequence: the source data to be padded 28 | :type sequence: sequence or iter 29 | :param n: the degree of the ngrams 30 | :type n: int 31 | :param pad_left: whether the ngrams should be left-padded 32 | :type pad_left: bool 33 | :param pad_right: whether the ngrams should be right-padded 34 | :type pad_right: bool 35 | :param left_pad_symbol: the symbol to use for left padding (default is None) 36 | :type left_pad_symbol: any 37 | :param right_pad_symbol: the symbol to use for right padding (default is None) 38 | :type right_pad_symbol: any 39 | :rtype: sequence or iter 40 | """ 41 | sequence = iter(sequence) 42 | if pad_left: 43 | sequence = chain((left_pad_symbol,) * (n - 1), sequence) 44 | if pad_right: 45 | sequence = chain(sequence, (right_pad_symbol,) * (n - 1)) 46 | return sequence 47 | 48 | 49 | # add a flag to pad the sequence so we get peripheral ngrams? 50 | 51 | 52 | def ngrams( 53 | sequence, 54 | n, 55 | pad_left=False, 56 | pad_right=False, 57 | left_pad_symbol=None, 58 | right_pad_symbol=None, 59 | ): 60 | """ 61 | Return the ngrams generated from a sequence of items, as an iterator. 62 | For example: 63 | >>> from nltk.util import ngrams 64 | >>> list(ngrams([1,2,3,4,5], 3)) 65 | [(1, 2, 3), (2, 3, 4), (3, 4, 5)] 66 | Wrap with list for a list version of this function. Set pad_left 67 | or pad_right to true in order to get additional ngrams: 68 | >>> list(ngrams([1,2,3,4,5], 2, pad_right=True)) 69 | [(1, 2), (2, 3), (3, 4), (4, 5), (5, None)] 70 | >>> list(ngrams([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='')) 71 | [(1, 2), (2, 3), (3, 4), (4, 5), (5, '')] 72 | >>> list(ngrams([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='')) 73 | [('', 1), (1, 2), (2, 3), (3, 4), (4, 5)] 74 | >>> list(ngrams([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='', right_pad_symbol='')) 75 | [('', 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, '')] 76 | :param sequence: the source data to be converted into ngrams 77 | :type sequence: sequence or iter 78 | :param n: the degree of the ngrams 79 | :type n: int 80 | :param pad_left: whether the ngrams should be left-padded 81 | :type pad_left: bool 82 | :param pad_right: whether the ngrams should be right-padded 83 | :type pad_right: bool 84 | :param left_pad_symbol: the symbol to use for left padding (default is None) 85 | :type left_pad_symbol: any 86 | :param right_pad_symbol: the symbol to use for right padding (default is None) 87 | :type right_pad_symbol: any 88 | :rtype: sequence or iter 89 | """ 90 | sequence = pad_sequence( 91 | sequence, n, pad_left, pad_right, left_pad_symbol, right_pad_symbol 92 | ) 93 | 94 | history = [] 95 | while n > 1: 96 | # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator 97 | try: 98 | next_item = next(sequence) 99 | except StopIteration: 100 | # no more data, terminate the generator 101 | return 102 | history.append(next_item) 103 | n -= 1 104 | for item in sequence: 105 | history.append(item) 106 | yield tuple(history) 107 | del history[0] 108 | -------------------------------------------------------------------------------- /calc_code_bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding:utf-8 -*- 5 | import json 6 | import argparse 7 | 8 | import bleu as bleu 9 | import weighted_ngram_match as weighted_ngram_match 10 | import syntax_match as syntax_match 11 | import dataflow_match as dataflow_match 12 | 13 | from pathlib import Path 14 | 15 | root_directory = Path(__file__).parents[2] 16 | 17 | 18 | def make_weights(reference_tokens, key_word_list): 19 | return {token: 1 if token in key_word_list else 0.2 \ 20 | for token in reference_tokens} 21 | 22 | 23 | def compute_codebleu(hypothesis, references, lang, params='0.25,0.25,0.25,0.25'): 24 | alpha, beta, gamma, theta = [float(x) for x in params.split(',')] 25 | 26 | # calculate ngram match (BLEU) 27 | tokenized_hyps = [x.split() for x in hypothesis] 28 | tokenized_refs = [[x.split() for x in reference] for reference in references] 29 | 30 | ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps) 31 | 32 | # calculate weighted ngram match 33 | kw_file = root_directory.joinpath("files_to_be_submitted/code_implementations/keywords/{}.txt".format(lang)) 34 | keywords = [x.strip() for x in open(kw_file, 'r', encoding='utf-8').readlines()] 35 | 36 | tokenized_refs_with_weights = \ 37 | [ 38 | [ 39 | [ 40 | reference_tokens, make_weights(reference_tokens, keywords) 41 | ] for reference_tokens in reference 42 | ] for reference in tokenized_refs 43 | ] 44 | 45 | weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights, tokenized_hyps) 46 | 47 | # calculate syntax match 48 | syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis, lang) 49 | 50 | # calculate dataflow match 51 | dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang) 52 | 53 | code_bleu_score = alpha * ngram_match_score \ 54 | + beta * weighted_ngram_match_score \ 55 | + gamma * syntax_match_score \ 56 | + theta * dataflow_match_score 57 | 58 | return code_bleu_score, (ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score) 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--refs', type=str, nargs='+', required=True, help='reference files') 64 | parser.add_argument('--json_refs', action='store_true', help='reference files are JSON files') 65 | parser.add_argument('--hyp', type=str, required=True, help='hypothesis file') 66 | parser.add_argument('--lang', type=str, required=True, 67 | choices=['java', 'javascript', 'c_sharp', 'php', 'go', 'python', 'ruby'], 68 | help='programming language') 69 | parser.add_argument('--params', type=str, default='0.25,0.25,0.25,0.25', 70 | help='alpha, beta and gamma') 71 | 72 | args = parser.parse_args() 73 | 74 | # List(List(String)) 75 | # -> length of the outer List is number of references per translation 76 | # -> length of the inner List is number of total examples 77 | pre_references = [ 78 | [x.strip() for x in open(file, 'r', encoding='utf-8').readlines()] 79 | for file in args.refs 80 | ] 81 | # List(String) 82 | hypothesis = [x.strip() for x in open(args.hyp, 'r', encoding='utf-8').readlines()] 83 | 84 | for i in range(len(pre_references)): 85 | assert len(hypothesis) == len(pre_references[i]) 86 | 87 | references = [] 88 | for i in range(len(hypothesis)): 89 | ref_for_instance = [] 90 | for j in range(len(pre_references)): 91 | if args.json_refs: 92 | _ref = json.loads(pre_references[j][i]) 93 | ref_for_instance.append(_ref['code']) 94 | else: 95 | ref_for_instance.append(pre_references[j][i]) 96 | references.append(ref_for_instance) 97 | 98 | assert len(references) == len(pre_references) * len(hypothesis) 99 | 100 | # references is List(List(String)) where the inner List is a 101 | # list of reference translations for one example. 102 | code_bleu_score, (ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score) = \ 103 | compute_codebleu(hypothesis, references, args.lang, args.params) 104 | print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'. 105 | format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score)) 106 | print('CodeBLEU score: %.2f' % (code_bleu_score * 100.0)) 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /dataflow_match.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from pathlib import Path 5 | from tree_sitter import Language, Parser 6 | from parser import ( 7 | DFG_python, 8 | DFG_java, 9 | DFG_ruby, DFG_go, 10 | DFG_php, 11 | DFG_javascript, 12 | DFG_csharp, 13 | remove_comments_and_docstrings, 14 | tree_to_token_index, 15 | index_to_code_token, 16 | tree_to_variable_index 17 | ) 18 | 19 | dfg_function = { 20 | 'python': DFG_python, 21 | 'java': DFG_java, 22 | 'ruby': DFG_ruby, 23 | 'go': DFG_go, 24 | 'php': DFG_php, 25 | 'javascript': DFG_javascript, 26 | 'c_sharp': DFG_csharp, 27 | } 28 | 29 | root_directory = Path(__file__).parents[2] 30 | PARSER_LOCATION = root_directory.joinpath("files_to_be_submitted/code_implementations/parser/my-languages.so") 31 | 32 | 33 | def calc_dataflow_match(references, candidate, lang): 34 | return corpus_dataflow_match([references], [candidate], lang) 35 | 36 | 37 | def corpus_dataflow_match(references, candidates, lang): 38 | LANGUAGE = Language(PARSER_LOCATION, lang) 39 | parser = Parser() 40 | parser.set_language(LANGUAGE) 41 | parser = [parser, dfg_function[lang]] 42 | match_count = 0 43 | total_count = 0 44 | 45 | for i in range(len(candidates)): 46 | references_sample = references[i] 47 | candidate = candidates[i] 48 | for reference in references_sample: 49 | try: 50 | candidate = remove_comments_and_docstrings(candidate, 'java') 51 | except: 52 | pass 53 | try: 54 | reference = remove_comments_and_docstrings(reference, 'java') 55 | except: 56 | pass 57 | 58 | cand_dfg = get_data_flow(candidate, parser) 59 | ref_dfg = get_data_flow(reference, parser) 60 | 61 | normalized_cand_dfg = normalize_dataflow(cand_dfg) 62 | normalized_ref_dfg = normalize_dataflow(ref_dfg) 63 | 64 | if len(normalized_ref_dfg) > 0: 65 | total_count += len(normalized_ref_dfg) 66 | for dataflow in normalized_ref_dfg: 67 | if dataflow in normalized_cand_dfg: 68 | match_count += 1 69 | normalized_cand_dfg.remove(dataflow) 70 | 71 | score = match_count / total_count if total_count > 0 else 1.0 72 | return score 73 | 74 | 75 | def get_data_flow(code, parser): 76 | try: 77 | tree = parser[0].parse(bytes(code, 'utf8')) 78 | root_node = tree.root_node 79 | tokens_index = tree_to_token_index(root_node) 80 | code = code.split('\n') 81 | code_tokens = [index_to_code_token(x, code) for x in tokens_index] 82 | index_to_code = {} 83 | for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)): 84 | index_to_code[index] = (idx, code) 85 | try: 86 | DFG, _ = parser[1](root_node, index_to_code, {}) 87 | except: 88 | DFG = [] 89 | DFG = sorted(DFG, key=lambda x: x[1]) 90 | indexs = set() 91 | for d in DFG: 92 | if len(d[-1]) != 0: 93 | indexs.add(d[1]) 94 | for x in d[-1]: 95 | indexs.add(x) 96 | new_DFG = [] 97 | for d in DFG: 98 | if d[1] in indexs: 99 | new_DFG.append(d) 100 | codes = code_tokens 101 | dfg = new_DFG 102 | except: 103 | codes = code.split() 104 | dfg = [] 105 | # merge nodes 106 | dic = {} 107 | for d in dfg: 108 | if d[1] not in dic: 109 | dic[d[1]] = d 110 | else: 111 | dic[d[1]] = (d[0], d[1], d[2], list(set(dic[d[1]][3] + d[3])), list(set(dic[d[1]][4] + d[4]))) 112 | DFG = [] 113 | for d in dic: 114 | DFG.append(dic[d]) 115 | dfg = DFG 116 | return dfg 117 | 118 | 119 | def normalize_dataflow_item(dataflow_item): 120 | var_name = dataflow_item[0] 121 | var_pos = dataflow_item[1] 122 | relationship = dataflow_item[2] 123 | par_vars_name_list = dataflow_item[3] 124 | par_vars_pos_list = dataflow_item[4] 125 | 126 | var_names = list(set(par_vars_name_list + [var_name])) 127 | norm_names = {} 128 | for i in range(len(var_names)): 129 | norm_names[var_names[i]] = 'var_' + str(i) 130 | 131 | norm_var_name = norm_names[var_name] 132 | relationship = dataflow_item[2] 133 | norm_par_vars_name_list = [norm_names[x] for x in par_vars_name_list] 134 | 135 | return (norm_var_name, relationship, norm_par_vars_name_list) 136 | 137 | 138 | def normalize_dataflow(dataflow): 139 | var_dict = {} 140 | i = 0 141 | normalized_dataflow = [] 142 | for item in dataflow: 143 | var_name = item[0] 144 | relationship = item[2] 145 | par_vars_name_list = item[3] 146 | for name in par_vars_name_list: 147 | if name not in var_dict: 148 | var_dict[name] = 'var_' + str(i) 149 | i += 1 150 | if var_name not in var_dict: 151 | var_dict[var_name] = 'var_' + str(i) 152 | i += 1 153 | normalized_dataflow.append((var_dict[var_name], relationship, [var_dict[x] for x in par_vars_name_list])) 154 | return normalized_dataflow 155 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | import json 3 | import torch 4 | from torch.utils.data import Dataset 5 | from evaluate import load 6 | from calc_code_bleu import compute_codebleu 7 | from transformers import RobertaTokenizer, T5ForConditionalGeneration 8 | import warnings 9 | warnings.filterwarnings('ignore') 10 | 11 | 12 | def data_visualisation(code_pth, encodings_pth): 13 | with open(encodings_pth) as file: 14 | data = file.read() 15 | js = json.loads(data) 16 | image_ids = list(js.keys()) 17 | encodings = list(js.values()) 18 | python_codes=[] 19 | for id in image_ids: 20 | cdp = code_pth+str(id)+'.py' 21 | lines='' 22 | file = open(cdp, 'r') 23 | if file.read()[0]=='#': 24 | file = open(cdp, 'r') 25 | next(file) 26 | lines = file.read() 27 | else: 28 | file = open(cdp, 'r') 29 | lines = file.read() 30 | python_codes.append(lines) 31 | return image_ids, encodings, python_codes 32 | 33 | 34 | def CodeT5_tokenize(): 35 | tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base') 36 | return tokenizer 37 | 38 | 39 | class CustomDataset(Dataset): 40 | def __init__(self, input_ids, attention_mask, output, imageids): 41 | self.input_ids = input_ids 42 | self.attention_mask = attention_mask 43 | self.output = output 44 | self.imageids = imageids 45 | 46 | def __len__(self): 47 | return self.input_ids.shape[0] 48 | 49 | def __getitem__(self, idx): 50 | return (self.imageids[idx], self.input_ids[idx], self.attention_mask[idx], self.output[idx]) 51 | 52 | 53 | def data_loading(test_set, batch_size): 54 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False) 55 | return test_loader 56 | 57 | 58 | def CodeT5_model(): 59 | model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base') 60 | return model 61 | 62 | 63 | def writing_results(results_pth, test_loader, tokenizer, model, device): 64 | bleu_score=0.0 65 | codebleu_score=0.0 66 | exact_match=0.0 67 | for image_id, input_id, attention_mask, code in tqdm(test_loader): 68 | input_id = input_id.to(device) 69 | attention_mask = attention_mask.to(device) 70 | code = code.to(device) 71 | # Generating the code from the model 72 | outputs = model.generate(input_ids = input_id, attention_mask = attention_mask, return_dict_in_generate=True, output_scores=True, max_length=1024) 73 | out = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True) 74 | program = tokenizer.batch_decode(code, skip_special_tokens=True) 75 | encod = tokenizer.batch_decode(input_id, skip_special_tokens=True) 76 | exact_match_metric = load("exact_match") 77 | # Calculating the metrics for each generated code 78 | for i in range(len(program)): 79 | codebleu = compute_codebleu([out[i]], [[program[i]]], 'python')[0] 80 | bleu = compute_codebleu([out[i]], [[program[i]]], 'python')[1][0] 81 | EM = exact_match_metric.compute(predictions=[out[i]], references=[program[i]])['exact_match'] 82 | bleu_score+=bleu 83 | exact_match+=EM 84 | codebleu_score+=codebleu 85 | # Writing results to the file 86 | file=open(results_pth, 'a') 87 | for i in range(len(image_id)): 88 | file.write(str(image_id[i])+".png\n") 89 | file.write("Encoding from tokenizer : \n " + encod[i]) 90 | file.write("\n") 91 | file.write("Original Python Program : \n " + program[i]) 92 | file.write("\n") 93 | file.write("Output : \n " + out[i]) 94 | file.write("\n \n ") 95 | 96 | bleu = bleu_score/len(test_loader.dataset) 97 | EM = exact_match/len(test_loader.dataset) 98 | codebleu = codebleu_score/len(test_loader.dataset) 99 | 100 | print(bleu, codebleu, EM) 101 | 102 | 103 | def run(): 104 | # Path to the test codes 105 | test_code_pth = '' 106 | # Path to the test encodings 107 | test_encodings_pth = '' 108 | # Path to the trained model checkpoints 109 | trained_model_pth = '' 110 | # Path to file where the generated codes will be stored 111 | results_pth = '' 112 | # Batch size for the test data 113 | batch_size = 16 114 | 115 | # Loading test data 116 | test_image_ids, test_encodings, test_codes = data_visualisation(test_code_pth, test_encodings_pth) 117 | print(len(test_image_ids)) 118 | # Tokenize the test data with CodeT5 tokenizer 119 | tokenizer = CodeT5_tokenize() 120 | tokenizer.add_tokens(['[SEP]', 'PARALLELOGRAM', 'RECTANGLE', 'OVAL', 'DIAMOND'], special_tokens=True) 121 | test_input = tokenizer(test_encodings, padding='max_length', truncation=True, return_tensors='pt', max_length=512) 122 | with tokenizer.as_target_tokenizer(): 123 | test_labels = tokenizer(test_codes, padding='max_length', truncation=True, return_tensors='pt', max_length=512) 124 | 125 | # Create the test dataset and dataloader 126 | test_set = CustomDataset(test_input['input_ids'], test_input['attention_mask'], test_labels['input_ids'], test_image_ids) 127 | test_loader = data_loading(test_set, batch_size) 128 | 129 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 130 | # Load pre-trained CodeT5 model from HuggingFace 131 | model = CodeT5_model() 132 | model.resize_token_embeddings(len(tokenizer)) 133 | model = model.to(device) 134 | 135 | # Load the fine-tuned model for inference on test data 136 | model.load_state_dict(torch.load(trained_model_pth, map_location=torch.device('cuda:0'))) 137 | 138 | # Generate the results 139 | writing_results(results_pth, test_loader, tokenizer, model, device) 140 | 141 | 142 | run() 143 | -------------------------------------------------------------------------------- /pre-training.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm.auto import tqdm 3 | import random 4 | import json 5 | import torch 6 | from torch.utils.data import Dataset 7 | from transformers import get_polynomial_decay_schedule_with_warmup 8 | from transformers import RobertaTokenizer, T5ForConditionalGeneration 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | 12 | 13 | def data_loading(codes_pth, encodings_pth): 14 | with open(encodings_pth) as file: 15 | data = file.read() 16 | js = json.loads(data) 17 | encodings = list(js.values()) 18 | python_codes=[] 19 | for code in os.listdir(codes_pth): 20 | cdp = codes_pth+code 21 | lines='' 22 | file = open(cdp, 'r') 23 | lines = file.read() 24 | python_codes.append(lines) 25 | inputs = encodings 26 | inputs.extend(python_codes) 27 | random.shuffle(inputs) 28 | return inputs 29 | 30 | 31 | class CustomDataset(Dataset): 32 | def __init__(self, input_ids, attention_mask, labels): 33 | self.input_ids = input_ids 34 | self.attention_mask = attention_mask 35 | self.labels = labels 36 | def __getitem__(self, idx): 37 | return (self.input_ids[idx], self.attention_mask[idx], self.labels[idx]) 38 | def __len__(self): 39 | return self.input_ids.shape[0] 40 | 41 | 42 | def tokenize(inputs): 43 | tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base') 44 | tokenizer.add_tokens(['[SEP]', 'PARALLELOGRAM', 'RECTANGLE', 'OVAL', 'DIAMOND'], special_tokens=True) 45 | tokenized_inputs = tokenizer(inputs, return_tensors='pt', max_length=512, truncation=True, padding='max_length') 46 | return tokenizer, tokenized_inputs 47 | 48 | 49 | def masking(tokenized_inputs): 50 | # create a copy of input_ids tensor to be used as labels for MLM 51 | tokenized_inputs['labels'] = tokenized_inputs.input_ids.detach().clone() 52 | # create a random tensor of same shape as input_ids 53 | rand = torch.rand(tokenized_inputs.input_ids.shape) 54 | # mask 15% of the tokens in each sequence, while ignoring [CLS], [PAD] and [SEP] tokens 55 | mask_arr = (rand < 0.15) * (tokenized_inputs.input_ids != 0) * (tokenized_inputs.input_ids != 1) * (tokenized_inputs.input_ids != 2) * (tokenized_inputs.input_ids != 32100) 56 | 57 | # replace selected tokens with [MASK] token 58 | selection = [] 59 | for i in range(tokenized_inputs.input_ids.shape[0]): 60 | selection.append(torch.flatten(mask_arr[i].nonzero()).tolist()) 61 | 62 | for i in range(tokenized_inputs.input_ids.shape[0]): 63 | tokenized_inputs.input_ids[i, selection[i]] = 4 64 | 65 | return tokenized_inputs 66 | 67 | 68 | def MLM_model(): 69 | model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base') 70 | return model 71 | 72 | 73 | def train(epochs, train_loader, model, optim, scheduler, device, logs_pth, checkpoints_pth): 74 | best_loss=10000 75 | best_epoch=1 76 | for epoch in range(epochs): 77 | train_running_loss = 0.0 78 | model.train() 79 | for input_ids, attention_mask, labels in tqdm(train_loader): 80 | optim.zero_grad() 81 | 82 | input_ids = input_ids.to(device) 83 | attention_mask = attention_mask.to(device) 84 | labels = labels.to(device) 85 | 86 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 87 | # extract loss 88 | loss = outputs.loss 89 | 90 | train_running_loss+=loss.item()*input_ids.size(0) 91 | # calculate loss for every parameter that needs grad update 92 | loss.backward() 93 | # update parameters 94 | optim.step() 95 | scheduler.step() 96 | 97 | train_epoch_loss = train_running_loss/len(train_loader.dataset) 98 | file = open(logs_pth, 'a') 99 | file.write("Train Loss after " + str(epoch+1) + " epoch is : " + str(train_epoch_loss)+ "\n") 100 | 101 | if train_epoch_loss < best_loss: 102 | best_epoch = epoch+1 103 | best_loss = train_epoch_loss 104 | torch.save(model.state_dict(), os.path.join(checkpoints_pth, '{}.pth'.format(epoch+1))) 105 | 106 | print("Best epoch is ", best_epoch, " with loss ", best_loss) 107 | 108 | 109 | def run(): 110 | # path to augmented codes 111 | train_codes_pth = '' 112 | # path to encodings of FloCo train set 113 | train_encodings_pth = '' 114 | # path to a text file to save the logs 115 | logs_pth = '' 116 | # path to save checkpoints 117 | checkpoints_pth = '' 118 | 119 | # load and shuffle the train data 120 | train_inputs = data_loading(train_codes_pth, train_encodings_pth) 121 | print("Pre-training with ", len(train_inputs) ," samples") 122 | 123 | # tokenize the train data with CodeT5 tokenizer 124 | tokenizer, tokenized_train_inputs = tokenize(train_inputs) 125 | 126 | # masking tokens randomly at a probablity of 15% 127 | masked_train_inputs = masking(tokenized_train_inputs) 128 | 129 | train_set = CustomDataset(masked_train_inputs['input_ids'], masked_train_inputs['attention_mask'], masked_train_inputs['labels']) 130 | 131 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True) 132 | 133 | # Load pre-trained CodeT5 model from HuggingFace 134 | model = MLM_model() 135 | model.resize_token_embeddings(len(tokenizer)) 136 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 137 | model = model.to(device) 138 | 139 | # Define hyperparameters 140 | lr = 0.00001 141 | epochs = 100 142 | num_batches = len(train_loader) 143 | num_warmup_steps = 1100 144 | num_training_steps = epochs*num_batches 145 | optim = torch.optim.Adam(model.parameters(), lr=lr) 146 | scheduler = get_polynomial_decay_schedule_with_warmup(optim, num_warmup_steps, num_training_steps, power=2) 147 | 148 | # Train the model and save the checkpoints and logs 149 | train(epochs, train_loader, model, optim, scheduler, device, logs_pth, checkpoints_pth) 150 | 151 | 152 | run() 153 | -------------------------------------------------------------------------------- /generate_encodings.py: -------------------------------------------------------------------------------- 1 | import easyocr 2 | import cv2 3 | import os 4 | import json 5 | 6 | 7 | def getTextCoordinates(image, reader): 8 | results = reader.readtext(image, paragraph = True) 9 | l = [] 10 | 11 | for (bbox, text) in results: 12 | print(text, bbox) 13 | l.append((text, bbox)) 14 | return l 15 | 16 | 17 | def annotate_text(pathsrc): 18 | dictionary = {} 19 | reader = easyocr.Reader(['en']) 20 | 21 | for filename in os.listdir(pathsrc): 22 | path = os.path.join(pathsrc, filename) 23 | image = cv2.imread(path) 24 | textlist = getTextCoordinates(image, reader) 25 | dictionary[filename] = textlist 26 | 27 | return dictionary 28 | 29 | 30 | def getShapeCoordinates(img): 31 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 32 | _, threshold = cv2.threshold(gray, 127, 255, cv2.THRESH_OTSU) 33 | contours, _ = cv2.findContours(threshold, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 34 | shapes = {'Rectangle': [], 'Oval': [], 'Diamond': [], 'Parallelogram': []} 35 | i = 0 36 | for contour in contours: 37 | if i == 0: 38 | i = 1 39 | continue 40 | 41 | approx = cv2.approxPolyDP(contour, 0.01 * cv2.arcLength(contour, True), True) 42 | area = cv2.contourArea(contour) 43 | 44 | if area > 1000: 45 | 46 | if len(approx) == 4: 47 | x, y = [], [] # coordinates of blocks 48 | 49 | for l in approx: 50 | pair = l[0] 51 | x.append(pair[0]) 52 | y.append(pair[1]) 53 | 54 | if (abs(x[1]-x[0])<6 and abs(x[3]-x[2])<6 and abs(y[3]-y[0])<6 and abs(y[2]-y[1])<6) or (abs(x[3]-x[0])<6 and abs(x[1]-x[2])<6 and abs(y[1]-y[0])<6 and abs(y[2]-y[3])<6): 55 | shapes['Rectangle'].append(approx) 56 | 57 | elif (abs(x[1]-x[3])>6 and abs(x[2]-x[1])<50) or (abs(x[2]-x[0])>6 and abs(x[0]-x[1])<50): 58 | shapes['Parallelogram'].append(approx) 59 | 60 | elif (abs(x[1]-x[3])<6 and abs(x[2]-x[1])>50) or (abs(x[2]-x[0])<6 and abs(x[0]-x[1])>50): 61 | shapes['Diamond'].append(approx) 62 | 63 | elif len(approx) >=8 and len(approx) <= 13: 64 | shapes['Oval'].append(approx) 65 | 66 | return shapes 67 | 68 | 69 | def annotate_shapes(pathsrc): 70 | dictionary = {} 71 | 72 | for filename in os.listdir(pathsrc): 73 | path = os.path.join(pathsrc, filename) 74 | image = cv2.imread(path) 75 | shapes = getShapeCoordinates(image) 76 | dictionary[filename] = shapes 77 | 78 | return dictionary 79 | 80 | 81 | def find_min_max(coordinate_list): 82 | min_x, min_y, max_x, max_y = coordinate_list[0][0], coordinate_list[0][1], coordinate_list[0][0], coordinate_list[0][1] 83 | for x, y in coordinate_list: 84 | if min_x > x: 85 | min_x = x 86 | if min_y > y: 87 | min_y = y 88 | if max_x < x: 89 | max_x = x 90 | if max_y < y: 91 | max_y = y 92 | return (min_x, min_y, max_x, max_y) 93 | 94 | def is_within(text_coordinates, min_x, min_y, max_x, max_y): 95 | tl, tr, bl, br = text_coordinates[0], text_coordinates[1], text_coordinates[2], text_coordinates[3] 96 | flag = False 97 | if (min_x < tl[0] < max_x) and (min_x < tr[0] < max_x) and (min_x < bl[0] < max_x) and (min_x < br[0] < max_x) and (min_y < tl[1] < max_y) and (min_y < tr[1] < max_y) and (min_y < bl[1] < max_y) and (min_y < br[1] < max_y): 98 | flag = True 99 | return flag 100 | 101 | def narray_to_list(narray): 102 | l = [] 103 | for points in narray: 104 | l.append([points[0][0], points[0][1]]) 105 | return l 106 | 107 | def associate_shape_each(text, text_coordinates, shape_coordinates): 108 | SHAPE = None 109 | SHAPELIST = ['Rectangle', 'Diamond', 'Parallelogram', 'Oval'] 110 | dcount = 0 111 | 112 | for shape in SHAPELIST: 113 | for narray in shape_coordinates[shape]: 114 | l = narray_to_list(narray) 115 | min_x, min_y, max_x, max_y = find_min_max(l) 116 | flag = is_within(text_coordinates, min_x, min_y, max_x, max_y) 117 | if flag: 118 | SHAPE = shape.upper() 119 | 120 | return SHAPE 121 | 122 | def get_diamond_coordinates(text_coordinates, shape_coordinates): 123 | 124 | for narray in shape_coordinates['Diamond']: 125 | l = narray_to_list(narray) 126 | min_x, min_y, max_x, max_y = find_min_max(l) 127 | flag = is_within(text_coordinates, min_x, min_y, max_x, max_y) 128 | if flag: 129 | return l 130 | 131 | def find_centroid(l): 132 | sum_x, sum_y = 0, 0 133 | for x, y in l: 134 | sum_x += x 135 | sum_y += y 136 | 137 | centroid = (sum_x//4, sum_y//4) 138 | return centroid 139 | 140 | def find_distance(p1, p2): 141 | dis = (p1[0] - p2[0])**2 + (p1[1] - p2[1])**2 142 | return dis 143 | 144 | def associate_nearest_diamond(text_coord, diamond_coordinates): 145 | text_centroid = find_centroid(text_coord) 146 | min_dis = float('inf') 147 | nearest = None 148 | for diamond in diamond_coordinates: 149 | diamond_centroid = find_centroid(diamond_coordinates[diamond]) 150 | dis = find_distance(text_centroid, diamond_centroid) 151 | if dis < min_dis: 152 | min_dis = dis 153 | nearest = diamond 154 | return nearest 155 | 156 | def associate_shape(name, text_dict, shape_dict): 157 | shape_coordinates = shape_dict[name] 158 | dcount = 0 159 | diamond_coordinates = {} 160 | text_shape_coord_list = [] 161 | text_shape_list = [] 162 | encoding = '' 163 | for text, text_coordinates in text_dict[name]: 164 | SHAPE = associate_shape_each(text, text_coordinates, shape_coordinates) 165 | if SHAPE == 'DIAMOND': 166 | dcount += 1 167 | SHAPE += str(dcount) 168 | diamond_coordinates[SHAPE] = get_diamond_coordinates(text_coordinates, shape_coordinates) 169 | text_shape_coord_list.append((text, SHAPE, text_coordinates)) 170 | 171 | 172 | for text, shape, text_coord in text_shape_coord_list: 173 | if shape == None: 174 | SHAPE = associate_nearest_diamond(text_coord, diamond_coordinates) 175 | text_shape_list.append((text, SHAPE)) 176 | if SHAPE == None: 177 | encoding += '{'+text+',None},' 178 | else: 179 | encoding += '{'+text+','+SHAPE+'},' 180 | else: 181 | text_shape_list.append((text, shape)) 182 | encoding += '{'+text+','+shape+'},' 183 | #print(text_shape_list) 184 | encoding = encoding[:-1] 185 | return (text_shape_list, encoding) 186 | 187 | 188 | def annotate_encodings(pathsrc, text_dict, shape_dict, encodings_pth): 189 | dictionary_tuple = {} 190 | dictionary_string = {} 191 | dictionary_modified_string = {} 192 | 193 | for filename in os.listdir(pathsrc): 194 | encoding_tuple, encoding_string = associate_shape(filename, text_dict, shape_dict) 195 | dictionary_tuple[filename[:-4]] = encoding_tuple 196 | dictionary_string[filename[:-4]] = encoding_string 197 | dictionary_modified_string[filename[:-4]] = encoding_string[1:-1].replace('},{', ' [SEP] ') 198 | 199 | with open(encodings_pth, 'w') as convert_file: 200 | convert_file.write(json.dumps(dictionary_modified_string)) 201 | 202 | 203 | def get_encodings(): 204 | # Set path to png flowchart images 205 | pngpath = "" 206 | # Set path to a file to save encodings 207 | encodings_pth = "" 208 | 209 | # Get text inside flowchart blocks and on arrowheads 210 | # along with their coordinates with respect to the flowchart image 211 | # using easyocr 212 | text_dict = annotate_text(pngpath) 213 | 214 | # Get shape coordinates of flowchart blocks and categorize them 215 | # using contour detection into Rectangle, Diamond, Parallelogram and Oval 216 | shape_dict = annotate_shapes(pngpath) 217 | 218 | annotate_encodings(pngpath, text_dict, shape_dict, encodings_pth) 219 | 220 | 221 | get_encodings() -------------------------------------------------------------------------------- /fine-tuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm.auto import tqdm 3 | import json 4 | import torch 5 | from torch.utils.data import Dataset 6 | from evaluate import load 7 | from calc_code_bleu import compute_codebleu 8 | from transformers import get_polynomial_decay_schedule_with_warmup 9 | from transformers import RobertaTokenizer, T5ForConditionalGeneration 10 | import warnings 11 | warnings.filterwarnings('ignore') 12 | 13 | 14 | def data_visualisation(code_pth, encodings_pth): 15 | with open(encodings_pth) as file: 16 | data = file.read() 17 | js = json.loads(data) 18 | image_ids = list(js.keys()) 19 | encodings = list(js.values()) 20 | python_codes=[] 21 | for id in image_ids: 22 | cdp = code_pth+str(id)+'.py' 23 | lines='' 24 | file = open(cdp, 'r') 25 | if file.read()[0]=='#': 26 | file = open(cdp, 'r') 27 | next(file) 28 | lines = file.read() 29 | else: 30 | file = open(cdp, 'r') 31 | lines = file.read() 32 | python_codes.append(lines) 33 | return image_ids, encodings, python_codes 34 | 35 | 36 | def CodeT5_tokenize(): 37 | tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base') 38 | return tokenizer 39 | 40 | 41 | class CustomDataset(Dataset): 42 | def __init__(self, input_ids, attention_mask, output, lbl_input_ids, imageids): 43 | self.input_ids = input_ids 44 | self.attention_mask = attention_mask 45 | self.output = output 46 | self.lbl_input_ids = lbl_input_ids 47 | self.imageids = imageids 48 | 49 | def __len__(self): 50 | return self.input_ids.shape[0] 51 | 52 | def __getitem__(self, idx): 53 | return (self.imageids[idx], self.input_ids[idx], self.attention_mask[idx], self.lbl_input_ids[idx], self.output[idx]) 54 | 55 | 56 | def data_loading(train_set, val_set, batch_size): 57 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True) 58 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True) 59 | return train_loader, val_loader 60 | 61 | 62 | def CodeT5_model(): 63 | model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base') 64 | return model 65 | 66 | 67 | def train(train_loader, val_loader, tokenizer, num_epochs, model, device, optimizer, scheduler, logs_pth, checkpoints_pth): 68 | best_codebleu=0.0 69 | best_epoch=1 70 | exact_match_metric = load("exact_match") 71 | for epoch in range(num_epochs): 72 | train_running_loss = 0.0 73 | model.train() 74 | print("Epoch: ", epoch) 75 | bleu_score=0.0 76 | codebleu_score=0.0 77 | exact_match=0.0 78 | for image_id, input_id, attention_mask, label, code in tqdm(train_loader): 79 | optimizer.zero_grad() 80 | input_id = input_id.to(device) 81 | attention_mask = attention_mask.to(device) 82 | label = label.to(device) 83 | # forward pass to get outputs 84 | outputs = model(input_ids = input_id, attention_mask = attention_mask, labels = label) 85 | # calculate and backpropagate loss to update parameters 86 | loss = outputs.loss 87 | train_running_loss+=loss.item()*input_id.size(0) 88 | loss.backward() 89 | optimizer.step() 90 | scheduler.step() 91 | 92 | # Change and print the learning rate after every epoch 93 | for param_group in optimizer.param_groups: 94 | lr = param_group['lr'] 95 | 96 | # calculate average loss for the epoch 97 | train_epoch_loss = train_running_loss/len(train_loader.dataset) 98 | file=open(logs_pth, 'a') 99 | file.write("Learning Rate: " + str(lr) + "\n") 100 | file.write("Train Loss after " + str(epoch+1) + " epoch is : " + str(train_epoch_loss)+ "\n") 101 | 102 | bleu_score=0.0 103 | codebleu_score=0.0 104 | exact_match=0.0 105 | model.eval() 106 | with torch.no_grad(): 107 | for image_id, input_id, attention_mask, label, code in tqdm(val_loader): 108 | input_id = input_id.to(device) 109 | attention_mask = attention_mask.to(device) 110 | code = code.to(device) 111 | # Generating codes from the model 112 | outputs = model.generate(input_ids = input_id, attention_mask = attention_mask, return_dict_in_generate=True, output_scores=True, max_length=1024) 113 | # Decode the output to get the predicted code 114 | out = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True) 115 | # Decode the label to get the actual code 116 | program = tokenizer.batch_decode(code, skip_special_tokens=True) 117 | 118 | # calculate codebleu, bleu and exact match scores for the predicted codes with respect to the actual codes 119 | for i in range(len(program)): 120 | codebleu = compute_codebleu([out[i]], [[program[i]]], 'python')[0] 121 | bleu = compute_codebleu([out[i]], [[program[i]]], 'python')[1][0] 122 | EM = exact_match_metric.compute(predictions=[out[i]], references=[program[i]])['exact_match'] 123 | bleu_score+=bleu 124 | exact_match+=EM 125 | codebleu_score+=codebleu 126 | # calculate the average loss, bleu, exact match and codebleu scores for the validation set 127 | val_epoch_bleu = bleu_score/len(val_loader.dataset) 128 | val_epoch_EM = exact_match/len(val_loader.dataset) 129 | val_epoch_codebleu = codebleu_score/len(val_loader.dataset) 130 | 131 | file.write("Validation Bleu score after " + str(epoch+1) + " epoch is : " + str(val_epoch_bleu)+ "\n") 132 | file.write("Validation Exact Match score after " + str(epoch+1) + " epoch is : " + str(val_epoch_EM)+ "\n") 133 | file.write("Validation CodeBleu score after " + str(epoch+1) + " epoch is : " + str(val_epoch_codebleu)+ "\n\n") 134 | # Save the model with the best validation codebleu score 135 | if val_epoch_codebleu > best_codebleu: 136 | best_epoch = epoch+1 137 | best_codebleu = val_epoch_codebleu 138 | torch.save(model.state_dict(), os.path.join(checkpoints_pth, '{}.pth'.format(epoch+1))) 139 | 140 | file=open(logs_pth, 'a') 141 | file.write("\n Best validation loss for " + str(best_epoch) + " epoch is : " + str(best_codebleu)+ "\n") 142 | 143 | 144 | def run(): 145 | # Path to the train codes 146 | train_code_pth = '' 147 | # Path to the train encodings 148 | train_encodings_pth = '' 149 | # Path to the validation codes 150 | val_code_pth = '' 151 | # Path to the validation encodings 152 | val_encodings_pth = '' 153 | # path to a text file to save the logs 154 | logs_pth = '' 155 | # path to save checkpoints 156 | checkpoints_pth = '' 157 | # batch size for fine-tuning 158 | batch_size = 16 159 | 160 | # Load the train and validation data from respective folders 161 | train_image_ids, train_encodings, train_codes = data_visualisation(train_code_pth, train_encodings_pth) 162 | val_image_ids, val_encodings, val_codes = data_visualisation(val_code_pth, val_encodings_pth) 163 | 164 | # tokenize the train and validation data with CodeT5 tokenizer 165 | tokenizer = CodeT5_tokenize() 166 | tokenizer.add_tokens(['[SEP]', 'PARALLELOGRAM', 'RECTANGLE', 'OVAL', 'DIAMOND'], special_tokens=True) 167 | train_input = tokenizer(train_encodings, padding='max_length', truncation=True, return_tensors='pt', max_length=512) 168 | val_input = tokenizer(val_encodings, padding='max_length', truncation=True, return_tensors='pt', max_length=512) 169 | with tokenizer.as_target_tokenizer(): 170 | train_labels = tokenizer(train_codes, padding='max_length', truncation=True, return_tensors='pt', max_length=512) 171 | val_labels = tokenizer(val_codes, padding='max_length', truncation=True, return_tensors='pt', max_length=512) 172 | 173 | # Set the labels to -100 for the padding tokens 174 | train_lbl_input_ids = torch.clone(train_labels['input_ids']) 175 | val_lbl_input_ids = torch.clone(val_labels['input_ids']) 176 | for mask in range(train_labels['attention_mask'].shape[0]): 177 | indices = (train_labels['attention_mask'][mask] == 0).nonzero(as_tuple=True)[0] 178 | train_lbl_input_ids[mask][indices] = -100 179 | 180 | for mask in range(val_labels['attention_mask'].shape[0]): 181 | indices = (val_labels['attention_mask'][mask] == 0).nonzero(as_tuple=True)[0] 182 | val_lbl_input_ids[mask][indices] = -100 183 | 184 | # Create the train and validation dataset and dataloaders 185 | train_set = CustomDataset(train_input['input_ids'], train_input['attention_mask'], train_labels['input_ids'], train_lbl_input_ids, train_image_ids) 186 | val_set = CustomDataset(val_input['input_ids'], val_input['attention_mask'], val_labels['input_ids'], val_lbl_input_ids, val_image_ids) 187 | train_loader, val_loader = data_loading(train_set, val_set, batch_size) 188 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 189 | 190 | # Load pre-trained CodeT5 model from HuggingFace 191 | model = CodeT5_model() 192 | model.resize_token_embeddings(len(tokenizer)) 193 | model = model.to(device) 194 | 195 | # Define hyperparameters 196 | num_epochs = 100 197 | lr = 0.00001 198 | num_batches = len(train_loader) 199 | num_training_steps = num_epochs*num_batches 200 | num_warmup_steps = 2450 201 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 202 | scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, power=2) 203 | # Train the model and save the checkpoints and logs 204 | train(train_loader, val_loader, tokenizer, num_epochs, model, device, optimizer, scheduler, logs_pth, checkpoints_pth) 205 | 206 | 207 | run() 208 | -------------------------------------------------------------------------------- /weighted_ngram_match.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT license. 4 | 5 | # Natural Language Toolkit: BLEU Score 6 | # 7 | # Copyright (C) 2001-2020 NLTK Project 8 | # Authors: Chin Yee Lee, Hengfeng Li, Ruxin Hou, Calvin Tanujaya Lim 9 | # Contributors: Björn Mattsson, Dmitrijs Milajevs, Liling Tan 10 | # URL: 11 | # For license information, see LICENSE.TXT 12 | 13 | """BLEU score implementation.""" 14 | 15 | import math 16 | import sys 17 | from fractions import Fraction 18 | import warnings 19 | from collections import Counter 20 | 21 | from utils import ngrams 22 | 23 | 24 | def sentence_bleu( 25 | references, 26 | hypothesis, 27 | weights=(0.25, 0.25, 0.25, 0.25), 28 | smoothing_function=None, 29 | auto_reweigh=False, 30 | ): 31 | """ 32 | Calculate BLEU score (Bilingual Evaluation Understudy) from 33 | Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu. 2002. 34 | "BLEU: a method for automatic evaluation of machine translation." 35 | In Proceedings of ACL. http://www.aclweb.org/anthology/P02-1040.pdf 36 | >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 37 | ... 'ensures', 'that', 'the', 'military', 'always', 38 | ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] 39 | >>> hypothesis2 = ['It', 'is', 'to', 'insure', 'the', 'troops', 40 | ... 'forever', 'hearing', 'the', 'activity', 'guidebook', 41 | ... 'that', 'party', 'direct'] 42 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 43 | ... 'ensures', 'that', 'the', 'military', 'will', 'forever', 44 | ... 'heed', 'Party', 'commands'] 45 | >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which', 46 | ... 'guarantees', 'the', 'military', 'forces', 'always', 47 | ... 'being', 'under', 'the', 'command', 'of', 'the', 48 | ... 'Party'] 49 | >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 50 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 51 | ... 'of', 'the', 'party'] 52 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis1) # doctest: +ELLIPSIS 53 | 0.5045... 54 | If there is no ngrams overlap for any order of n-grams, BLEU returns the 55 | value 0. This is because the precision for the order of n-grams without 56 | overlap is 0, and the geometric mean in the final BLEU score computation 57 | multiplies the 0 with the precision of other n-grams. This results in 0 58 | (independently of the precision of the othe n-gram orders). The following 59 | example has zero 3-gram and 4-gram overlaps: 60 | >>> round(sentence_bleu([reference1, reference2, reference3], hypothesis2),4) # doctest: +ELLIPSIS 61 | 0.0 62 | To avoid this harsh behaviour when no ngram overlaps are found a smoothing 63 | function can be used. 64 | >>> chencherry = SmoothingFunction() 65 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis2, 66 | ... smoothing_function=chencherry.method1) # doctest: +ELLIPSIS 67 | 0.0370... 68 | The default BLEU calculates a score for up to 4-grams using uniform 69 | weights (this is called BLEU-4). To evaluate your translations with 70 | higher/lower order ngrams, use customized weights. E.g. when accounting 71 | for up to 5-grams with uniform weights (this is called BLEU-5) use: 72 | >>> weights = (1./5., 1./5., 1./5., 1./5., 1./5.) 73 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis1, weights) # doctest: +ELLIPSIS 74 | 0.3920... 75 | :param references: reference sentences 76 | :type references: list(list(str)) 77 | :param hypothesis: a hypothesis sentence 78 | :type hypothesis: list(str) 79 | :param weights: weights for unigrams, bigrams, trigrams and so on 80 | :type weights: list(float) 81 | :param smoothing_function: 82 | :type smoothing_function: SmoothingFunction 83 | :param auto_reweigh: Option to re-normalize the weights uniformly. 84 | :type auto_reweigh: bool 85 | :return: The sentence-level BLEU score. 86 | :rtype: float 87 | """ 88 | return corpus_bleu( 89 | [references], [hypothesis], weights, smoothing_function, auto_reweigh 90 | ) 91 | 92 | 93 | def corpus_bleu( 94 | list_of_references, 95 | hypotheses, 96 | weights=(0.25, 0.25, 0.25, 0.25), 97 | smoothing_function=None, 98 | auto_reweigh=False, 99 | ): 100 | """ 101 | Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all 102 | the hypotheses and their respective references. 103 | Instead of averaging the sentence level BLEU scores (i.e. marco-average 104 | precision), the original BLEU metric (Papineni et al. 2002) accounts for 105 | the micro-average precision (i.e. summing the numerators and denominators 106 | for each hypothesis-reference(s) pairs before the division). 107 | >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 108 | ... 'ensures', 'that', 'the', 'military', 'always', 109 | ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] 110 | >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 111 | ... 'ensures', 'that', 'the', 'military', 'will', 'forever', 112 | ... 'heed', 'Party', 'commands'] 113 | >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which', 114 | ... 'guarantees', 'the', 'military', 'forces', 'always', 115 | ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party'] 116 | >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 117 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 118 | ... 'of', 'the', 'party'] 119 | >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was', 120 | ... 'interested', 'in', 'world', 'history'] 121 | >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history', 122 | ... 'because', 'he', 'read', 'the', 'book'] 123 | >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]] 124 | >>> hypotheses = [hyp1, hyp2] 125 | >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS 126 | 0.5920... 127 | The example below show that corpus_bleu() is different from averaging 128 | sentence_bleu() for hypotheses 129 | >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1) 130 | >>> score2 = sentence_bleu([ref2a], hyp2) 131 | >>> (score1 + score2) / 2 # doctest: +ELLIPSIS 132 | 0.6223... 133 | :param list_of_references: a corpus of lists of reference sentences, w.r.t. hypotheses 134 | :type list_of_references: list(list(list(str))) 135 | :param hypotheses: a list of hypothesis sentences 136 | :type hypotheses: list(list(str)) 137 | :param weights: weights for unigrams, bigrams, trigrams and so on 138 | :type weights: list(float) 139 | :param smoothing_function: 140 | :type smoothing_function: SmoothingFunction 141 | :param auto_reweigh: Option to re-normalize the weights uniformly. 142 | :type auto_reweigh: bool 143 | :return: The corpus-level BLEU score. 144 | :rtype: float 145 | """ 146 | # Before proceeding to compute BLEU, perform sanity checks. 147 | 148 | p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches. 149 | p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref. 150 | hyp_lengths, ref_lengths = 0, 0 151 | 152 | assert len(list_of_references) == len(hypotheses), ( 153 | "The number of hypotheses and their reference(s) should be the " "same " 154 | ) 155 | 156 | # Iterate through each hypothesis and their corresponding references. 157 | for references, hypothesis in zip(list_of_references, hypotheses): 158 | # For each order of ngram, calculate the numerator and 159 | # denominator for the corpus-level modified precision. 160 | for i, _ in enumerate(weights, start=1): 161 | p_i_numeraotr, p_i_denominator = modified_recall(references, hypothesis, i) 162 | p_numerators[i] += p_i_numeraotr 163 | p_denominators[i] += p_i_denominator 164 | 165 | # Calculate the hypothesis length and the closest reference length. 166 | # Adds them to the corpus-level hypothesis and reference counts. 167 | hyp_len = len(hypothesis) 168 | hyp_lengths += hyp_len 169 | ref_lengths += closest_ref_length(references, hyp_len) 170 | 171 | # Calculate corpus-level brevity penalty. 172 | bp = brevity_penalty(ref_lengths, hyp_lengths) 173 | 174 | # Uniformly re-weighting based on maximum hypothesis lengths if largest 175 | # order of n-grams < 4 and weights is set at default. 176 | if auto_reweigh: 177 | if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25): 178 | weights = (1 / hyp_lengths,) * hyp_lengths 179 | 180 | # Collects the various recall values for the different ngram orders. 181 | p_n = [ 182 | (p_numerators[i], p_denominators[i]) 183 | for i, _ in enumerate(weights, start=1) 184 | ] 185 | 186 | # Returns 0 if there's no matching n-grams 187 | # We only need to check for p_numerators[1] == 0, since if there's 188 | # no unigrams, there won't be any higher order ngrams. 189 | if p_numerators[1] == 0: 190 | return 0 191 | 192 | # If there's no smoothing, set use method0 from SmoothinFunction class. 193 | if not smoothing_function: 194 | smoothing_function = SmoothingFunction().method1 195 | # Smoothen the modified precision. 196 | # Note: smoothing_function() may convert values into floats; 197 | # it tries to retain the Fraction object as much as the 198 | # smoothing method allows. 199 | p_n = smoothing_function( 200 | p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths 201 | ) 202 | # pdb.set_trace() 203 | s = (w_i * math.log(p_i[0] / p_i[1]) for w_i, p_i in zip(weights, p_n)) 204 | s = bp * math.exp(math.fsum(s)) 205 | return s 206 | 207 | 208 | def modified_recall(references, hypothesis, n): 209 | """ 210 | Calculate modified ngram recall. 211 | :param references: A list of reference translations. 212 | :type references: list(list(str)) 213 | :param hypothesis: A hypothesis translation. 214 | :type hypothesis: list(str) 215 | :param n: The ngram order. 216 | :type n: int 217 | :return: BLEU's modified precision for the nth order ngram. 218 | :rtype: Fraction 219 | """ 220 | # Extracts all ngrams in hypothesis 221 | # Set an empty Counter if hypothesis is empty. 222 | # pdb.set_trace() 223 | numerator = 0 224 | denominator = 0 225 | 226 | counts = Counter(ngrams(hypothesis, n)) if len(hypothesis) >= n else Counter() 227 | # Extract a union of references' counts. 228 | # max_counts = reduce(or_, [Counter(ngrams(ref, n)) for ref in references]) 229 | max_counts = {} 230 | for reference_and_weights in references: 231 | reference = reference_and_weights[0] 232 | weights = reference_and_weights[1] 233 | reference_counts = ( 234 | Counter(ngrams(reference, n)) if len(reference) >= n else Counter() 235 | ) 236 | # for ngram in reference_counts: 237 | # max_counts[ngram] = max(max_counts.get(ngram, 0), counts[ngram]) 238 | clipped_counts = { 239 | ngram: min(count, counts[ngram]) for ngram, count in reference_counts.items() 240 | } 241 | # reweight 242 | if n == 1 and len(weights) == len(reference_counts): 243 | def weighted_sum(weights, counts): 244 | sum_counts = 0 245 | for ngram, count in counts.items(): 246 | sum_counts += count * (weights[ngram[0]] if ngram[0] in weights else 1) 247 | return sum_counts 248 | 249 | numerator += weighted_sum(weights, clipped_counts) 250 | denominator += max(1, weighted_sum(weights, reference_counts)) 251 | 252 | else: 253 | numerator += sum(clipped_counts.values()) 254 | denominator += max(1, sum(reference_counts.values())) 255 | 256 | # # Assigns the intersection between hypothesis and references' counts. 257 | # clipped_counts = { 258 | # ngram: min(count, max_counts[ngram]) for ngram, count in counts.items() 259 | # } 260 | 261 | # numerator += sum(clipped_counts.values()) 262 | # # Ensures that denominator is minimum 1 to avoid ZeroDivisionError. 263 | # # Usually this happens when the ngram order is > len(reference). 264 | # denominator += max(1, sum(counts.values())) 265 | 266 | # return Fraction(numerator, denominator, _normalize=False) 267 | return numerator, denominator 268 | 269 | 270 | def closest_ref_length(references, hyp_len): 271 | """ 272 | This function finds the reference that is the closest length to the 273 | hypothesis. The closest reference length is referred to as *r* variable 274 | from the brevity penalty formula in Papineni et. al. (2002) 275 | :param references: A list of reference translations. 276 | :type references: list(list(str)) 277 | :param hyp_len: The length of the hypothesis. 278 | :type hyp_len: int 279 | :return: The length of the reference that's closest to the hypothesis. 280 | :rtype: int 281 | """ 282 | ref_lens = (len(reference) for reference in references) 283 | closest_ref_len = min( 284 | ref_lens, key=lambda ref_len: (abs(ref_len - hyp_len), ref_len) 285 | ) 286 | return closest_ref_len 287 | 288 | 289 | def brevity_penalty(closest_ref_len, hyp_len): 290 | """ 291 | Calculate brevity penalty. 292 | As the modified n-gram precision still has the problem from the short 293 | length sentence, brevity penalty is used to modify the overall BLEU 294 | score according to length. 295 | An example from the paper. There are three references with length 12, 15 296 | and 17. And a concise hypothesis of the length 12. The brevity penalty is 1. 297 | >>> reference1 = list('aaaaaaaaaaaa') # i.e. ['a'] * 12 298 | >>> reference2 = list('aaaaaaaaaaaaaaa') # i.e. ['a'] * 15 299 | >>> reference3 = list('aaaaaaaaaaaaaaaaa') # i.e. ['a'] * 17 300 | >>> hypothesis = list('aaaaaaaaaaaa') # i.e. ['a'] * 12 301 | >>> references = [reference1, reference2, reference3] 302 | >>> hyp_len = len(hypothesis) 303 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 304 | >>> brevity_penalty(closest_ref_len, hyp_len) 305 | 1.0 306 | In case a hypothesis translation is shorter than the references, penalty is 307 | applied. 308 | >>> references = [['a'] * 28, ['a'] * 28] 309 | >>> hypothesis = ['a'] * 12 310 | >>> hyp_len = len(hypothesis) 311 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 312 | >>> brevity_penalty(closest_ref_len, hyp_len) 313 | 0.2635971381157267 314 | The length of the closest reference is used to compute the penalty. If the 315 | length of a hypothesis is 12, and the reference lengths are 13 and 2, the 316 | penalty is applied because the hypothesis length (12) is less then the 317 | closest reference length (13). 318 | >>> references = [['a'] * 13, ['a'] * 2] 319 | >>> hypothesis = ['a'] * 12 320 | >>> hyp_len = len(hypothesis) 321 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 322 | >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS 323 | 0.9200... 324 | The brevity penalty doesn't depend on reference order. More importantly, 325 | when two reference sentences are at the same distance, the shortest 326 | reference sentence length is used. 327 | >>> references = [['a'] * 13, ['a'] * 11] 328 | >>> hypothesis = ['a'] * 12 329 | >>> hyp_len = len(hypothesis) 330 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 331 | >>> bp1 = brevity_penalty(closest_ref_len, hyp_len) 332 | >>> hyp_len = len(hypothesis) 333 | >>> closest_ref_len = closest_ref_length(reversed(references), hyp_len) 334 | >>> bp2 = brevity_penalty(closest_ref_len, hyp_len) 335 | >>> bp1 == bp2 == 1 336 | True 337 | A test example from mteval-v13a.pl (starting from the line 705): 338 | >>> references = [['a'] * 11, ['a'] * 8] 339 | >>> hypothesis = ['a'] * 7 340 | >>> hyp_len = len(hypothesis) 341 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 342 | >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS 343 | 0.8668... 344 | >>> references = [['a'] * 11, ['a'] * 8, ['a'] * 6, ['a'] * 7] 345 | >>> hypothesis = ['a'] * 7 346 | >>> hyp_len = len(hypothesis) 347 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 348 | >>> brevity_penalty(closest_ref_len, hyp_len) 349 | 1.0 350 | :param hyp_len: The length of the hypothesis for a single sentence OR the 351 | sum of all the hypotheses' lengths for a corpus 352 | :type hyp_len: int 353 | :param closest_ref_len: The length of the closest reference for a single 354 | hypothesis OR the sum of all the closest references for every hypotheses. 355 | :type closest_ref_len: int 356 | :return: BLEU's brevity penalty. 357 | :rtype: float 358 | """ 359 | if hyp_len > closest_ref_len: 360 | return 1 361 | # If hypothesis is empty, brevity penalty = 0 should result in BLEU = 0.0 362 | elif hyp_len == 0: 363 | return 0 364 | else: 365 | return math.exp(1 - closest_ref_len / hyp_len) 366 | 367 | 368 | class SmoothingFunction: 369 | """ 370 | This is an implementation of the smoothing techniques 371 | for segment-level BLEU scores that was presented in 372 | Boxing Chen and Collin Cherry (2014) A Systematic Comparison of 373 | Smoothing Techniques for Sentence-Level BLEU. In WMT14. 374 | http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf 375 | """ 376 | 377 | def __init__(self, epsilon=0.1, alpha=5, k=5): 378 | """ 379 | This will initialize the parameters required for the various smoothing 380 | techniques, the default values are set to the numbers used in the 381 | experiments from Chen and Cherry (2014). 382 | >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 'ensures', 383 | ... 'that', 'the', 'military', 'always', 'obeys', 'the', 384 | ... 'commands', 'of', 'the', 'party'] 385 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 'ensures', 386 | ... 'that', 'the', 'military', 'will', 'forever', 'heed', 387 | ... 'Party', 'commands'] 388 | >>> chencherry = SmoothingFunction() 389 | >>> print(sentence_bleu([reference1], hypothesis1)) # doctest: +ELLIPSIS 390 | 0.4118... 391 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method0)) # doctest: +ELLIPSIS 392 | 0.4118... 393 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method1)) # doctest: +ELLIPSIS 394 | 0.4118... 395 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method2)) # doctest: +ELLIPSIS 396 | 0.4489... 397 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method3)) # doctest: +ELLIPSIS 398 | 0.4118... 399 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method4)) # doctest: +ELLIPSIS 400 | 0.4118... 401 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method5)) # doctest: +ELLIPSIS 402 | 0.4905... 403 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method6)) # doctest: +ELLIPSIS 404 | 0.4135... 405 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method7)) # doctest: +ELLIPSIS 406 | 0.4905... 407 | :param epsilon: the epsilon value use in method 1 408 | :type epsilon: float 409 | :param alpha: the alpha value use in method 6 410 | :type alpha: int 411 | :param k: the k value use in method 4 412 | :type k: int 413 | """ 414 | self.epsilon = epsilon 415 | self.alpha = alpha 416 | self.k = k 417 | 418 | def method0(self, p_n, *args, **kwargs): 419 | """ 420 | No smoothing. 421 | """ 422 | p_n_new = [] 423 | for i, p_i in enumerate(p_n): 424 | if p_i[0] != 0: 425 | p_n_new.append(p_i) 426 | else: 427 | _msg = str( 428 | "\nThe hypothesis contains 0 counts of {}-gram overlaps.\n" 429 | "Therefore the BLEU score evaluates to 0, independently of\n" 430 | "how many N-gram overlaps of lower order it contains.\n" 431 | "Consider using lower n-gram order or use " 432 | "SmoothingFunction()" 433 | ).format(i + 1) 434 | warnings.warn(_msg) 435 | # When numerator==0 where denonminator==0 or !=0, the result 436 | # for the precision score should be equal to 0 or undefined. 437 | # Due to BLEU geometric mean computation in logarithm space, 438 | # we we need to take the return sys.float_info.min such that 439 | # math.log(sys.float_info.min) returns a 0 precision score. 440 | p_n_new.append(sys.float_info.min) 441 | return p_n_new 442 | 443 | def method1(self, p_n, *args, **kwargs): 444 | """ 445 | Smoothing method 1: Add *epsilon* counts to precision with 0 counts. 446 | """ 447 | return [ 448 | ((p_i[0] + self.epsilon), p_i[1]) 449 | if p_i[0] == 0 450 | else p_i 451 | for p_i in p_n 452 | ] 453 | 454 | def method2(self, p_n, *args, **kwargs): 455 | """ 456 | Smoothing method 2: Add 1 to both numerator and denominator from 457 | Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of 458 | machine translation quality using longest common subsequence and 459 | skip-bigram statistics. In ACL04. 460 | """ 461 | return [ 462 | (p_i[0] + 1, p_i[1] + 1) 463 | for p_i in p_n 464 | ] 465 | 466 | def method3(self, p_n, *args, **kwargs): 467 | """ 468 | Smoothing method 3: NIST geometric sequence smoothing 469 | The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each 470 | precision score whose matching n-gram count is null. 471 | k is 1 for the first 'n' value for which the n-gram match count is null/ 472 | For example, if the text contains: 473 | - one 2-gram match 474 | - and (consequently) two 1-gram matches 475 | the n-gram count for each individual precision score would be: 476 | - n=1 => prec_count = 2 (two unigrams) 477 | - n=2 => prec_count = 1 (one bigram) 478 | - n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1) 479 | - n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2) 480 | """ 481 | incvnt = 1 # From the mteval-v13a.pl, it's referred to as k. 482 | for i, p_i in enumerate(p_n): 483 | if p_i.numerator == 0: 484 | p_n[i] = 1 / (2 ** incvnt * p_i.denominator) 485 | incvnt += 1 486 | return p_n 487 | 488 | def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 489 | """ 490 | Smoothing method 4: 491 | Shorter translations may have inflated precision values due to having 492 | smaller denominators; therefore, we give them proportionally 493 | smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry 494 | suggests dividing by 1/ln(len(T)), where T is the length of the translation. 495 | """ 496 | hyp_len = hyp_len if hyp_len else len(hypothesis) 497 | for i, p_i in enumerate(p_n): 498 | if p_i.numerator == 0 and hyp_len != 0: 499 | incvnt = i + 1 * self.k / math.log( 500 | hyp_len 501 | ) # Note that this K is different from the K from NIST. 502 | p_n[i] = incvnt / p_i.denominator 503 | return p_n 504 | 505 | def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 506 | """ 507 | Smoothing method 5: 508 | The matched counts for similar values of n should be similar. To a 509 | calculate the n-gram matched count, it averages the n−1, n and n+1 gram 510 | matched counts. 511 | """ 512 | hyp_len = hyp_len if hyp_len else len(hypothesis) 513 | m = {} 514 | # Requires an precision value for an addition ngram order. 515 | p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)] 516 | m[-1] = p_n[0] + 1 517 | for i, p_i in enumerate(p_n): 518 | p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3 519 | m[i] = p_n[i] 520 | return p_n 521 | 522 | def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 523 | """ 524 | Smoothing method 6: 525 | Interpolates the maximum likelihood estimate of the precision *p_n* with 526 | a prior estimate *pi0*. The prior is estimated by assuming that the ratio 527 | between pn and pn−1 will be the same as that between pn−1 and pn−2; from 528 | Gao and He (2013) Training MRF-Based Phrase Translation Models using 529 | Gradient Ascent. In NAACL. 530 | """ 531 | hyp_len = hyp_len if hyp_len else len(hypothesis) 532 | # This smoothing only works when p_1 and p_2 is non-zero. 533 | # Raise an error with an appropriate message when the input is too short 534 | # to use this smoothing technique. 535 | assert p_n[2], "This smoothing method requires non-zero precision for bigrams." 536 | for i, p_i in enumerate(p_n): 537 | if i in [0, 1]: # Skips the first 2 orders of ngrams. 538 | continue 539 | else: 540 | pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2] 541 | # No. of ngrams in translation that matches the reference. 542 | m = p_i.numerator 543 | # No. of ngrams in translation. 544 | l = sum(1 for _ in ngrams(hypothesis, i + 1)) 545 | # Calculates the interpolated precision. 546 | p_n[i] = (m + self.alpha * pi0) / (l + self.alpha) 547 | return p_n 548 | 549 | def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 550 | """ 551 | Smoothing method 7: 552 | Interpolates methods 4 and 5. 553 | """ 554 | hyp_len = hyp_len if hyp_len else len(hypothesis) 555 | p_n = self.method4(p_n, references, hypothesis, hyp_len) 556 | p_n = self.method5(p_n, references, hypothesis, hyp_len) 557 | return p_n 558 | -------------------------------------------------------------------------------- /bleu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Natural Language Toolkit: BLEU Score 3 | # 4 | # Copyright (C) 2001-2020 NLTK Project 5 | # Authors: Chin Yee Lee, Hengfeng Li, Ruxin Hou, Calvin Tanujaya Lim 6 | # Contributors: Björn Mattsson, Dmitrijs Milajevs, Liling Tan 7 | # URL: 8 | # For license information, see LICENSE.TXT 9 | 10 | """BLEU score implementation.""" 11 | 12 | import math 13 | import sys 14 | from fractions import Fraction 15 | import warnings 16 | from collections import Counter 17 | from utils import ngrams 18 | 19 | 20 | def sentence_bleu( 21 | references, 22 | hypothesis, 23 | weights=(0.25, 0.25, 0.25, 0.25), 24 | smoothing_function=None, 25 | auto_reweigh=False, 26 | ): 27 | """ 28 | Calculate BLEU score (Bilingual Evaluation Understudy) from 29 | Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu. 2002. 30 | "BLEU: a method for automatic evaluation of machine translation." 31 | In Proceedings of ACL. http://www.aclweb.org/anthology/P02-1040.pdf 32 | >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 33 | ... 'ensures', 'that', 'the', 'military', 'always', 34 | ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] 35 | >>> hypothesis2 = ['It', 'is', 'to', 'insure', 'the', 'troops', 36 | ... 'forever', 'hearing', 'the', 'activity', 'guidebook', 37 | ... 'that', 'party', 'direct'] 38 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 39 | ... 'ensures', 'that', 'the', 'military', 'will', 'forever', 40 | ... 'heed', 'Party', 'commands'] 41 | >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which', 42 | ... 'guarantees', 'the', 'military', 'forces', 'always', 43 | ... 'being', 'under', 'the', 'command', 'of', 'the', 44 | ... 'Party'] 45 | >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 46 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 47 | ... 'of', 'the', 'party'] 48 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis1) # doctest: +ELLIPSIS 49 | 0.5045... 50 | If there is no ngrams overlap for any order of n-grams, BLEU returns the 51 | value 0. This is because the precision for the order of n-grams without 52 | overlap is 0, and the geometric mean in the final BLEU score computation 53 | multiplies the 0 with the precision of other n-grams. This results in 0 54 | (independently of the precision of the othe n-gram orders). The following 55 | example has zero 3-gram and 4-gram overlaps: 56 | >>> round(sentence_bleu([reference1, reference2, reference3], hypothesis2),4) # doctest: +ELLIPSIS 57 | 0.0 58 | To avoid this harsh behaviour when no ngram overlaps are found a smoothing 59 | function can be used. 60 | >>> chencherry = SmoothingFunction() 61 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis2, 62 | ... smoothing_function=chencherry.method1) # doctest: +ELLIPSIS 63 | 0.0370... 64 | The default BLEU calculates a score for up to 4-grams using uniform 65 | weights (this is called BLEU-4). To evaluate your translations with 66 | higher/lower order ngrams, use customized weights. E.g. when accounting 67 | for up to 5-grams with uniform weights (this is called BLEU-5) use: 68 | >>> weights = (1./5., 1./5., 1./5., 1./5., 1./5.) 69 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis1, weights) # doctest: +ELLIPSIS 70 | 0.3920... 71 | :param references: reference sentences 72 | :type references: list(list(str)) 73 | :param hypothesis: a hypothesis sentence 74 | :type hypothesis: list(str) 75 | :param weights: weights for unigrams, bigrams, trigrams and so on 76 | :type weights: list(float) 77 | :param smoothing_function: 78 | :type smoothing_function: SmoothingFunction 79 | :param auto_reweigh: Option to re-normalize the weights uniformly. 80 | :type auto_reweigh: bool 81 | :return: The sentence-level BLEU score. 82 | :rtype: float 83 | """ 84 | return corpus_bleu( 85 | [references], [hypothesis], weights, smoothing_function, auto_reweigh 86 | ) 87 | 88 | 89 | def corpus_bleu( 90 | list_of_references, 91 | hypotheses, 92 | weights=(0.25, 0.25, 0.25, 0.25), 93 | smoothing_function=None, 94 | auto_reweigh=False, 95 | ): 96 | """ 97 | Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all 98 | the hypotheses and their respective references. 99 | Instead of averaging the sentence level BLEU scores (i.e. marco-average 100 | precision), the original BLEU metric (Papineni et al. 2002) accounts for 101 | the micro-average precision (i.e. summing the numerators and denominators 102 | for each hypothesis-reference(s) pairs before the division). 103 | >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 104 | ... 'ensures', 'that', 'the', 'military', 'always', 105 | ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] 106 | >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 107 | ... 'ensures', 'that', 'the', 'military', 'will', 'forever', 108 | ... 'heed', 'Party', 'commands'] 109 | >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which', 110 | ... 'guarantees', 'the', 'military', 'forces', 'always', 111 | ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party'] 112 | >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 113 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 114 | ... 'of', 'the', 'party'] 115 | >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was', 116 | ... 'interested', 'in', 'world', 'history'] 117 | >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history', 118 | ... 'because', 'he', 'read', 'the', 'book'] 119 | >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]] 120 | >>> hypotheses = [hyp1, hyp2] 121 | >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS 122 | 0.5920... 123 | The example below show that corpus_bleu() is different from averaging 124 | sentence_bleu() for hypotheses 125 | >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1) 126 | >>> score2 = sentence_bleu([ref2a], hyp2) 127 | >>> (score1 + score2) / 2 # doctest: +ELLIPSIS 128 | 0.6223... 129 | :param list_of_references: a corpus of lists of reference sentences, w.r.t. hypotheses 130 | :type list_of_references: list(list(list(str))) 131 | :param hypotheses: a list of hypothesis sentences 132 | :type hypotheses: list(list(str)) 133 | :param weights: weights for unigrams, bigrams, trigrams and so on 134 | :type weights: list(float) 135 | :param smoothing_function: 136 | :type smoothing_function: SmoothingFunction 137 | :param auto_reweigh: Option to re-normalize the weights uniformly. 138 | :type auto_reweigh: bool 139 | :return: The corpus-level BLEU score. 140 | :rtype: float 141 | """ 142 | # Before proceeding to compute BLEU, perform sanity checks. 143 | 144 | p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches. 145 | p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref. 146 | hyp_lengths, ref_lengths = 0, 0 147 | 148 | assert len(list_of_references) == len(hypotheses), ( 149 | "The number of hypotheses and their reference(s) should be the " "same " 150 | ) 151 | 152 | # Iterate through each hypothesis and their corresponding references. 153 | for references, hypothesis in zip(list_of_references, hypotheses): 154 | # For each order of ngram, calculate the numerator and 155 | # denominator for the corpus-level modified precision. 156 | for i, _ in enumerate(weights, start=1): 157 | p_i = modified_precision(references, hypothesis, i) 158 | p_numerators[i] += p_i.numerator 159 | p_denominators[i] += p_i.denominator 160 | 161 | # Calculate the hypothesis length and the closest reference length. 162 | # Adds them to the corpus-level hypothesis and reference counts. 163 | hyp_len = len(hypothesis) 164 | hyp_lengths += hyp_len 165 | ref_lengths += closest_ref_length(references, hyp_len) 166 | 167 | # Calculate corpus-level brevity penalty. 168 | bp = brevity_penalty(ref_lengths, hyp_lengths) 169 | 170 | # Uniformly re-weighting based on maximum hypothesis lengths if largest 171 | # order of n-grams < 4 and weights is set at default. 172 | if auto_reweigh: 173 | if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25): 174 | weights = (1 / hyp_lengths,) * hyp_lengths 175 | 176 | # Collects the various precision values for the different ngram orders. 177 | p_n = [ 178 | Fraction(p_numerators[i], p_denominators[i], _normalize=False) 179 | for i, _ in enumerate(weights, start=1) 180 | ] 181 | 182 | # Returns 0 if there's no matching n-grams 183 | # We only need to check for p_numerators[1] == 0, since if there's 184 | # no unigrams, there won't be any higher order ngrams. 185 | if p_numerators[1] == 0: 186 | return 0 187 | 188 | # If there's no smoothing, set use method0 from SmoothinFunction class. 189 | if not smoothing_function: 190 | smoothing_function = SmoothingFunction().method1 191 | # Smoothen the modified precision. 192 | # Note: smoothing_function() may convert values into floats; 193 | # it tries to retain the Fraction object as much as the 194 | # smoothing method allows. 195 | p_n = smoothing_function( 196 | p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths 197 | ) 198 | s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n)) 199 | s = bp * math.exp(math.fsum(s)) 200 | return s 201 | 202 | 203 | def modified_precision(references, hypothesis, n): 204 | """ 205 | Calculate modified ngram precision. 206 | The normal precision method may lead to some wrong translations with 207 | high-precision, e.g., the translation, in which a word of reference 208 | repeats several times, has very high precision. 209 | This function only returns the Fraction object that contains the numerator 210 | and denominator necessary to calculate the corpus-level precision. 211 | To calculate the modified precision for a single pair of hypothesis and 212 | references, cast the Fraction object into a float. 213 | The famous "the the the ... " example shows that you can get BLEU precision 214 | by duplicating high frequency words. 215 | >>> reference1 = 'the cat is on the mat'.split() 216 | >>> reference2 = 'there is a cat on the mat'.split() 217 | >>> hypothesis1 = 'the the the the the the the'.split() 218 | >>> references = [reference1, reference2] 219 | >>> float(modified_precision(references, hypothesis1, n=1)) # doctest: +ELLIPSIS 220 | 0.2857... 221 | In the modified n-gram precision, a reference word will be considered 222 | exhausted after a matching hypothesis word is identified, e.g. 223 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 224 | ... 'ensures', 'that', 'the', 'military', 'will', 225 | ... 'forever', 'heed', 'Party', 'commands'] 226 | >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which', 227 | ... 'guarantees', 'the', 'military', 'forces', 'always', 228 | ... 'being', 'under', 'the', 'command', 'of', 'the', 229 | ... 'Party'] 230 | >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 231 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 232 | ... 'of', 'the', 'party'] 233 | >>> hypothesis = 'of the'.split() 234 | >>> references = [reference1, reference2, reference3] 235 | >>> float(modified_precision(references, hypothesis, n=1)) 236 | 1.0 237 | >>> float(modified_precision(references, hypothesis, n=2)) 238 | 1.0 239 | An example of a normal machine translation hypothesis: 240 | >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 241 | ... 'ensures', 'that', 'the', 'military', 'always', 242 | ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] 243 | >>> hypothesis2 = ['It', 'is', 'to', 'insure', 'the', 'troops', 244 | ... 'forever', 'hearing', 'the', 'activity', 'guidebook', 245 | ... 'that', 'party', 'direct'] 246 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 247 | ... 'ensures', 'that', 'the', 'military', 'will', 248 | ... 'forever', 'heed', 'Party', 'commands'] 249 | >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which', 250 | ... 'guarantees', 'the', 'military', 'forces', 'always', 251 | ... 'being', 'under', 'the', 'command', 'of', 'the', 252 | ... 'Party'] 253 | >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 254 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 255 | ... 'of', 'the', 'party'] 256 | >>> references = [reference1, reference2, reference3] 257 | >>> float(modified_precision(references, hypothesis1, n=1)) # doctest: +ELLIPSIS 258 | 0.9444... 259 | >>> float(modified_precision(references, hypothesis2, n=1)) # doctest: +ELLIPSIS 260 | 0.5714... 261 | >>> float(modified_precision(references, hypothesis1, n=2)) # doctest: +ELLIPSIS 262 | 0.5882352941176471 263 | >>> float(modified_precision(references, hypothesis2, n=2)) # doctest: +ELLIPSIS 264 | 0.07692... 265 | :param references: A list of reference translations. 266 | :type references: list(list(str)) 267 | :param hypothesis: A hypothesis translation. 268 | :type hypothesis: list(str) 269 | :param n: The ngram order. 270 | :type n: int 271 | :return: BLEU's modified precision for the nth order ngram. 272 | :rtype: Fraction 273 | """ 274 | # Extracts all ngrams in hypothesis 275 | # Set an empty Counter if hypothesis is empty. 276 | 277 | counts = Counter(ngrams(hypothesis, n)) if len(hypothesis) >= n else Counter() 278 | # Extract a union of references' counts. 279 | # max_counts = reduce(or_, [Counter(ngrams(ref, n)) for ref in references]) 280 | max_counts = {} 281 | for reference in references: 282 | reference_counts = ( 283 | Counter(ngrams(reference, n)) if len(reference) >= n else Counter() 284 | ) 285 | for ngram in counts: 286 | max_counts[ngram] = max(max_counts.get(ngram, 0), reference_counts[ngram]) 287 | 288 | # Assigns the intersection between hypothesis and references' counts. 289 | clipped_counts = { 290 | ngram: min(count, max_counts[ngram]) for ngram, count in counts.items() 291 | } 292 | 293 | numerator = sum(clipped_counts.values()) 294 | # Ensures that denominator is minimum 1 to avoid ZeroDivisionError. 295 | # Usually this happens when the ngram order is > len(reference). 296 | denominator = max(1, sum(counts.values())) 297 | 298 | return Fraction(numerator, denominator, _normalize=False) 299 | 300 | 301 | def closest_ref_length(references, hyp_len): 302 | """ 303 | This function finds the reference that is the closest length to the 304 | hypothesis. The closest reference length is referred to as *r* variable 305 | from the brevity penalty formula in Papineni et. al. (2002) 306 | :param references: A list of reference translations. 307 | :type references: list(list(str)) 308 | :param hyp_len: The length of the hypothesis. 309 | :type hyp_len: int 310 | :return: The length of the reference that's closest to the hypothesis. 311 | :rtype: int 312 | """ 313 | ref_lens = (len(reference) for reference in references) 314 | closest_ref_len = min( 315 | ref_lens, key=lambda ref_len: (abs(ref_len - hyp_len), ref_len) 316 | ) 317 | return closest_ref_len 318 | 319 | 320 | def brevity_penalty(closest_ref_len, hyp_len): 321 | """ 322 | Calculate brevity penalty. 323 | As the modified n-gram precision still has the problem from the short 324 | length sentence, brevity penalty is used to modify the overall BLEU 325 | score according to length. 326 | An example from the paper. There are three references with length 12, 15 327 | and 17. And a concise hypothesis of the length 12. The brevity penalty is 1. 328 | >>> reference1 = list('aaaaaaaaaaaa') # i.e. ['a'] * 12 329 | >>> reference2 = list('aaaaaaaaaaaaaaa') # i.e. ['a'] * 15 330 | >>> reference3 = list('aaaaaaaaaaaaaaaaa') # i.e. ['a'] * 17 331 | >>> hypothesis = list('aaaaaaaaaaaa') # i.e. ['a'] * 12 332 | >>> references = [reference1, reference2, reference3] 333 | >>> hyp_len = len(hypothesis) 334 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 335 | >>> brevity_penalty(closest_ref_len, hyp_len) 336 | 1.0 337 | In case a hypothesis translation is shorter than the references, penalty is 338 | applied. 339 | >>> references = [['a'] * 28, ['a'] * 28] 340 | >>> hypothesis = ['a'] * 12 341 | >>> hyp_len = len(hypothesis) 342 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 343 | >>> brevity_penalty(closest_ref_len, hyp_len) 344 | 0.2635971381157267 345 | The length of the closest reference is used to compute the penalty. If the 346 | length of a hypothesis is 12, and the reference lengths are 13 and 2, the 347 | penalty is applied because the hypothesis length (12) is less then the 348 | closest reference length (13). 349 | >>> references = [['a'] * 13, ['a'] * 2] 350 | >>> hypothesis = ['a'] * 12 351 | >>> hyp_len = len(hypothesis) 352 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 353 | >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS 354 | 0.9200... 355 | The brevity penalty doesn't depend on reference order. More importantly, 356 | when two reference sentences are at the same distance, the shortest 357 | reference sentence length is used. 358 | >>> references = [['a'] * 13, ['a'] * 11] 359 | >>> hypothesis = ['a'] * 12 360 | >>> hyp_len = len(hypothesis) 361 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 362 | >>> bp1 = brevity_penalty(closest_ref_len, hyp_len) 363 | >>> hyp_len = len(hypothesis) 364 | >>> closest_ref_len = closest_ref_length(reversed(references), hyp_len) 365 | >>> bp2 = brevity_penalty(closest_ref_len, hyp_len) 366 | >>> bp1 == bp2 == 1 367 | True 368 | A test example from mteval-v13a.pl (starting from the line 705): 369 | >>> references = [['a'] * 11, ['a'] * 8] 370 | >>> hypothesis = ['a'] * 7 371 | >>> hyp_len = len(hypothesis) 372 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 373 | >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS 374 | 0.8668... 375 | >>> references = [['a'] * 11, ['a'] * 8, ['a'] * 6, ['a'] * 7] 376 | >>> hypothesis = ['a'] * 7 377 | >>> hyp_len = len(hypothesis) 378 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 379 | >>> brevity_penalty(closest_ref_len, hyp_len) 380 | 1.0 381 | :param hyp_len: The length of the hypothesis for a single sentence OR the 382 | sum of all the hypotheses' lengths for a corpus 383 | :type hyp_len: int 384 | :param closest_ref_len: The length of the closest reference for a single 385 | hypothesis OR the sum of all the closest references for every hypotheses. 386 | :type closest_ref_len: int 387 | :return: BLEU's brevity penalty. 388 | :rtype: float 389 | """ 390 | if hyp_len > closest_ref_len: 391 | return 1 392 | # If hypothesis is empty, brevity penalty = 0 should result in BLEU = 0.0 393 | elif hyp_len == 0: 394 | return 0 395 | else: 396 | return math.exp(1 - closest_ref_len / hyp_len) 397 | 398 | 399 | class SmoothingFunction: 400 | """ 401 | This is an implementation of the smoothing techniques 402 | for segment-level BLEU scores that was presented in 403 | Boxing Chen and Collin Cherry (2014) A Systematic Comparison of 404 | Smoothing Techniques for Sentence-Level BLEU. In WMT14. 405 | http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf 406 | """ 407 | 408 | def __init__(self, epsilon=0.1, alpha=5, k=5): 409 | """ 410 | This will initialize the parameters required for the various smoothing 411 | techniques, the default values are set to the numbers used in the 412 | experiments from Chen and Cherry (2014). 413 | >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 'ensures', 414 | ... 'that', 'the', 'military', 'always', 'obeys', 'the', 415 | ... 'commands', 'of', 'the', 'party'] 416 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 'ensures', 417 | ... 'that', 'the', 'military', 'will', 'forever', 'heed', 418 | ... 'Party', 'commands'] 419 | >>> chencherry = SmoothingFunction() 420 | >>> print(sentence_bleu([reference1], hypothesis1)) # doctest: +ELLIPSIS 421 | 0.4118... 422 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method0)) # doctest: +ELLIPSIS 423 | 0.4118... 424 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method1)) # doctest: +ELLIPSIS 425 | 0.4118... 426 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method2)) # doctest: +ELLIPSIS 427 | 0.4489... 428 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method3)) # doctest: +ELLIPSIS 429 | 0.4118... 430 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method4)) # doctest: +ELLIPSIS 431 | 0.4118... 432 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method5)) # doctest: +ELLIPSIS 433 | 0.4905... 434 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method6)) # doctest: +ELLIPSIS 435 | 0.4135... 436 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method7)) # doctest: +ELLIPSIS 437 | 0.4905... 438 | :param epsilon: the epsilon value use in method 1 439 | :type epsilon: float 440 | :param alpha: the alpha value use in method 6 441 | :type alpha: int 442 | :param k: the k value use in method 4 443 | :type k: int 444 | """ 445 | self.epsilon = epsilon 446 | self.alpha = alpha 447 | self.k = k 448 | 449 | def method0(self, p_n, *args, **kwargs): 450 | """ 451 | No smoothing. 452 | """ 453 | p_n_new = [] 454 | for i, p_i in enumerate(p_n): 455 | if p_i.numerator != 0: 456 | p_n_new.append(p_i) 457 | else: 458 | _msg = str( 459 | "\nThe hypothesis contains 0 counts of {}-gram overlaps.\n" 460 | "Therefore the BLEU score evaluates to 0, independently of\n" 461 | "how many N-gram overlaps of lower order it contains.\n" 462 | "Consider using lower n-gram order or use " 463 | "SmoothingFunction()" 464 | ).format(i + 1) 465 | warnings.warn(_msg) 466 | # When numerator==0 where denonminator==0 or !=0, the result 467 | # for the precision score should be equal to 0 or undefined. 468 | # Due to BLEU geometric mean computation in logarithm space, 469 | # we we need to take the return sys.float_info.min such that 470 | # math.log(sys.float_info.min) returns a 0 precision score. 471 | p_n_new.append(sys.float_info.min) 472 | return p_n_new 473 | 474 | def method1(self, p_n, *args, **kwargs): 475 | """ 476 | Smoothing method 1: Add *epsilon* counts to precision with 0 counts. 477 | """ 478 | return [ 479 | (p_i.numerator + self.epsilon) / p_i.denominator 480 | if p_i.numerator == 0 481 | else p_i 482 | for p_i in p_n 483 | ] 484 | 485 | def method2(self, p_n, *args, **kwargs): 486 | """ 487 | Smoothing method 2: Add 1 to both numerator and denominator from 488 | Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of 489 | machine translation quality using longest common subsequence and 490 | skip-bigram statistics. In ACL04. 491 | """ 492 | return [ 493 | Fraction(p_i.numerator + 1, p_i.denominator + 1, _normalize=False) 494 | for p_i in p_n 495 | ] 496 | 497 | def method3(self, p_n, *args, **kwargs): 498 | """ 499 | Smoothing method 3: NIST geometric sequence smoothing 500 | The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each 501 | precision score whose matching n-gram count is null. 502 | k is 1 for the first 'n' value for which the n-gram match count is null/ 503 | For example, if the text contains: 504 | - one 2-gram match 505 | - and (consequently) two 1-gram matches 506 | the n-gram count for each individual precision score would be: 507 | - n=1 => prec_count = 2 (two unigrams) 508 | - n=2 => prec_count = 1 (one bigram) 509 | - n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1) 510 | - n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2) 511 | """ 512 | incvnt = 1 # From the mteval-v13a.pl, it's referred to as k. 513 | for i, p_i in enumerate(p_n): 514 | if p_i.numerator == 0: 515 | p_n[i] = 1 / (2 ** incvnt * p_i.denominator) 516 | incvnt += 1 517 | return p_n 518 | 519 | def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 520 | """ 521 | Smoothing method 4: 522 | Shorter translations may have inflated precision values due to having 523 | smaller denominators; therefore, we give them proportionally 524 | smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry 525 | suggests dividing by 1/ln(len(T)), where T is the length of the translation. 526 | """ 527 | hyp_len = hyp_len if hyp_len else len(hypothesis) 528 | for i, p_i in enumerate(p_n): 529 | if p_i.numerator == 0 and hyp_len != 0: 530 | incvnt = i + 1 * self.k / math.log( 531 | hyp_len 532 | ) # Note that this K is different from the K from NIST. 533 | p_n[i] = incvnt / p_i.denominator 534 | return p_n 535 | 536 | def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 537 | """ 538 | Smoothing method 5: 539 | The matched counts for similar values of n should be similar. To a 540 | calculate the n-gram matched count, it averages the n−1, n and n+1 gram 541 | matched counts. 542 | """ 543 | hyp_len = hyp_len if hyp_len else len(hypothesis) 544 | m = {} 545 | # Requires an precision value for an addition ngram order. 546 | p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)] 547 | m[-1] = p_n[0] + 1 548 | for i, p_i in enumerate(p_n): 549 | p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3 550 | m[i] = p_n[i] 551 | return p_n 552 | 553 | def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 554 | """ 555 | Smoothing method 6: 556 | Interpolates the maximum likelihood estimate of the precision *p_n* with 557 | a prior estimate *pi0*. The prior is estimated by assuming that the ratio 558 | between pn and pn−1 will be the same as that between pn−1 and pn−2; from 559 | Gao and He (2013) Training MRF-Based Phrase Translation Models using 560 | Gradient Ascent. In NAACL. 561 | """ 562 | hyp_len = hyp_len if hyp_len else len(hypothesis) 563 | # This smoothing only works when p_1 and p_2 is non-zero. 564 | # Raise an error with an appropriate message when the input is too short 565 | # to use this smoothing technique. 566 | assert p_n[2], "This smoothing method requires non-zero precision for bigrams." 567 | for i, p_i in enumerate(p_n): 568 | if i in [0, 1]: # Skips the first 2 orders of ngrams. 569 | continue 570 | else: 571 | pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2] 572 | # No. of ngrams in translation that matches the reference. 573 | m = p_i.numerator 574 | # No. of ngrams in translation. 575 | l = sum(1 for _ in ngrams(hypothesis, i + 1)) 576 | # Calculates the interpolated precision. 577 | p_n[i] = (m + self.alpha * pi0) / (l + self.alpha) 578 | return p_n 579 | 580 | def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 581 | """ 582 | Smoothing method 7: 583 | Interpolates methods 4 and 5. 584 | """ 585 | hyp_len = hyp_len if hyp_len else len(hypothesis) 586 | p_n = self.method4(p_n, references, hypothesis, hyp_len) 587 | p_n = self.method5(p_n, references, hypothesis, hyp_len) 588 | return p_n 589 | -------------------------------------------------------------------------------- /parser/DFG.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from tree_sitter import Language, Parser 5 | from .utils import ( 6 | remove_comments_and_docstrings, 7 | tree_to_token_index, 8 | index_to_code_token, 9 | tree_to_variable_index 10 | ) 11 | 12 | 13 | def DFG_python(root_node, index_to_code, states): 14 | assignment = ['assignment', 'augmented_assignment', 'for_in_clause'] 15 | if_statement = ['if_statement'] 16 | for_statement = ['for_statement'] 17 | while_statement = ['while_statement'] 18 | do_first_statement = ['for_in_clause'] 19 | def_statement = ['default_parameter'] 20 | states = states.copy() 21 | if (len(root_node.children) == 0 or root_node.type == 'string') and root_node.type != 'comment': 22 | idx, code = index_to_code[(root_node.start_point, root_node.end_point)] 23 | if root_node.type == code: 24 | return [], states 25 | elif code in states: 26 | return [(code, idx, 'comesFrom', [code], states[code].copy())], states 27 | else: 28 | if root_node.type == 'identifier': 29 | states[code] = [idx] 30 | return [(code, idx, 'comesFrom', [], [])], states 31 | elif root_node.type in def_statement: 32 | name = root_node.child_by_field_name('name') 33 | value = root_node.child_by_field_name('value') 34 | DFG = [] 35 | if value is None: 36 | indexs = tree_to_variable_index(name, index_to_code) 37 | for index in indexs: 38 | idx, code = index_to_code[index] 39 | DFG.append((code, idx, 'comesFrom', [], [])) 40 | states[code] = [idx] 41 | return sorted(DFG, key=lambda x: x[1]), states 42 | else: 43 | name_indexs = tree_to_variable_index(name, index_to_code) 44 | value_indexs = tree_to_variable_index(value, index_to_code) 45 | temp, states = DFG_python(value, index_to_code, states) 46 | DFG += temp 47 | for index1 in name_indexs: 48 | idx1, code1 = index_to_code[index1] 49 | for index2 in value_indexs: 50 | idx2, code2 = index_to_code[index2] 51 | DFG.append((code1, idx1, 'comesFrom', [code2], [idx2])) 52 | states[code1] = [idx1] 53 | return sorted(DFG, key=lambda x: x[1]), states 54 | elif root_node.type in assignment: 55 | if root_node.type == 'for_in_clause': 56 | right_nodes = [root_node.children[-1]] 57 | left_nodes = [root_node.child_by_field_name('left')] 58 | else: 59 | if root_node.child_by_field_name('right') is None: 60 | return [], states 61 | left_nodes = [x for x in root_node.child_by_field_name('left').children if x.type != ','] 62 | right_nodes = [x for x in root_node.child_by_field_name('right').children if x.type != ','] 63 | if len(right_nodes) != len(left_nodes): 64 | left_nodes = [root_node.child_by_field_name('left')] 65 | right_nodes = [root_node.child_by_field_name('right')] 66 | if len(left_nodes) == 0: 67 | left_nodes = [root_node.child_by_field_name('left')] 68 | if len(right_nodes) == 0: 69 | right_nodes = [root_node.child_by_field_name('right')] 70 | DFG = [] 71 | for node in right_nodes: 72 | temp, states = DFG_python(node, index_to_code, states) 73 | DFG += temp 74 | 75 | for left_node, right_node in zip(left_nodes, right_nodes): 76 | left_tokens_index = tree_to_variable_index(left_node, index_to_code) 77 | right_tokens_index = tree_to_variable_index(right_node, index_to_code) 78 | temp = [] 79 | for token1_index in left_tokens_index: 80 | idx1, code1 = index_to_code[token1_index] 81 | temp.append((code1, idx1, 'computedFrom', [index_to_code[x][1] for x in right_tokens_index], 82 | [index_to_code[x][0] for x in right_tokens_index])) 83 | states[code1] = [idx1] 84 | DFG += temp 85 | return sorted(DFG, key=lambda x: x[1]), states 86 | elif root_node.type in if_statement: 87 | DFG = [] 88 | current_states = states.copy() 89 | others_states = [] 90 | tag = False 91 | if 'else' in root_node.type: 92 | tag = True 93 | for child in root_node.children: 94 | if 'else' in child.type: 95 | tag = True 96 | if child.type not in ['elif_clause', 'else_clause']: 97 | temp, current_states = DFG_python(child, index_to_code, current_states) 98 | DFG += temp 99 | else: 100 | temp, new_states = DFG_python(child, index_to_code, states) 101 | DFG += temp 102 | others_states.append(new_states) 103 | others_states.append(current_states) 104 | if tag is False: 105 | others_states.append(states) 106 | new_states = {} 107 | for dic in others_states: 108 | for key in dic: 109 | if key not in new_states: 110 | new_states[key] = dic[key].copy() 111 | else: 112 | new_states[key] += dic[key] 113 | for key in new_states: 114 | new_states[key] = sorted(list(set(new_states[key]))) 115 | return sorted(DFG, key=lambda x: x[1]), new_states 116 | elif root_node.type in for_statement: 117 | DFG = [] 118 | for i in range(2): 119 | right_nodes = [x for x in root_node.child_by_field_name('right').children if x.type != ','] 120 | left_nodes = [x for x in root_node.child_by_field_name('left').children if x.type != ','] 121 | if len(right_nodes) != len(left_nodes): 122 | left_nodes = [root_node.child_by_field_name('left')] 123 | right_nodes = [root_node.child_by_field_name('right')] 124 | if len(left_nodes) == 0: 125 | left_nodes = [root_node.child_by_field_name('left')] 126 | if len(right_nodes) == 0: 127 | right_nodes = [root_node.child_by_field_name('right')] 128 | for node in right_nodes: 129 | temp, states = DFG_python(node, index_to_code, states) 130 | DFG += temp 131 | for left_node, right_node in zip(left_nodes, right_nodes): 132 | left_tokens_index = tree_to_variable_index(left_node, index_to_code) 133 | right_tokens_index = tree_to_variable_index(right_node, index_to_code) 134 | temp = [] 135 | for token1_index in left_tokens_index: 136 | idx1, code1 = index_to_code[token1_index] 137 | temp.append((code1, idx1, 'computedFrom', [index_to_code[x][1] for x in right_tokens_index], 138 | [index_to_code[x][0] for x in right_tokens_index])) 139 | states[code1] = [idx1] 140 | DFG += temp 141 | if root_node.children[-1].type == "block": 142 | temp, states = DFG_python(root_node.children[-1], index_to_code, states) 143 | DFG += temp 144 | dic = {} 145 | for x in DFG: 146 | if (x[0], x[1], x[2]) not in dic: 147 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 148 | else: 149 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 150 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 151 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 152 | return sorted(DFG, key=lambda x: x[1]), states 153 | elif root_node.type in while_statement: 154 | DFG = [] 155 | for i in range(2): 156 | for child in root_node.children: 157 | temp, states = DFG_python(child, index_to_code, states) 158 | DFG += temp 159 | dic = {} 160 | for x in DFG: 161 | if (x[0], x[1], x[2]) not in dic: 162 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 163 | else: 164 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 165 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 166 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 167 | return sorted(DFG, key=lambda x: x[1]), states 168 | else: 169 | DFG = [] 170 | for child in root_node.children: 171 | if child.type in do_first_statement: 172 | temp, states = DFG_python(child, index_to_code, states) 173 | DFG += temp 174 | for child in root_node.children: 175 | if child.type not in do_first_statement: 176 | temp, states = DFG_python(child, index_to_code, states) 177 | DFG += temp 178 | 179 | return sorted(DFG, key=lambda x: x[1]), states 180 | 181 | 182 | def DFG_java(root_node, index_to_code, states): 183 | assignment = ['assignment_expression'] 184 | def_statement = ['variable_declarator'] 185 | increment_statement = ['update_expression'] 186 | if_statement = ['if_statement', 'else'] 187 | for_statement = ['for_statement'] 188 | enhanced_for_statement = ['enhanced_for_statement'] 189 | while_statement = ['while_statement'] 190 | do_first_statement = [] 191 | states = states.copy() 192 | if (len(root_node.children) == 0 or root_node.type == 'string') and root_node.type != 'comment': 193 | idx, code = index_to_code[(root_node.start_point, root_node.end_point)] 194 | if root_node.type == code: 195 | return [], states 196 | elif code in states: 197 | return [(code, idx, 'comesFrom', [code], states[code].copy())], states 198 | else: 199 | if root_node.type == 'identifier': 200 | states[code] = [idx] 201 | return [(code, idx, 'comesFrom', [], [])], states 202 | elif root_node.type in def_statement: 203 | name = root_node.child_by_field_name('name') 204 | value = root_node.child_by_field_name('value') 205 | DFG = [] 206 | if value is None: 207 | indexs = tree_to_variable_index(name, index_to_code) 208 | for index in indexs: 209 | idx, code = index_to_code[index] 210 | DFG.append((code, idx, 'comesFrom', [], [])) 211 | states[code] = [idx] 212 | return sorted(DFG, key=lambda x: x[1]), states 213 | else: 214 | name_indexs = tree_to_variable_index(name, index_to_code) 215 | value_indexs = tree_to_variable_index(value, index_to_code) 216 | temp, states = DFG_java(value, index_to_code, states) 217 | DFG += temp 218 | for index1 in name_indexs: 219 | idx1, code1 = index_to_code[index1] 220 | for index2 in value_indexs: 221 | idx2, code2 = index_to_code[index2] 222 | DFG.append((code1, idx1, 'comesFrom', [code2], [idx2])) 223 | states[code1] = [idx1] 224 | return sorted(DFG, key=lambda x: x[1]), states 225 | elif root_node.type in assignment: 226 | left_nodes = root_node.child_by_field_name('left') 227 | right_nodes = root_node.child_by_field_name('right') 228 | DFG = [] 229 | temp, states = DFG_java(right_nodes, index_to_code, states) 230 | DFG += temp 231 | name_indexs = tree_to_variable_index(left_nodes, index_to_code) 232 | value_indexs = tree_to_variable_index(right_nodes, index_to_code) 233 | for index1 in name_indexs: 234 | idx1, code1 = index_to_code[index1] 235 | for index2 in value_indexs: 236 | idx2, code2 = index_to_code[index2] 237 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 238 | states[code1] = [idx1] 239 | return sorted(DFG, key=lambda x: x[1]), states 240 | elif root_node.type in increment_statement: 241 | DFG = [] 242 | indexs = tree_to_variable_index(root_node, index_to_code) 243 | for index1 in indexs: 244 | idx1, code1 = index_to_code[index1] 245 | for index2 in indexs: 246 | idx2, code2 = index_to_code[index2] 247 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 248 | states[code1] = [idx1] 249 | return sorted(DFG, key=lambda x: x[1]), states 250 | elif root_node.type in if_statement: 251 | DFG = [] 252 | current_states = states.copy() 253 | others_states = [] 254 | flag = False 255 | if 'else' in root_node.type: 256 | tag = True 257 | for child in root_node.children: 258 | if 'else' in child.type: 259 | tag = True 260 | if child.type not in if_statement and flag is False: 261 | temp, current_states = DFG_java(child, index_to_code, current_states) 262 | DFG += temp 263 | else: 264 | flag = True 265 | temp, new_states = DFG_java(child, index_to_code, states) 266 | DFG += temp 267 | others_states.append(new_states) 268 | others_states.append(current_states) 269 | if tag is False: 270 | others_states.append(states) 271 | new_states = {} 272 | for dic in others_states: 273 | for key in dic: 274 | if key not in new_states: 275 | new_states[key] = dic[key].copy() 276 | else: 277 | new_states[key] += dic[key] 278 | for key in new_states: 279 | new_states[key] = sorted(list(set(new_states[key]))) 280 | return sorted(DFG, key=lambda x: x[1]), new_states 281 | elif root_node.type in for_statement: 282 | DFG = [] 283 | for child in root_node.children: 284 | temp, states = DFG_java(child, index_to_code, states) 285 | DFG += temp 286 | flag = False 287 | for child in root_node.children: 288 | if flag: 289 | temp, states = DFG_java(child, index_to_code, states) 290 | DFG += temp 291 | elif child.type == "local_variable_declaration": 292 | flag = True 293 | dic = {} 294 | for x in DFG: 295 | if (x[0], x[1], x[2]) not in dic: 296 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 297 | else: 298 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 299 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 300 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 301 | return sorted(DFG, key=lambda x: x[1]), states 302 | elif root_node.type in enhanced_for_statement: 303 | name = root_node.child_by_field_name('name') 304 | value = root_node.child_by_field_name('value') 305 | body = root_node.child_by_field_name('body') 306 | DFG = [] 307 | for i in range(2): 308 | temp, states = DFG_java(value, index_to_code, states) 309 | DFG += temp 310 | name_indexs = tree_to_variable_index(name, index_to_code) 311 | value_indexs = tree_to_variable_index(value, index_to_code) 312 | for index1 in name_indexs: 313 | idx1, code1 = index_to_code[index1] 314 | for index2 in value_indexs: 315 | idx2, code2 = index_to_code[index2] 316 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 317 | states[code1] = [idx1] 318 | temp, states = DFG_java(body, index_to_code, states) 319 | DFG += temp 320 | dic = {} 321 | for x in DFG: 322 | if (x[0], x[1], x[2]) not in dic: 323 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 324 | else: 325 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 326 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 327 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 328 | return sorted(DFG, key=lambda x: x[1]), states 329 | elif root_node.type in while_statement: 330 | DFG = [] 331 | for i in range(2): 332 | for child in root_node.children: 333 | temp, states = DFG_java(child, index_to_code, states) 334 | DFG += temp 335 | dic = {} 336 | for x in DFG: 337 | if (x[0], x[1], x[2]) not in dic: 338 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 339 | else: 340 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 341 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 342 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 343 | return sorted(DFG, key=lambda x: x[1]), states 344 | else: 345 | DFG = [] 346 | for child in root_node.children: 347 | if child.type in do_first_statement: 348 | temp, states = DFG_java(child, index_to_code, states) 349 | DFG += temp 350 | for child in root_node.children: 351 | if child.type not in do_first_statement: 352 | temp, states = DFG_java(child, index_to_code, states) 353 | DFG += temp 354 | 355 | return sorted(DFG, key=lambda x: x[1]), states 356 | 357 | 358 | def DFG_csharp(root_node, index_to_code, states): 359 | assignment = ['assignment_expression'] 360 | def_statement = ['variable_declarator'] 361 | increment_statement = ['postfix_unary_expression'] 362 | if_statement = ['if_statement', 'else'] 363 | for_statement = ['for_statement'] 364 | enhanced_for_statement = ['for_each_statement'] 365 | while_statement = ['while_statement'] 366 | do_first_statement = [] 367 | states = states.copy() 368 | if (len(root_node.children) == 0 or root_node.type == 'string') and root_node.type != 'comment': 369 | idx, code = index_to_code[(root_node.start_point, root_node.end_point)] 370 | if root_node.type == code: 371 | return [], states 372 | elif code in states: 373 | return [(code, idx, 'comesFrom', [code], states[code].copy())], states 374 | else: 375 | if root_node.type == 'identifier': 376 | states[code] = [idx] 377 | return [(code, idx, 'comesFrom', [], [])], states 378 | elif root_node.type in def_statement: 379 | if len(root_node.children) == 2: 380 | name = root_node.children[0] 381 | value = root_node.children[1] 382 | else: 383 | name = root_node.children[0] 384 | value = None 385 | DFG = [] 386 | if value is None: 387 | indexs = tree_to_variable_index(name, index_to_code) 388 | for index in indexs: 389 | idx, code = index_to_code[index] 390 | DFG.append((code, idx, 'comesFrom', [], [])) 391 | states[code] = [idx] 392 | return sorted(DFG, key=lambda x: x[1]), states 393 | else: 394 | name_indexs = tree_to_variable_index(name, index_to_code) 395 | value_indexs = tree_to_variable_index(value, index_to_code) 396 | temp, states = DFG_csharp(value, index_to_code, states) 397 | DFG += temp 398 | for index1 in name_indexs: 399 | idx1, code1 = index_to_code[index1] 400 | for index2 in value_indexs: 401 | idx2, code2 = index_to_code[index2] 402 | DFG.append((code1, idx1, 'comesFrom', [code2], [idx2])) 403 | states[code1] = [idx1] 404 | return sorted(DFG, key=lambda x: x[1]), states 405 | elif root_node.type in assignment: 406 | left_nodes = root_node.child_by_field_name('left') 407 | right_nodes = root_node.child_by_field_name('right') 408 | DFG = [] 409 | temp, states = DFG_csharp(right_nodes, index_to_code, states) 410 | DFG += temp 411 | name_indexs = tree_to_variable_index(left_nodes, index_to_code) 412 | value_indexs = tree_to_variable_index(right_nodes, index_to_code) 413 | for index1 in name_indexs: 414 | idx1, code1 = index_to_code[index1] 415 | for index2 in value_indexs: 416 | idx2, code2 = index_to_code[index2] 417 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 418 | states[code1] = [idx1] 419 | return sorted(DFG, key=lambda x: x[1]), states 420 | elif root_node.type in increment_statement: 421 | DFG = [] 422 | indexs = tree_to_variable_index(root_node, index_to_code) 423 | for index1 in indexs: 424 | idx1, code1 = index_to_code[index1] 425 | for index2 in indexs: 426 | idx2, code2 = index_to_code[index2] 427 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 428 | states[code1] = [idx1] 429 | return sorted(DFG, key=lambda x: x[1]), states 430 | elif root_node.type in if_statement: 431 | DFG = [] 432 | current_states = states.copy() 433 | others_states = [] 434 | flag = False 435 | if 'else' in root_node.type: 436 | tag = True 437 | for child in root_node.children: 438 | if 'else' in child.type: 439 | tag = True 440 | if child.type not in if_statement and flag is False: 441 | temp, current_states = DFG_csharp(child, index_to_code, current_states) 442 | DFG += temp 443 | else: 444 | flag = True 445 | temp, new_states = DFG_csharp(child, index_to_code, states) 446 | DFG += temp 447 | others_states.append(new_states) 448 | others_states.append(current_states) 449 | if tag is False: 450 | others_states.append(states) 451 | new_states = {} 452 | for dic in others_states: 453 | for key in dic: 454 | if key not in new_states: 455 | new_states[key] = dic[key].copy() 456 | else: 457 | new_states[key] += dic[key] 458 | for key in new_states: 459 | new_states[key] = sorted(list(set(new_states[key]))) 460 | return sorted(DFG, key=lambda x: x[1]), new_states 461 | elif root_node.type in for_statement: 462 | DFG = [] 463 | for child in root_node.children: 464 | temp, states = DFG_csharp(child, index_to_code, states) 465 | DFG += temp 466 | flag = False 467 | for child in root_node.children: 468 | if flag: 469 | temp, states = DFG_csharp(child, index_to_code, states) 470 | DFG += temp 471 | elif child.type == "local_variable_declaration": 472 | flag = True 473 | dic = {} 474 | for x in DFG: 475 | if (x[0], x[1], x[2]) not in dic: 476 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 477 | else: 478 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 479 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 480 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 481 | return sorted(DFG, key=lambda x: x[1]), states 482 | elif root_node.type in enhanced_for_statement: 483 | name = root_node.child_by_field_name('left') 484 | value = root_node.child_by_field_name('right') 485 | body = root_node.child_by_field_name('body') 486 | DFG = [] 487 | for i in range(2): 488 | temp, states = DFG_csharp(value, index_to_code, states) 489 | DFG += temp 490 | name_indexs = tree_to_variable_index(name, index_to_code) 491 | value_indexs = tree_to_variable_index(value, index_to_code) 492 | for index1 in name_indexs: 493 | idx1, code1 = index_to_code[index1] 494 | for index2 in value_indexs: 495 | idx2, code2 = index_to_code[index2] 496 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 497 | states[code1] = [idx1] 498 | temp, states = DFG_csharp(body, index_to_code, states) 499 | DFG += temp 500 | dic = {} 501 | for x in DFG: 502 | if (x[0], x[1], x[2]) not in dic: 503 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 504 | else: 505 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 506 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 507 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 508 | return sorted(DFG, key=lambda x: x[1]), states 509 | elif root_node.type in while_statement: 510 | DFG = [] 511 | for i in range(2): 512 | for child in root_node.children: 513 | temp, states = DFG_csharp(child, index_to_code, states) 514 | DFG += temp 515 | dic = {} 516 | for x in DFG: 517 | if (x[0], x[1], x[2]) not in dic: 518 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 519 | else: 520 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 521 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 522 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 523 | return sorted(DFG, key=lambda x: x[1]), states 524 | else: 525 | DFG = [] 526 | for child in root_node.children: 527 | if child.type in do_first_statement: 528 | temp, states = DFG_csharp(child, index_to_code, states) 529 | DFG += temp 530 | for child in root_node.children: 531 | if child.type not in do_first_statement: 532 | temp, states = DFG_csharp(child, index_to_code, states) 533 | DFG += temp 534 | 535 | return sorted(DFG, key=lambda x: x[1]), states 536 | 537 | 538 | def DFG_ruby(root_node, index_to_code, states): 539 | assignment = ['assignment', 'operator_assignment'] 540 | if_statement = ['if', 'elsif', 'else', 'unless', 'when'] 541 | for_statement = ['for'] 542 | while_statement = ['while_modifier', 'until'] 543 | do_first_statement = [] 544 | def_statement = ['keyword_parameter'] 545 | if (len(root_node.children) == 0 or root_node.type == 'string') and root_node.type != 'comment': 546 | states = states.copy() 547 | idx, code = index_to_code[(root_node.start_point, root_node.end_point)] 548 | if root_node.type == code: 549 | return [], states 550 | elif code in states: 551 | return [(code, idx, 'comesFrom', [code], states[code].copy())], states 552 | else: 553 | if root_node.type == 'identifier': 554 | states[code] = [idx] 555 | return [(code, idx, 'comesFrom', [], [])], states 556 | elif root_node.type in def_statement: 557 | name = root_node.child_by_field_name('name') 558 | value = root_node.child_by_field_name('value') 559 | DFG = [] 560 | if value is None: 561 | indexs = tree_to_variable_index(name, index_to_code) 562 | for index in indexs: 563 | idx, code = index_to_code[index] 564 | DFG.append((code, idx, 'comesFrom', [], [])) 565 | states[code] = [idx] 566 | return sorted(DFG, key=lambda x: x[1]), states 567 | else: 568 | name_indexs = tree_to_variable_index(name, index_to_code) 569 | value_indexs = tree_to_variable_index(value, index_to_code) 570 | temp, states = DFG_ruby(value, index_to_code, states) 571 | DFG += temp 572 | for index1 in name_indexs: 573 | idx1, code1 = index_to_code[index1] 574 | for index2 in value_indexs: 575 | idx2, code2 = index_to_code[index2] 576 | DFG.append((code1, idx1, 'comesFrom', [code2], [idx2])) 577 | states[code1] = [idx1] 578 | return sorted(DFG, key=lambda x: x[1]), states 579 | elif root_node.type in assignment: 580 | left_nodes = [x for x in root_node.child_by_field_name('left').children if x.type != ','] 581 | right_nodes = [x for x in root_node.child_by_field_name('right').children if x.type != ','] 582 | if len(right_nodes) != len(left_nodes): 583 | left_nodes = [root_node.child_by_field_name('left')] 584 | right_nodes = [root_node.child_by_field_name('right')] 585 | if len(left_nodes) == 0: 586 | left_nodes = [root_node.child_by_field_name('left')] 587 | if len(right_nodes) == 0: 588 | right_nodes = [root_node.child_by_field_name('right')] 589 | if root_node.type == "operator_assignment": 590 | left_nodes = [root_node.children[0]] 591 | right_nodes = [root_node.children[-1]] 592 | 593 | DFG = [] 594 | for node in right_nodes: 595 | temp, states = DFG_ruby(node, index_to_code, states) 596 | DFG += temp 597 | 598 | for left_node, right_node in zip(left_nodes, right_nodes): 599 | left_tokens_index = tree_to_variable_index(left_node, index_to_code) 600 | right_tokens_index = tree_to_variable_index(right_node, index_to_code) 601 | temp = [] 602 | for token1_index in left_tokens_index: 603 | idx1, code1 = index_to_code[token1_index] 604 | temp.append((code1, idx1, 'computedFrom', [index_to_code[x][1] for x in right_tokens_index], 605 | [index_to_code[x][0] for x in right_tokens_index])) 606 | states[code1] = [idx1] 607 | DFG += temp 608 | return sorted(DFG, key=lambda x: x[1]), states 609 | elif root_node.type in if_statement: 610 | DFG = [] 611 | current_states = states.copy() 612 | others_states = [] 613 | if 'else' in root_node.type: 614 | tag = True 615 | for child in root_node.children: 616 | if 'else' in child.type: 617 | tag = True 618 | if child.type not in if_statement: 619 | temp, current_states = DFG_ruby(child, index_to_code, current_states) 620 | DFG += temp 621 | else: 622 | temp, new_states = DFG_ruby(child, index_to_code, states) 623 | DFG += temp 624 | others_states.append(new_states) 625 | others_states.append(current_states) 626 | if tag is False: 627 | others_states.append(states) 628 | new_states = {} 629 | for dic in others_states: 630 | for key in dic: 631 | if key not in new_states: 632 | new_states[key] = dic[key].copy() 633 | else: 634 | new_states[key] += dic[key] 635 | for key in new_states: 636 | new_states[key] = sorted(list(set(new_states[key]))) 637 | return sorted(DFG, key=lambda x: x[1]), new_states 638 | elif root_node.type in for_statement: 639 | DFG = [] 640 | for i in range(2): 641 | left_nodes = [root_node.child_by_field_name('pattern')] 642 | right_nodes = [root_node.child_by_field_name('value')] 643 | assert len(right_nodes) == len(left_nodes) 644 | for node in right_nodes: 645 | temp, states = DFG_ruby(node, index_to_code, states) 646 | DFG += temp 647 | for left_node, right_node in zip(left_nodes, right_nodes): 648 | left_tokens_index = tree_to_variable_index(left_node, index_to_code) 649 | right_tokens_index = tree_to_variable_index(right_node, index_to_code) 650 | temp = [] 651 | for token1_index in left_tokens_index: 652 | idx1, code1 = index_to_code[token1_index] 653 | temp.append((code1, idx1, 'computedFrom', [index_to_code[x][1] for x in right_tokens_index], 654 | [index_to_code[x][0] for x in right_tokens_index])) 655 | states[code1] = [idx1] 656 | DFG += temp 657 | temp, states = DFG_ruby(root_node.child_by_field_name('body'), index_to_code, states) 658 | DFG += temp 659 | dic = {} 660 | for x in DFG: 661 | if (x[0], x[1], x[2]) not in dic: 662 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 663 | else: 664 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 665 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 666 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 667 | return sorted(DFG, key=lambda x: x[1]), states 668 | elif root_node.type in while_statement: 669 | DFG = [] 670 | for i in range(2): 671 | for child in root_node.children: 672 | temp, states = DFG_ruby(child, index_to_code, states) 673 | DFG += temp 674 | dic = {} 675 | for x in DFG: 676 | if (x[0], x[1], x[2]) not in dic: 677 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 678 | else: 679 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 680 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 681 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 682 | return sorted(DFG, key=lambda x: x[1]), states 683 | else: 684 | DFG = [] 685 | for child in root_node.children: 686 | if child.type in do_first_statement: 687 | temp, states = DFG_ruby(child, index_to_code, states) 688 | DFG += temp 689 | for child in root_node.children: 690 | if child.type not in do_first_statement: 691 | temp, states = DFG_ruby(child, index_to_code, states) 692 | DFG += temp 693 | 694 | return sorted(DFG, key=lambda x: x[1]), states 695 | 696 | 697 | def DFG_go(root_node, index_to_code, states): 698 | assignment = ['assignment_statement', ] 699 | def_statement = ['var_spec'] 700 | increment_statement = ['inc_statement'] 701 | if_statement = ['if_statement', 'else'] 702 | for_statement = ['for_statement'] 703 | enhanced_for_statement = [] 704 | while_statement = [] 705 | do_first_statement = [] 706 | states = states.copy() 707 | if (len(root_node.children) == 0 or root_node.type == 'string') and root_node.type != 'comment': 708 | idx, code = index_to_code[(root_node.start_point, root_node.end_point)] 709 | if root_node.type == code: 710 | return [], states 711 | elif code in states: 712 | return [(code, idx, 'comesFrom', [code], states[code].copy())], states 713 | else: 714 | if root_node.type == 'identifier': 715 | states[code] = [idx] 716 | return [(code, idx, 'comesFrom', [], [])], states 717 | elif root_node.type in def_statement: 718 | name = root_node.child_by_field_name('name') 719 | value = root_node.child_by_field_name('value') 720 | DFG = [] 721 | if value is None: 722 | indexs = tree_to_variable_index(name, index_to_code) 723 | for index in indexs: 724 | idx, code = index_to_code[index] 725 | DFG.append((code, idx, 'comesFrom', [], [])) 726 | states[code] = [idx] 727 | return sorted(DFG, key=lambda x: x[1]), states 728 | else: 729 | name_indexs = tree_to_variable_index(name, index_to_code) 730 | value_indexs = tree_to_variable_index(value, index_to_code) 731 | temp, states = DFG_go(value, index_to_code, states) 732 | DFG += temp 733 | for index1 in name_indexs: 734 | idx1, code1 = index_to_code[index1] 735 | for index2 in value_indexs: 736 | idx2, code2 = index_to_code[index2] 737 | DFG.append((code1, idx1, 'comesFrom', [code2], [idx2])) 738 | states[code1] = [idx1] 739 | return sorted(DFG, key=lambda x: x[1]), states 740 | elif root_node.type in assignment: 741 | left_nodes = root_node.child_by_field_name('left') 742 | right_nodes = root_node.child_by_field_name('right') 743 | DFG = [] 744 | temp, states = DFG_go(right_nodes, index_to_code, states) 745 | DFG += temp 746 | name_indexs = tree_to_variable_index(left_nodes, index_to_code) 747 | value_indexs = tree_to_variable_index(right_nodes, index_to_code) 748 | for index1 in name_indexs: 749 | idx1, code1 = index_to_code[index1] 750 | for index2 in value_indexs: 751 | idx2, code2 = index_to_code[index2] 752 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 753 | states[code1] = [idx1] 754 | return sorted(DFG, key=lambda x: x[1]), states 755 | elif root_node.type in increment_statement: 756 | DFG = [] 757 | indexs = tree_to_variable_index(root_node, index_to_code) 758 | for index1 in indexs: 759 | idx1, code1 = index_to_code[index1] 760 | for index2 in indexs: 761 | idx2, code2 = index_to_code[index2] 762 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 763 | states[code1] = [idx1] 764 | return sorted(DFG, key=lambda x: x[1]), states 765 | elif root_node.type in if_statement: 766 | DFG = [] 767 | current_states = states.copy() 768 | others_states = [] 769 | flag = False 770 | if 'else' in root_node.type: 771 | tag = True 772 | for child in root_node.children: 773 | if 'else' in child.type: 774 | tag = True 775 | if child.type not in if_statement and flag is False: 776 | temp, current_states = DFG_go(child, index_to_code, current_states) 777 | DFG += temp 778 | else: 779 | flag = True 780 | temp, new_states = DFG_go(child, index_to_code, states) 781 | DFG += temp 782 | others_states.append(new_states) 783 | others_states.append(current_states) 784 | if tag is False: 785 | others_states.append(states) 786 | new_states = {} 787 | for dic in others_states: 788 | for key in dic: 789 | if key not in new_states: 790 | new_states[key] = dic[key].copy() 791 | else: 792 | new_states[key] += dic[key] 793 | for key in states: 794 | if key not in new_states: 795 | new_states[key] = states[key] 796 | else: 797 | new_states[key] += states[key] 798 | for key in new_states: 799 | new_states[key] = sorted(list(set(new_states[key]))) 800 | return sorted(DFG, key=lambda x: x[1]), new_states 801 | elif root_node.type in for_statement: 802 | DFG = [] 803 | for child in root_node.children: 804 | temp, states = DFG_go(child, index_to_code, states) 805 | DFG += temp 806 | flag = False 807 | for child in root_node.children: 808 | if flag: 809 | temp, states = DFG_go(child, index_to_code, states) 810 | DFG += temp 811 | elif child.type == "for_clause": 812 | if child.child_by_field_name('update') is not None: 813 | temp, states = DFG_go(child.child_by_field_name('update'), index_to_code, states) 814 | DFG += temp 815 | flag = True 816 | dic = {} 817 | for x in DFG: 818 | if (x[0], x[1], x[2]) not in dic: 819 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 820 | else: 821 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 822 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 823 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 824 | return sorted(DFG, key=lambda x: x[1]), states 825 | else: 826 | DFG = [] 827 | for child in root_node.children: 828 | if child.type in do_first_statement: 829 | temp, states = DFG_go(child, index_to_code, states) 830 | DFG += temp 831 | for child in root_node.children: 832 | if child.type not in do_first_statement: 833 | temp, states = DFG_go(child, index_to_code, states) 834 | DFG += temp 835 | 836 | return sorted(DFG, key=lambda x: x[1]), states 837 | 838 | 839 | def DFG_php(root_node, index_to_code, states): 840 | assignment = ['assignment_expression', 'augmented_assignment_expression'] 841 | def_statement = ['simple_parameter'] 842 | increment_statement = ['update_expression'] 843 | if_statement = ['if_statement', 'else_clause'] 844 | for_statement = ['for_statement'] 845 | enhanced_for_statement = ['foreach_statement'] 846 | while_statement = ['while_statement'] 847 | do_first_statement = [] 848 | states = states.copy() 849 | if (len(root_node.children) == 0 or root_node.type == 'string') and root_node.type != 'comment': 850 | idx, code = index_to_code[(root_node.start_point, root_node.end_point)] 851 | if root_node.type == code: 852 | return [], states 853 | elif code in states: 854 | return [(code, idx, 'comesFrom', [code], states[code].copy())], states 855 | else: 856 | if root_node.type == 'identifier': 857 | states[code] = [idx] 858 | return [(code, idx, 'comesFrom', [], [])], states 859 | elif root_node.type in def_statement: 860 | name = root_node.child_by_field_name('name') 861 | value = root_node.child_by_field_name('default_value') 862 | DFG = [] 863 | if value is None: 864 | indexs = tree_to_variable_index(name, index_to_code) 865 | for index in indexs: 866 | idx, code = index_to_code[index] 867 | DFG.append((code, idx, 'comesFrom', [], [])) 868 | states[code] = [idx] 869 | return sorted(DFG, key=lambda x: x[1]), states 870 | else: 871 | name_indexs = tree_to_variable_index(name, index_to_code) 872 | value_indexs = tree_to_variable_index(value, index_to_code) 873 | temp, states = DFG_php(value, index_to_code, states) 874 | DFG += temp 875 | for index1 in name_indexs: 876 | idx1, code1 = index_to_code[index1] 877 | for index2 in value_indexs: 878 | idx2, code2 = index_to_code[index2] 879 | DFG.append((code1, idx1, 'comesFrom', [code2], [idx2])) 880 | states[code1] = [idx1] 881 | return sorted(DFG, key=lambda x: x[1]), states 882 | elif root_node.type in assignment: 883 | left_nodes = root_node.child_by_field_name('left') 884 | right_nodes = root_node.child_by_field_name('right') 885 | DFG = [] 886 | temp, states = DFG_php(right_nodes, index_to_code, states) 887 | DFG += temp 888 | name_indexs = tree_to_variable_index(left_nodes, index_to_code) 889 | value_indexs = tree_to_variable_index(right_nodes, index_to_code) 890 | for index1 in name_indexs: 891 | idx1, code1 = index_to_code[index1] 892 | for index2 in value_indexs: 893 | idx2, code2 = index_to_code[index2] 894 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 895 | states[code1] = [idx1] 896 | return sorted(DFG, key=lambda x: x[1]), states 897 | elif root_node.type in increment_statement: 898 | DFG = [] 899 | indexs = tree_to_variable_index(root_node, index_to_code) 900 | for index1 in indexs: 901 | idx1, code1 = index_to_code[index1] 902 | for index2 in indexs: 903 | idx2, code2 = index_to_code[index2] 904 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 905 | states[code1] = [idx1] 906 | return sorted(DFG, key=lambda x: x[1]), states 907 | elif root_node.type in if_statement: 908 | DFG = [] 909 | current_states = states.copy() 910 | others_states = [] 911 | flag = False 912 | if 'else' in root_node.type: 913 | tag = True 914 | for child in root_node.children: 915 | if 'else' in child.type: 916 | tag = True 917 | if child.type not in if_statement and flag is False: 918 | temp, current_states = DFG_php(child, index_to_code, current_states) 919 | DFG += temp 920 | else: 921 | flag = True 922 | temp, new_states = DFG_php(child, index_to_code, states) 923 | DFG += temp 924 | others_states.append(new_states) 925 | others_states.append(current_states) 926 | new_states = {} 927 | for dic in others_states: 928 | for key in dic: 929 | if key not in new_states: 930 | new_states[key] = dic[key].copy() 931 | else: 932 | new_states[key] += dic[key] 933 | for key in states: 934 | if key not in new_states: 935 | new_states[key] = states[key] 936 | else: 937 | new_states[key] += states[key] 938 | for key in new_states: 939 | new_states[key] = sorted(list(set(new_states[key]))) 940 | return sorted(DFG, key=lambda x: x[1]), new_states 941 | elif root_node.type in for_statement: 942 | DFG = [] 943 | for child in root_node.children: 944 | temp, states = DFG_php(child, index_to_code, states) 945 | DFG += temp 946 | flag = False 947 | for child in root_node.children: 948 | if flag: 949 | temp, states = DFG_php(child, index_to_code, states) 950 | DFG += temp 951 | elif child.type == "assignment_expression": 952 | flag = True 953 | dic = {} 954 | for x in DFG: 955 | if (x[0], x[1], x[2]) not in dic: 956 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 957 | else: 958 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 959 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 960 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 961 | return sorted(DFG, key=lambda x: x[1]), states 962 | elif root_node.type in enhanced_for_statement: 963 | name = None 964 | value = None 965 | for child in root_node.children: 966 | if child.type == 'variable_name' and value is None: 967 | value = child 968 | elif child.type == 'variable_name' and name is None: 969 | name = child 970 | break 971 | body = root_node.child_by_field_name('body') 972 | DFG = [] 973 | for i in range(2): 974 | temp, states = DFG_php(value, index_to_code, states) 975 | DFG += temp 976 | name_indexs = tree_to_variable_index(name, index_to_code) 977 | value_indexs = tree_to_variable_index(value, index_to_code) 978 | for index1 in name_indexs: 979 | idx1, code1 = index_to_code[index1] 980 | for index2 in value_indexs: 981 | idx2, code2 = index_to_code[index2] 982 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 983 | states[code1] = [idx1] 984 | temp, states = DFG_php(body, index_to_code, states) 985 | DFG += temp 986 | dic = {} 987 | for x in DFG: 988 | if (x[0], x[1], x[2]) not in dic: 989 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 990 | else: 991 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 992 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 993 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 994 | return sorted(DFG, key=lambda x: x[1]), states 995 | elif root_node.type in while_statement: 996 | DFG = [] 997 | for i in range(2): 998 | for child in root_node.children: 999 | temp, states = DFG_php(child, index_to_code, states) 1000 | DFG += temp 1001 | dic = {} 1002 | for x in DFG: 1003 | if (x[0], x[1], x[2]) not in dic: 1004 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 1005 | else: 1006 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 1007 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 1008 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 1009 | return sorted(DFG, key=lambda x: x[1]), states 1010 | else: 1011 | DFG = [] 1012 | for child in root_node.children: 1013 | if child.type in do_first_statement: 1014 | temp, states = DFG_php(child, index_to_code, states) 1015 | DFG += temp 1016 | for child in root_node.children: 1017 | if child.type not in do_first_statement: 1018 | temp, states = DFG_php(child, index_to_code, states) 1019 | DFG += temp 1020 | 1021 | return sorted(DFG, key=lambda x: x[1]), states 1022 | 1023 | 1024 | def DFG_javascript(root_node, index_to_code, states): 1025 | assignment = ['assignment_pattern', 'augmented_assignment_expression'] 1026 | def_statement = ['variable_declarator'] 1027 | increment_statement = ['update_expression'] 1028 | if_statement = ['if_statement', 'else'] 1029 | for_statement = ['for_statement'] 1030 | enhanced_for_statement = [] 1031 | while_statement = ['while_statement'] 1032 | do_first_statement = [] 1033 | states = states.copy() 1034 | if (len(root_node.children) == 0 or root_node.type == 'string') and root_node.type != 'comment': 1035 | idx, code = index_to_code[(root_node.start_point, root_node.end_point)] 1036 | if root_node.type == code: 1037 | return [], states 1038 | elif code in states: 1039 | return [(code, idx, 'comesFrom', [code], states[code].copy())], states 1040 | else: 1041 | if root_node.type == 'identifier': 1042 | states[code] = [idx] 1043 | return [(code, idx, 'comesFrom', [], [])], states 1044 | elif root_node.type in def_statement: 1045 | name = root_node.child_by_field_name('name') 1046 | value = root_node.child_by_field_name('value') 1047 | DFG = [] 1048 | if value is None: 1049 | indexs = tree_to_variable_index(name, index_to_code) 1050 | for index in indexs: 1051 | idx, code = index_to_code[index] 1052 | DFG.append((code, idx, 'comesFrom', [], [])) 1053 | states[code] = [idx] 1054 | return sorted(DFG, key=lambda x: x[1]), states 1055 | else: 1056 | name_indexs = tree_to_variable_index(name, index_to_code) 1057 | value_indexs = tree_to_variable_index(value, index_to_code) 1058 | temp, states = DFG_javascript(value, index_to_code, states) 1059 | DFG += temp 1060 | for index1 in name_indexs: 1061 | idx1, code1 = index_to_code[index1] 1062 | for index2 in value_indexs: 1063 | idx2, code2 = index_to_code[index2] 1064 | DFG.append((code1, idx1, 'comesFrom', [code2], [idx2])) 1065 | states[code1] = [idx1] 1066 | return sorted(DFG, key=lambda x: x[1]), states 1067 | elif root_node.type in assignment: 1068 | left_nodes = root_node.child_by_field_name('left') 1069 | right_nodes = root_node.child_by_field_name('right') 1070 | DFG = [] 1071 | temp, states = DFG_javascript(right_nodes, index_to_code, states) 1072 | DFG += temp 1073 | name_indexs = tree_to_variable_index(left_nodes, index_to_code) 1074 | value_indexs = tree_to_variable_index(right_nodes, index_to_code) 1075 | for index1 in name_indexs: 1076 | idx1, code1 = index_to_code[index1] 1077 | for index2 in value_indexs: 1078 | idx2, code2 = index_to_code[index2] 1079 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 1080 | states[code1] = [idx1] 1081 | return sorted(DFG, key=lambda x: x[1]), states 1082 | elif root_node.type in increment_statement: 1083 | DFG = [] 1084 | indexs = tree_to_variable_index(root_node, index_to_code) 1085 | for index1 in indexs: 1086 | idx1, code1 = index_to_code[index1] 1087 | for index2 in indexs: 1088 | idx2, code2 = index_to_code[index2] 1089 | DFG.append((code1, idx1, 'computedFrom', [code2], [idx2])) 1090 | states[code1] = [idx1] 1091 | return sorted(DFG, key=lambda x: x[1]), states 1092 | elif root_node.type in if_statement: 1093 | DFG = [] 1094 | current_states = states.copy() 1095 | others_states = [] 1096 | flag = False 1097 | if 'else' in root_node.type: 1098 | tag = True 1099 | for child in root_node.children: 1100 | if 'else' in child.type: 1101 | tag = True 1102 | if child.type not in if_statement and flag is False: 1103 | temp, current_states = DFG_javascript(child, index_to_code, current_states) 1104 | DFG += temp 1105 | else: 1106 | flag = True 1107 | temp, new_states = DFG_javascript(child, index_to_code, states) 1108 | DFG += temp 1109 | others_states.append(new_states) 1110 | others_states.append(current_states) 1111 | if tag is False: 1112 | others_states.append(states) 1113 | new_states = {} 1114 | for dic in others_states: 1115 | for key in dic: 1116 | if key not in new_states: 1117 | new_states[key] = dic[key].copy() 1118 | else: 1119 | new_states[key] += dic[key] 1120 | for key in states: 1121 | if key not in new_states: 1122 | new_states[key] = states[key] 1123 | else: 1124 | new_states[key] += states[key] 1125 | for key in new_states: 1126 | new_states[key] = sorted(list(set(new_states[key]))) 1127 | return sorted(DFG, key=lambda x: x[1]), new_states 1128 | elif root_node.type in for_statement: 1129 | DFG = [] 1130 | for child in root_node.children: 1131 | temp, states = DFG_javascript(child, index_to_code, states) 1132 | DFG += temp 1133 | flag = False 1134 | for child in root_node.children: 1135 | if flag: 1136 | temp, states = DFG_javascript(child, index_to_code, states) 1137 | DFG += temp 1138 | elif child.type == "variable_declaration": 1139 | flag = True 1140 | dic = {} 1141 | for x in DFG: 1142 | if (x[0], x[1], x[2]) not in dic: 1143 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 1144 | else: 1145 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 1146 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 1147 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 1148 | return sorted(DFG, key=lambda x: x[1]), states 1149 | elif root_node.type in while_statement: 1150 | DFG = [] 1151 | for i in range(2): 1152 | for child in root_node.children: 1153 | temp, states = DFG_javascript(child, index_to_code, states) 1154 | DFG += temp 1155 | dic = {} 1156 | for x in DFG: 1157 | if (x[0], x[1], x[2]) not in dic: 1158 | dic[(x[0], x[1], x[2])] = [x[3], x[4]] 1159 | else: 1160 | dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) 1161 | dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) 1162 | DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] 1163 | return sorted(DFG, key=lambda x: x[1]), states 1164 | else: 1165 | DFG = [] 1166 | for child in root_node.children: 1167 | if child.type in do_first_statement: 1168 | temp, states = DFG_javascript(child, index_to_code, states) 1169 | DFG += temp 1170 | for child in root_node.children: 1171 | if child.type not in do_first_statement: 1172 | temp, states = DFG_javascript(child, index_to_code, states) 1173 | DFG += temp 1174 | 1175 | return sorted(DFG, key=lambda x: x[1]), states 1176 | --------------------------------------------------------------------------------