├── .env.example ├── .gitignore ├── LICENSE ├── README.md ├── code_transformer ├── __init__.py ├── configuration │ ├── __init__.py │ ├── attention.py │ ├── code_transformer.py │ ├── configuration_utils.py │ ├── great_transformer.py │ ├── transformer_lm_decoder.py │ └── transformer_lm_encoder.py ├── env.py ├── experiments │ ├── __init__.py │ ├── code_transformer │ │ ├── __init__.py │ │ ├── code_summarization.py │ │ ├── code_summarization.yaml │ │ ├── language_modeling.py │ │ └── language_modeling.yaml │ ├── experiment.py │ ├── great │ │ ├── __init__.py │ │ ├── code_summarization.py │ │ └── code_summarization.yaml │ ├── log.py │ ├── mixins │ │ ├── __init__.py │ │ ├── code_summarization.py │ │ ├── code_summarization_great.py │ │ ├── code_trans_lm.py │ │ ├── code_trans_transformer.py │ │ ├── great_transformer.py │ │ ├── xl_net_lm.py │ │ └── xl_net_transformer.py │ ├── paper │ │ ├── __init__.py │ │ ├── ct_go.yaml │ │ ├── ct_java_pretrain_lm.yaml │ │ ├── ct_java_small.yaml │ │ ├── ct_java_small_ablation_only_ancestor_sp.yaml │ │ ├── ct_java_small_ablation_only_ppr.yaml │ │ ├── ct_java_small_ablation_only_shortest_paths.yaml │ │ ├── ct_java_small_ablation_only_sibling_sp.yaml │ │ ├── ct_java_small_only_ast.yaml │ │ ├── ct_java_small_pretrain.yaml │ │ ├── ct_javascript.yaml │ │ ├── ct_multilang.yaml │ │ ├── ct_multilang_go.yaml │ │ ├── ct_multilang_javascript.yaml │ │ ├── ct_multilang_lm.yaml │ │ ├── ct_multilang_lm_pretrain.yaml │ │ ├── ct_multilang_python.yaml │ │ ├── ct_multilang_ruby.yaml │ │ ├── ct_no_pointer_go.yaml │ │ ├── ct_no_pointer_java_small.yaml │ │ ├── ct_no_pointer_java_small_only_ast.yaml │ │ ├── ct_no_pointer_javascript.yaml │ │ ├── ct_no_pointer_multilang.yaml │ │ ├── ct_no_pointer_python.yaml │ │ ├── ct_no_pointer_ruby.yaml │ │ ├── ct_python.yaml │ │ ├── ct_ruby.yaml │ │ ├── great_go.yaml │ │ ├── great_java_small.yaml │ │ ├── great_javascript.yaml │ │ ├── great_multilang.yaml │ │ ├── great_python.yaml │ │ ├── great_ruby.yaml │ │ ├── xl_net_go.yaml │ │ ├── xl_net_java_small.yaml │ │ ├── xl_net_javascript.yaml │ │ ├── xl_net_multilang.yaml │ │ ├── xl_net_no_pointer_java_small.yaml │ │ ├── xl_net_python.yaml │ │ └── xl_net_ruby.yaml │ ├── preprocessing │ │ ├── __init__.py │ │ ├── preprocess-1-code2seq.yaml │ │ ├── preprocess-1-csn.yaml │ │ ├── preprocess-1.py │ │ ├── preprocess-2.py │ │ └── preprocess-2.yaml │ └── xl_net │ │ ├── __init__.py │ │ ├── base.py │ │ ├── code_summarization.py │ │ ├── code_summarization.yaml │ │ ├── language_modeling.py │ │ └── language_modeling.yaml ├── modeling │ ├── __init__.py │ ├── code_transformer │ │ ├── __init__.py │ │ ├── code_transformer.py │ │ ├── decoder.py │ │ ├── distance_embeddings.py │ │ └── lm.py │ ├── constants.py │ ├── data_utils.py │ ├── decoder │ │ ├── __init__.py │ │ ├── pointer.py │ │ └── transformer.py │ ├── great_transformer │ │ ├── __init__.py │ │ ├── great_transformer.py │ │ └── transformer.py │ ├── modelmanager │ │ ├── __init__.py │ │ ├── base.py │ │ ├── code_transformer.py │ │ ├── great.py │ │ └── xl_net.py │ └── xl_net │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── xl_net_language_model.py ├── preprocessing │ ├── README.md │ ├── __init__.py │ ├── datamanager │ │ ├── __init__.py │ │ ├── base.py │ │ ├── c2s │ │ │ ├── __init__.py │ │ │ └── raw.py │ │ ├── csn │ │ │ ├── __init__.py │ │ │ └── raw.py │ │ └── preprocessed.py │ ├── dataset │ │ ├── __init__.py │ │ ├── ablation.py │ │ ├── base.py │ │ ├── code_summarization.py │ │ └── lm.py │ ├── graph │ │ ├── __init__.py │ │ ├── alg.py │ │ ├── ast.py │ │ ├── binning.py │ │ ├── distances.py │ │ └── transform.py │ ├── nlp │ │ ├── __init__.py │ │ ├── javaparser.py │ │ ├── semantic.py │ │ ├── text.py │ │ ├── tokenization.py │ │ └── vocab.py │ └── pipeline │ │ ├── __init__.py │ │ ├── code2seq.py │ │ ├── filter.py │ │ ├── stage1.py │ │ └── stage2.py └── utils │ ├── __init__.py │ ├── data.py │ ├── inference.py │ ├── io.py │ ├── log.py │ ├── loss.py │ ├── metrics.py │ ├── sacred.py │ ├── timing.py │ └── vocab.py ├── figures ├── code_transformer_overview.png └── preprocessing_overview.png ├── notebooks ├── deduplicate_java_pretrain.ipynb └── interactive_prediction.ipynb ├── requirements.txt ├── scripts ├── code2seq │ ├── combine-vocabs-code2seq.sh │ ├── preprocess-code2seq-helper.py │ ├── preprocess-code2seq.py │ └── preprocess-code2seq.sh ├── deduplicate-java-pretrain.py ├── evaluate-multilanguage.py ├── evaluate.py ├── extract-java-methods.py ├── run-experiment.py └── run-preprocessing.py ├── setup.py ├── sub_modules ├── code2seq │ ├── .gitignore │ ├── CSharpExtractor │ │ ├── .gitattributes │ │ ├── .gitignore │ │ ├── CSharpExtractor │ │ │ ├── .nuget │ │ │ │ └── packages.config │ │ │ ├── CSharpExtractor.sln │ │ │ └── Extractor │ │ │ │ ├── Extractor.cs │ │ │ │ ├── Extractor.csproj │ │ │ │ ├── PathFinder.cs │ │ │ │ ├── Program.cs │ │ │ │ ├── Properties │ │ │ │ └── launchSettings.json │ │ │ │ ├── Temp.cs │ │ │ │ ├── Tree │ │ │ │ └── Tree.cs │ │ │ │ ├── Utilities.cs │ │ │ │ └── Variable.cs │ │ └── extract.py │ ├── Input.java │ ├── JavaExtractor │ │ ├── JPredict │ │ │ ├── .classpath │ │ │ ├── .gitignore │ │ │ └── src │ │ │ │ └── main │ │ │ │ └── java │ │ │ │ ├── JavaExtractor │ │ │ │ ├── App.java │ │ │ │ ├── Common │ │ │ │ │ ├── CommandLineValues.java │ │ │ │ │ ├── Common.java │ │ │ │ │ └── MethodContent.java │ │ │ │ ├── ExtractFeaturesTask.java │ │ │ │ ├── FeatureExtractor.java │ │ │ │ ├── FeaturesEntities │ │ │ │ │ ├── ProgramFeatures.java │ │ │ │ │ ├── ProgramRelation.java │ │ │ │ │ └── Property.java │ │ │ │ └── Visitors │ │ │ │ │ ├── FunctionVisitor.java │ │ │ │ │ └── LeavesCollectorVisitor.java │ │ │ │ └── Test.java │ │ └── extract.py │ ├── LICENSE │ ├── Python150kExtractor │ │ ├── README.md │ │ ├── extract.py │ │ └── preprocess.sh │ ├── README.md │ ├── __init__.py │ ├── baseline_tokenization │ │ ├── input_example.txt │ │ ├── javalang │ │ │ ├── __init__.py │ │ │ ├── ast.py │ │ │ ├── javadoc.py │ │ │ ├── parse.py │ │ │ ├── parser.py │ │ │ ├── test │ │ │ │ ├── __init__.py │ │ │ │ ├── source │ │ │ │ │ └── package-info │ │ │ │ │ │ ├── AnnotationJavadoc.java │ │ │ │ │ │ ├── AnnotationOnly.java │ │ │ │ │ │ ├── JavadocAnnotation.java │ │ │ │ │ │ ├── JavadocOnly.java │ │ │ │ │ │ └── NoAnnotationNoJavadoc.java │ │ │ │ ├── test_java_8_syntax.py │ │ │ │ ├── test_javadoc.py │ │ │ │ ├── test_package_declaration.py │ │ │ │ └── test_util.py │ │ │ ├── tokenizer.py │ │ │ ├── tree.py │ │ │ └── util.py │ │ └── subtokenize_nmt_baseline.py │ ├── code2seq.py │ ├── common.py │ ├── config.py │ ├── evaluate.sh │ ├── extractor.py │ ├── images │ │ └── network.png │ ├── interactive_predict.py │ ├── model.py │ ├── preprocess.py │ ├── preprocess.sh │ ├── preprocess_csharp.sh │ ├── reader.py │ ├── train.sh │ └── train_python150k.sh ├── java-method-extractor │ ├── .classpath │ ├── JavaExtractor (1).iml │ ├── JavaExtractor.iml │ ├── JavaMethodExtractor.iml │ ├── code-2-seq-java-extractor.iml │ ├── dependency-reduced-pom.xml │ ├── extract.py │ ├── pom.xml │ ├── src │ │ └── main │ │ │ ├── JavaExtractor │ │ │ ├── App.java │ │ │ ├── Common │ │ │ │ ├── CommandLineValues.java │ │ │ │ ├── Common.java │ │ │ │ └── MethodContent.java │ │ │ ├── ExtractFeaturesTask.java │ │ │ ├── FeatureExtractor.java │ │ │ ├── FeaturesEntities │ │ │ │ ├── ProgramFeatures.java │ │ │ │ ├── ProgramRelation.java │ │ │ │ └── Property.java │ │ │ └── Visitors │ │ │ │ ├── FunctionVisitor.java │ │ │ │ └── LeavesCollectorVisitor.java │ │ │ └── java │ │ │ ├── CommandLineValues.java │ │ │ ├── ExtractMethodsTask.java │ │ │ ├── JavaMethodExtractor.java │ │ │ ├── MethodContent.java │ │ │ └── MethodVisitor.java │ └── target │ │ └── JavaMethodExtractor-1.0.0-SNAPSHOT.jar └── java-parser │ ├── java-parser.iml │ ├── java-parser.iws │ ├── pom.xml │ ├── src │ └── main │ │ └── java │ │ ├── ASTNode.java │ │ ├── ASTParser.java │ │ └── META-INF │ │ └── MANIFEST.MF │ └── target │ └── java-parser-1.0-SNAPSHOT.jar └── tests ├── modeling ├── code_transformer │ └── test_code_transformer.py ├── great │ └── test_great.py └── xl_net │ └── test_xl_net.py ├── preprocessing └── binning.py ├── test_loss.py └── test_metrics.py /.env.example: -------------------------------------------------------------------------------- 1 | # Put .env into ${HOME}/.config/code_transformer/.env 2 | # 3 | # Per default, we assume the following folder structure: 4 | # CODE_TRANSFORMER_DATA_PATH 5 | # ├── raw 6 | # │ ├── csn 7 | # │ │ ├── python 8 | # │ │ │ └── final 9 | # │ │ │ └── ... 10 | # │ │ : 11 | # │ │ └── go 12 | # │ ├── code2seq 13 | # │ │ └── java-small 14 | # │ └── code2seq-methods 15 | # │ └── java-small 16 | # ├── stage1 17 | # └── stage2 18 | # ├── python 19 | # │ ├── train 20 | # │ ├── valid 21 | # │ ├── test 22 | # │ └── vocabularies.p.gzip 23 | # ├── java-small 24 | # : 25 | # └── python,javascript,go,ruby 26 | # 27 | # 28 | # CODE_TRANSFORMER_BINARY_PATH 29 | # ├── java-parser-1.0-SNAPSHOT.jar 30 | # ├── JavaMethodExtractor-1.0.0-SNAPSHOT.jar 31 | # └── semantic 32 | # 33 | # CODE_TRANSFORMER_MODELS_PATH 34 | # ├── ct_lm 35 | # ├── ct_code_summarization 36 | # │ ├── CT-1 37 | # │ │ ├── config.json 38 | # │ │ ├── model_10000.p 39 | # │ │ : 40 | # │ │ └── model_450000.p 41 | # │ : 42 | # │ └── CT-24 43 | # ├── great_code_summarization 44 | # └── xl_net_code_summarization 45 | 46 | export CODE_TRANSFORMER_DATA_PATH=<<>> 47 | export CODE_TRANSFORMER_BINARY_PATH=<<>> 48 | export CODE_TRANSFORMER_MODELS_PATH=<<>> 49 | export CODE_TRANSFORMER_LOGS_PATH=<<>> 50 | 51 | export CODE_TRANSFORMER_CSN_RAW_DATA_PATH=${CODE_TRANSFORMER_DATA_PATH}/raw/csn 52 | export CODE_TRANSFORMER_CODE2SEQ_RAW_DATA_PATH=${CODE_TRANSFORMER_DATA_PATH}/raw/code2seq 53 | export CODE_TRANSFORMER_CODE2SEQ_EXTRACTED_METHODS_DATA_PATH=${CODE_TRANSFORMER_DATA_PATH}/raw/code2seq-methods 54 | 55 | export CODE_TRANSFORMER_DATA_PATH_STAGE_1=${CODE_TRANSFORMER_DATA_PATH}/stage1 56 | export CODE_TRANSFORMER_DATA_PATH_STAGE_2=${CODE_TRANSFORMER_DATA_PATH}/stage2 57 | 58 | export CODE_TRANSFORMER_JAVA_EXECUTABLE=java 59 | export CODE_TRANSFORMER_JAVA_PARSER_EXECUTABLE=${CODE_TRANSFORMER_BINARY_PATH}/java-parser-1.0-SNAPSHOT.jar 60 | export CODE_TRANSFORMER_JAVA_METHOD_EXTRACTOR_EXECUTABLE=${CODE_TRANSFORMER_BINARY_PATH}/JavaMethodExtractor-1.0.0-SNAPSHOT.jar 61 | export CODE_TRANSFORMER_SEMANTIC_EXECUTABLE=${CODE_TRANSFORMER_BINARY_PATH}/semantic -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python cache 2 | _pycache__ 3 | *.py[cod] 4 | 5 | # IntelliJ/Jupyter 6 | .idea 7 | .ipynb_checkpoints 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Daniel Zügner and Tobias Kirschstein, Technical University of Munich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/__init__.py -------------------------------------------------------------------------------- /code_transformer/configuration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/configuration/__init__.py -------------------------------------------------------------------------------- /code_transformer/configuration/attention.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class AttentionType(Enum): 5 | SCALED_DOT_PRODUCT = "scaled_dot_product" 6 | ADDITIVE = "additive" 7 | MULTIHEAD = "multihead" 8 | -------------------------------------------------------------------------------- /code_transformer/configuration/code_transformer.py: -------------------------------------------------------------------------------- 1 | from code_transformer.configuration.configuration_utils import ModelConfiguration 2 | 3 | 4 | class CodeTransformerLayerConfig(ModelConfiguration): 5 | def __init__(self, 6 | d_model=256, 7 | nhead=8, 8 | dim_feedforward=1024, 9 | activation="gelu", 10 | dropout=0.1, 11 | 12 | num_relative_distances=1, 13 | use_token_distances=False, 14 | use_edge_embeddings=False, 15 | use_content_content=True, 16 | use_content_pos=True, 17 | use_pos_content=True, 18 | use_pos_pos=True, ): 19 | super(CodeTransformerLayerConfig, self).__init__() 20 | self.d_model = d_model 21 | self.nhead = nhead 22 | self.dim_feedforward = dim_feedforward 23 | self.activation = activation 24 | self.dropout = dropout 25 | self.num_relative_distances = num_relative_distances 26 | self.use_token_distances = use_token_distances 27 | self.use_edge_embeddings = use_edge_embeddings 28 | self.use_content_content = use_content_content 29 | self.use_content_pos = use_content_pos 30 | self.use_pos_content = use_pos_content 31 | self.use_pos_pos = use_pos_pos 32 | 33 | 34 | class CodeTransformerCoreConfig(ModelConfiguration): 35 | def __init__(self, 36 | encoder_layer: CodeTransformerLayerConfig, 37 | num_layers: int, 38 | positional_encoding=None, 39 | norm=None 40 | ): 41 | super(CodeTransformerCoreConfig, self).__init__() 42 | if isinstance(encoder_layer, CodeTransformerLayerConfig): 43 | self.encoder_layer = CodeTransformerLayerConfig(**encoder_layer) 44 | else: 45 | self.encoder_layer = encoder_layer 46 | self.num_layers = num_layers 47 | self.norm = norm 48 | self.positional_encoding = positional_encoding 49 | -------------------------------------------------------------------------------- /code_transformer/configuration/configuration_utils.py: -------------------------------------------------------------------------------- 1 | class DotDict(dict): 2 | """ 3 | Simple extension of Python's dict to support dot access. 4 | """ 5 | def __init__(self, *args, **kwargs): 6 | super(DotDict, self).__init__(*args, **kwargs) 7 | for arg in args: 8 | if isinstance(arg, dict): 9 | for k, v in arg.items(): 10 | self[k] = v 11 | 12 | if kwargs: 13 | for k, v in kwargs.items(): 14 | if isinstance(v, dict): 15 | self[k] = DotDict(**v) 16 | else: 17 | self[k] = v 18 | 19 | def __getattr__(self, attr): 20 | return self[attr] 21 | 22 | def __setattr__(self, key, value): 23 | self.__setitem__(key, value) 24 | 25 | def __setitem__(self, key, value): 26 | super(DotDict, self).__setitem__(key, value) 27 | self.__dict__.update({key: value}) 28 | 29 | def __delattr__(self, item): 30 | self.__delitem__(item) 31 | 32 | def __delitem__(self, key): 33 | super(DotDict, self).__delitem__(key) 34 | del self.__dict__[key] 35 | 36 | 37 | class ModelConfiguration(DotDict): 38 | pass 39 | 40 | -------------------------------------------------------------------------------- /code_transformer/configuration/great_transformer.py: -------------------------------------------------------------------------------- 1 | from code_transformer.configuration.configuration_utils import ModelConfiguration 2 | 3 | 4 | class GreatTransformerConfig(ModelConfiguration): 5 | def __init__(self, 6 | num_layers: int, 7 | positional_encoding=None, 8 | embed_dim=256, 9 | num_heads=8, 10 | ff_dim=1024, 11 | dropout_rate=0.1, 12 | is_encoder_decoder=False 13 | ): 14 | super(GreatTransformerConfig, self).__init__() 15 | 16 | self.num_layers = num_layers 17 | self.positional_encoding = positional_encoding 18 | 19 | self.embed_dim = embed_dim 20 | self.hidden_dim = embed_dim 21 | self.attention_dim = embed_dim 22 | self.bias_dim = embed_dim 23 | self.num_heads = num_heads 24 | self.ff_dim = ff_dim 25 | self.dropout_rate = dropout_rate 26 | self.is_encoder_decoder = is_encoder_decoder 27 | 28 | 29 | class GreatEncoderConfig(ModelConfiguration): 30 | 31 | def __init__(self, 32 | transformer_config: GreatTransformerConfig, 33 | vocab_size=32000, 34 | num_node_types=None, 35 | subtokens_per_token=5, 36 | num_languages=None 37 | ): 38 | super(GreatEncoderConfig, self).__init__() 39 | 40 | self.transformer_config = transformer_config 41 | self.vocab_size = vocab_size 42 | self.num_node_types = num_node_types 43 | self.subtokens_per_token = subtokens_per_token 44 | self.num_languages = num_languages 45 | -------------------------------------------------------------------------------- /code_transformer/configuration/transformer_lm_encoder.py: -------------------------------------------------------------------------------- 1 | from code_transformer.configuration.configuration_utils import ModelConfiguration 2 | 3 | 4 | class TransformerLMEncoderConfig(ModelConfiguration): 5 | 6 | def __init__(self, 7 | transformer, #: Union[CodeTransformer, CodeTransformerCoreConfig], 8 | vocab_size=32000, 9 | num_node_types=None, 10 | num_token_types=None, 11 | subtokens_per_token=5, 12 | input_nonlinearity=None, 13 | num_languages=None): 14 | super(TransformerLMEncoderConfig, self).__init__() 15 | 16 | self.transformer = transformer 17 | self.vocab_size = vocab_size 18 | self.num_token_types = num_token_types 19 | self.num_node_types = num_node_types 20 | self.subtokens_per_token = subtokens_per_token 21 | self.input_nonlinearity = input_nonlinearity 22 | self.num_languages = num_languages 23 | -------------------------------------------------------------------------------- /code_transformer/env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Per default, we assume the following folder structure: 3 | CODE_TRANSFORMER_DATA_PATH 4 | ├── raw 5 | │ ├── csn 6 | │ │ ├── python 7 | │ │ │ └── final 8 | │ │ │ └── ... 9 | │ │ : 10 | │ │ └── go 11 | │ ├── code2seq 12 | │ │ └── java-small 13 | │ └── code2seq-methods 14 | │ └── java-small 15 | ├── stage1 16 | └── stage2 17 | ├── python 18 | │ ├── train 19 | │ ├── valid 20 | │ ├── test 21 | │ └── vocabularies.p.gzip 22 | ├── java-small 23 | : 24 | └── python,javascript,go,ruby 25 | 26 | 27 | CODE_TRANSFORMER_BINARY_PATH 28 | ├── java-parser-1.0-SNAPSHOT.jar 29 | ├── JavaMethodExtractor-1.0.0-SNAPSHOT.jar 30 | └── semantic 31 | 32 | CODE_TRANSFORMER_MODELS_PATH 33 | ├── ct_lm 34 | ├── ct_code_summarization 35 | │ ├── CT-1 36 | │ │ ├── config.json 37 | │ │ ├── model_10000.p 38 | │ │ : 39 | │ │ └── model_450000.p 40 | │ : 41 | │ └── CT-24 42 | ├── great_code_summarization 43 | └── xl_net_code_summarization 44 | """ 45 | 46 | from environs import Env 47 | from pathlib import Path 48 | 49 | env = Env(expand_vars=True) 50 | env_file_path = Path(f"{Path.home()}/.config/code_transformer/.env") 51 | if env_file_path.exists(): 52 | env.read_env(env_file_path, recurse=False) 53 | 54 | with env.prefixed("CODE_TRANSFORMER_"): 55 | 56 | _DATA_PATH = env("DATA_PATH") 57 | _BINARY_PATH = env("BINARY_PATH") 58 | MODELS_SAVE_PATH = env("MODELS_PATH") 59 | LOGS_PATH = env("LOGS_PATH") 60 | 61 | CSN_RAW_DATA_PATH = env("CSN_RAW_DATA_PATH", f"{_DATA_PATH}/raw/csn") 62 | CODE2SEQ_RAW_DATA_PATH = env("CODE2SEQ_RAW_DATA_PATH", f"{_DATA_PATH}/raw/code2seq") 63 | CODE2SEQ_EXTRACTED_METHODS_DATA_PATH = env("CODE2SEQ_EXTRACTED_METHODS_DATA_PATH", f"{_DATA_PATH}/raw/code2seq-methods") 64 | 65 | DATA_PATH_STAGE_1 = env("DATA_PATH_STAGE_1", f"{_DATA_PATH}/stage1") 66 | DATA_PATH_STAGE_2 = env("DATA_PATH_STAGE_2", f"{_DATA_PATH}/stage2") 67 | 68 | JAVA_EXECUTABLE = env("JAVA_EXECUTABLE", "java") 69 | JAVA_PARSER_EXECUTABLE = env("JAVA_PARSER_EXECUTABLE", f"{_BINARY_PATH}/java-parser-1.0-SNAPSHOT.jar") 70 | JAVA_METHOD_EXTRACTOR_EXECUTABLE = env("JAVA_METHOD_EXTRACTOR_EXECUTABLE", f"{_BINARY_PATH}/JavaMethodExtractor-1.0.0-SNAPSHOT.jar") 71 | SEMANTIC_EXECUTABLE = env("SEMANTIC_EXECUTABLE", f"{_BINARY_PATH}/semantic") 72 | -------------------------------------------------------------------------------- /code_transformer/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/experiments/__init__.py -------------------------------------------------------------------------------- /code_transformer/experiments/code_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/experiments/code_transformer/__init__.py -------------------------------------------------------------------------------- /code_transformer/experiments/code_transformer/code_summarization.py: -------------------------------------------------------------------------------- 1 | from code_transformer.experiments.experiment import ExperimentSetup, ex 2 | from code_transformer.experiments.mixins.code_summarization import CTCodeSummarizationMixin 3 | from code_transformer.experiments.mixins.code_trans_transformer import CodeTransformerDecoderMixin 4 | 5 | 6 | class CodeTransDecoderExperimentSetup(CodeTransformerDecoderMixin, 7 | CTCodeSummarizationMixin, 8 | ExperimentSetup): 9 | pass 10 | 11 | 12 | @ex.automain 13 | def main(): 14 | experiment = CodeTransDecoderExperimentSetup() 15 | experiment.train() 16 | 17 | 18 | @ex.command(unobserved=True) 19 | def recreate_experiment(): 20 | return CodeTransDecoderExperimentSetup() -------------------------------------------------------------------------------- /code_transformer/experiments/code_transformer/language_modeling.py: -------------------------------------------------------------------------------- 1 | from code_transformer.experiments.experiment import ExperimentSetup, ex 2 | from code_transformer.experiments.mixins.code_trans_transformer import CodeTransformerDecoderMixin 3 | from code_transformer.modeling.modelmanager import CodeTransformerLMModelManager 4 | 5 | 6 | class CodeTransLanguageModelingTransformerExperimentSetup(CodeTransformerDecoderMixin, 7 | ExperimentSetup): 8 | 9 | def __init__(self): 10 | super(CodeTransLanguageModelingTransformerExperimentSetup, self).__init__() 11 | self.model_manager = CodeTransformerLMModelManager() 12 | # Overwrite model manager to ensure saving language modeling experiments in a different folder even if 13 | # they share the exact same architecture as models trained for code summarization 14 | 15 | 16 | @ex.automain 17 | def main(): 18 | experiment = CodeTransLanguageModelingTransformerExperimentSetup() 19 | experiment.train() 20 | 21 | 22 | @ex.command(unobserved=True) 23 | def recreate_experiment(): 24 | return CodeTransLanguageModelingTransformerExperimentSetup() 25 | -------------------------------------------------------------------------------- /code_transformer/experiments/code_transformer/language_modeling.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/language_modeling.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | num_predict: 2 7 | use_validation: True 8 | use_no_punctuation: True 9 | use_pointer_network: True 10 | num_sub_tokens: 5 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: 15 | - ppr 16 | - ancestor_sp 17 | - sibling_sp 18 | - shortest_paths 19 | distance_binning: 20 | type: 'exponential' 21 | growth_factor: 1.3 22 | n_fixed_bins: 9 23 | 24 | transfer_learning: 25 | use_pretrained_model: False 26 | model_type: 'ct_lm' 27 | run_id: 27 28 | snapshot_iteration: 'latest' 29 | 30 | model: 31 | with_cuda: True 32 | label_smoothing: 0.1 33 | lm_encoder: 34 | input_nonlinearity: 'tanh' 35 | num_languages: None 36 | transformer: 37 | num_layers: 3 38 | encoder_layer: 39 | d_model: 16 40 | nhead: 8 41 | dim_feedforward: 16 42 | dropout: 0 43 | activation: 'gelu' 44 | use_content_content: True 45 | use_content_pos: True 46 | use_pos_content: True 47 | use_pos_pos: True 48 | use_token_distances: True 49 | lm_decoder: 50 | output_nonlinearity: None 51 | n_layers: 1 52 | decoder_dropout: 0 53 | decoder_nhead: 8 54 | decoder_dim_feedforward: 16 55 | decoder_activation: 'gelu' 56 | use_teacher_forcing: True 57 | pointer_attention_type: 'additive' 58 | use_pointer_query_linear: False 59 | use_pointer_query_self_attention: False 60 | attend_cls_token: False 61 | 62 | optimizer: 63 | learning_rate: 8e-5 64 | reg_scale: 0 65 | 66 | # scheduler: 'OneCycleLR' 67 | # scheduler_params: 68 | # max_lr: 5e-5 69 | # steps_per_epoch: 2000 # 500000 / 256 70 | # epochs: 21 71 | # pct_start: 0.1 72 | 73 | #scheduler: 'MultiStepLR' 74 | #scheduler_params: 75 | # milestones: [50] 76 | # gamma: 0.1 77 | 78 | training: 79 | random_seed: 123 80 | batch_size: 2 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | validate_every: 1000 84 | persistent_snapshot_every: 100 85 | max_validation_samples: 10000 86 | metrics: 87 | - top1_accuracy 88 | - top5_accuracy 89 | - non_trivial_accuracy 90 | - precision 91 | - recall 92 | - f1_score 93 | - rouge_2 94 | - rouge_l 95 | -------------------------------------------------------------------------------- /code_transformer/experiments/great/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/experiments/great/__init__.py -------------------------------------------------------------------------------- /code_transformer/experiments/great/code_summarization.py: -------------------------------------------------------------------------------- 1 | from code_transformer.experiments.experiment import ExperimentSetup, ex 2 | from code_transformer.experiments.mixins.code_summarization_great import CTCodeSummarizationGreatMixin 3 | from code_transformer.experiments.mixins.great_transformer import GreatTransformerDecoderMixin 4 | 5 | 6 | class GreatTransformerDecoderExperimentSetup(GreatTransformerDecoderMixin, 7 | CTCodeSummarizationGreatMixin, 8 | ExperimentSetup): 9 | pass 10 | 11 | 12 | @ex.automain 13 | def main(): 14 | experiment = GreatTransformerDecoderExperimentSetup() 15 | experiment.train() 16 | 17 | 18 | @ex.command(unobserved=True) 19 | def recreate_experiment(): 20 | return GreatTransformerDecoderExperimentSetup() 21 | -------------------------------------------------------------------------------- /code_transformer/experiments/great/code_summarization.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/great/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,ruby,go' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_pointer_network: True 10 | 11 | data_transforms: 12 | max_distance_mask: None 13 | relative_distances: 14 | - ppr 15 | - ancestor_sp 16 | - sibling_sp 17 | - shortest_paths 18 | 19 | distance_binning: 20 | type: 'exponential' 21 | growth_factor: 1.3 22 | n_fixed_bins: 9 23 | 24 | model: 25 | with_cuda: False 26 | label_smoothing: 0.1 27 | lm_encoder: 28 | transformer_config: 29 | embed_dim: 16 30 | num_layers: 3 31 | num_heads: 8 32 | ff_dim: 16 33 | dropout_rate: 0.2 34 | lm_decoder: 35 | output_nonlinearity: None 36 | n_layers: 1 37 | decoder_dropout: 0 38 | decoder_nhead: 8 39 | decoder_dim_feedforward: 15 40 | decoder_activation: 'gelu' 41 | use_teacher_forcing: True 42 | pointer_attention_type: 'additive' 43 | use_pointer_query_linear: False 44 | use_pointer_query_self_attention: False 45 | attend_cls_token: False 46 | 47 | optimizer: 48 | optimizer: 'Adam' 49 | learning_rate: 8e-5 50 | reg_scale: 3e-5 51 | 52 | #scheduler: 'OneCycleLR' 53 | #scheduler_params: 54 | # max_lr: 1e-4 55 | # steps_per_epoch: 4000 56 | # epochs: 30 57 | # pct_start: 0.3 58 | 59 | #scheduler: 'MultiStepLR' 60 | #scheduler_params: 61 | # milestones: [1500, 5000] 62 | # gamma: 0.1 63 | 64 | training: 65 | random_seed: 456 66 | batch_size: 8 67 | simulated_batch_size: 128 68 | simulated_batch_size_valid: 1280 69 | accumulate_tokens_batch: False 70 | validate_every: 100 71 | persistent_snapshot_every: 100 72 | early_stopping_patience: 20 73 | max_validation_samples: 50000 74 | metrics: 75 | - top1_accuracy 76 | - top5_accuracy 77 | - non_trivial_accuracy 78 | - precision 79 | - recall 80 | - f1_score 81 | - micro_f1_score 82 | - rouge-2 83 | - rouge-l 84 | 85 | -------------------------------------------------------------------------------- /code_transformer/experiments/mixins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/experiments/mixins/__init__.py -------------------------------------------------------------------------------- /code_transformer/experiments/mixins/code_trans_lm.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from torch.nn import CrossEntropyLoss 4 | 5 | from code_transformer.experiments.experiment import ExperimentSetup, ex 6 | from code_transformer.modeling.code_transformer.lm import TransformerLanguageModel 7 | from code_transformer.modeling.modelmanager import CodeTransformerLMModelManager 8 | from code_transformer.utils.loss import LabelSmoothingLoss 9 | 10 | 11 | class CodeTransformerLanguageModelMixin(ExperimentSetup, ABC): 12 | 13 | @ex.capture(prefix="model") 14 | def _init_model(self, transformer_lm_encoder: dict, with_cuda: bool, output_nonlinearity=None, 15 | label_smoothing=None): 16 | config = self.generate_transformer_lm_encoder_config(transformer_lm_encoder) 17 | 18 | if label_smoothing is None: 19 | loss_fct = CrossEntropyLoss(ignore_index=-1) 20 | else: 21 | loss_fct = LabelSmoothingLoss(label_smoothing) 22 | 23 | self.model_lm = TransformerLanguageModel(config, output_nonlinearity=output_nonlinearity, loss_fct=loss_fct) 24 | self.model_manager = CodeTransformerLMModelManager() 25 | self.model_config = config 26 | if self.use_pretrained_model: 27 | self.model_lm.load_state_dict(self.pretrained_model_params) 28 | 29 | self.with_cuda = with_cuda 30 | -------------------------------------------------------------------------------- /code_transformer/experiments/mixins/xl_net_lm.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from torch.nn import CrossEntropyLoss 4 | 5 | from code_transformer.configuration.transformer_lm_encoder import TransformerLMEncoderConfig 6 | from code_transformer.experiments.experiment import ex 7 | from code_transformer.experiments.xl_net.base import XLNetExperimentSetup 8 | from code_transformer.modeling.modelmanager.xl_net import XLNetLMModelManager 9 | from code_transformer.modeling.xl_net.xl_net_language_model import XLNetLMEncoder, XLNetLanguageModel 10 | from code_transformer.utils.loss import LabelSmoothingLoss 11 | 12 | 13 | class XLNetLanguageModelingMixin(XLNetExperimentSetup, ABC): 14 | 15 | @ex.capture(prefix="model") 16 | def _init_model(self, transformer_lm_encoder: dict, with_cuda: bool, output_nonlinearity, label_smoothing=None): 17 | config = TransformerLMEncoderConfig(**transformer_lm_encoder) 18 | 19 | if hasattr(self, "word_vocab"): 20 | config.vocab_size = len(self.word_vocab) 21 | if hasattr(self, "token_type_vocab"): 22 | config.num_token_types = len(self.token_type_vocab) 23 | if hasattr(self, "node_type_vocab"): 24 | config.num_node_types = len(self.node_type_vocab) 25 | 26 | xl_net_lm_encoder = XLNetLMEncoder(config) 27 | 28 | if label_smoothing is None: 29 | loss_fct = CrossEntropyLoss(ignore_index=-1) 30 | else: 31 | loss_fct = LabelSmoothingLoss(label_smoothing) 32 | 33 | if hasattr(self.dataset_train, 'num_sub_tokens_output'): 34 | num_sub_tokens_output = self.dataset_train.num_sub_tokens_output 35 | else: 36 | num_sub_tokens_output = 5 37 | 38 | self.model_manager = XLNetLMModelManager() 39 | self.model_lm = XLNetLanguageModel(xl_net_lm_encoder, output_nonlinearity=output_nonlinearity, 40 | loss_fct=loss_fct, output_sub_tokens_per_token=num_sub_tokens_output) 41 | 42 | self.with_cuda = with_cuda 43 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/experiments/paper/__init__.py -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_go.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'go' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: True 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_java_pretrain_lm.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/language_modeling.py' 3 | 4 | data_setup: 5 | language: 'java-pretrain' 6 | num_predict: 2 7 | use_validation: True 8 | use_no_punctuation: True 9 | use_pointer_network: True 10 | num_sub_tokens: 5 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: 15 | - ppr 16 | - ancestor_sp 17 | - sibling_sp 18 | - shortest_paths 19 | distance_binning: 20 | type: 'exponential' 21 | growth_factor: 1.3 22 | n_fixed_bins: 9 23 | 24 | transfer_learning: 25 | use_pretrained_model: False 26 | model_type: 'ct_lm' 27 | run_id: 27 28 | snapshot_iteration: 'latest' 29 | 30 | model: 31 | with_cuda: True 32 | label_smoothing: 0.1 33 | lm_encoder: 34 | input_nonlinearity: 'tanh' 35 | num_languages: None 36 | transformer: 37 | num_layers: 3 38 | encoder_layer: 39 | d_model: 1024 40 | nhead: 8 41 | dim_feedforward: 2048 42 | dropout: 0 43 | activation: 'gelu' 44 | use_content_content: True 45 | use_content_pos: True 46 | use_pos_content: True 47 | use_pos_pos: True 48 | use_token_distances: True 49 | lm_decoder: 50 | output_nonlinearity: None 51 | n_layers: 1 52 | decoder_dropout: 0 53 | decoder_nhead: 8 54 | decoder_dim_feedforward: 2048 55 | decoder_activation: 'gelu' 56 | use_teacher_forcing: True 57 | pointer_attention_type: 'additive' 58 | use_pointer_query_linear: False 59 | use_pointer_query_self_attention: False 60 | attend_cls_token: False 61 | 62 | optimizer: 63 | learning_rate: 8e-5 64 | reg_scale: 0 65 | 66 | training: 67 | random_seed: 123 68 | batch_size: 2 69 | simulated_batch_size: 128 70 | simulated_batch_size_valid: 1280 71 | validate_every: 1000 72 | persistent_snapshot_every: 50000 73 | max_validation_samples: 10000 74 | metrics: 75 | - top1_accuracy 76 | - top5_accuracy 77 | - non_trivial_accuracy 78 | - precision 79 | - recall 80 | - f1_score 81 | - rouge_2 82 | - rouge_l 83 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_java_small.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: False 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0.1 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 50 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_java_small_ablation_only_ancestor_sp.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: True 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ancestor_sp 23 | 24 | distance_binning: 25 | type: 'exponential' 26 | growth_factor: 1.3 27 | n_fixed_bins: 9 28 | 29 | transfer_learning: 30 | use_pretrained_model: False 31 | model_type: 'ct_code_summarization' 32 | run_id: CT-23 33 | snapshot_iteration: 10 34 | cpu: False 35 | freeze_encoder_layers: None 36 | 37 | model: 38 | with_cuda: True 39 | label_smoothing: 0.1 40 | lm_encoder: 41 | input_nonlinearity: 'tanh' 42 | num_languages: None 43 | transformer: 44 | num_layers: 3 45 | encoder_layer: 46 | d_model: 1024 47 | nhead: 8 48 | dim_feedforward: 2048 49 | dropout: 0.2 50 | activation: 'gelu' 51 | use_content_content: True 52 | use_content_pos: True 53 | use_pos_content: True 54 | use_pos_pos: True 55 | use_token_distances: False 56 | lm_decoder: 57 | output_nonlinearity: None 58 | n_layers: 1 59 | decoder_dropout: 0 60 | decoder_nhead: 8 61 | decoder_dim_feedforward: 2048 62 | decoder_activation: 'gelu' 63 | use_teacher_forcing: True 64 | pointer_attention_type: 'additive' 65 | use_pointer_query_linear: False 66 | use_pointer_query_self_attention: False 67 | concat_query_and_pointer: True 68 | attend_cls_token: True 69 | 70 | optimizer: 71 | optimizer: 'Adam' 72 | learning_rate: 8e-5 73 | reg_scale: 3e-5 74 | 75 | training: 76 | random_seed: 456 77 | batch_size: 8 78 | simulated_batch_size: 128 79 | simulated_batch_size_valid: 1280 80 | accumulate_tokens_batch: False 81 | validate_every: 100 82 | persistent_snapshot_every: 10000 83 | early_stopping_patience: 20 84 | max_validation_samples: 50000 85 | metrics: 86 | - top1_accuracy 87 | - top5_accuracy 88 | - non_trivial_accuracy 89 | - precision 90 | - recall 91 | - f1_score 92 | - micro_f1_score 93 | - rouge_2 94 | - rouge_l 95 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_java_small_ablation_only_ppr.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: True 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | 24 | distance_binning: 25 | type: 'exponential' 26 | growth_factor: 1.3 27 | n_fixed_bins: 9 28 | 29 | transfer_learning: 30 | use_pretrained_model: False 31 | model_type: 'ct_code_summarization' 32 | run_id: CT-23 33 | snapshot_iteration: 10 34 | cpu: False 35 | freeze_encoder_layers: None 36 | 37 | model: 38 | with_cuda: True 39 | label_smoothing: 0.1 40 | lm_encoder: 41 | input_nonlinearity: 'tanh' 42 | num_languages: None 43 | transformer: 44 | num_layers: 3 45 | encoder_layer: 46 | d_model: 1024 47 | nhead: 8 48 | dim_feedforward: 2048 49 | dropout: 0.2 50 | activation: 'gelu' 51 | use_content_content: True 52 | use_content_pos: True 53 | use_pos_content: True 54 | use_pos_pos: True 55 | use_token_distances: False 56 | lm_decoder: 57 | output_nonlinearity: None 58 | n_layers: 1 59 | decoder_dropout: 0 60 | decoder_nhead: 8 61 | decoder_dim_feedforward: 2048 62 | decoder_activation: 'gelu' 63 | use_teacher_forcing: True 64 | pointer_attention_type: 'additive' 65 | use_pointer_query_linear: False 66 | use_pointer_query_self_attention: False 67 | concat_query_and_pointer: True 68 | attend_cls_token: True 69 | 70 | optimizer: 71 | optimizer: 'Adam' 72 | learning_rate: 8e-5 73 | reg_scale: 3e-5 74 | 75 | training: 76 | random_seed: 456 77 | batch_size: 8 78 | simulated_batch_size: 128 79 | simulated_batch_size_valid: 1280 80 | accumulate_tokens_batch: False 81 | validate_every: 100 82 | persistent_snapshot_every: 10000 83 | early_stopping_patience: 20 84 | max_validation_samples: 50000 85 | metrics: 86 | - top1_accuracy 87 | - top5_accuracy 88 | - non_trivial_accuracy 89 | - precision 90 | - recall 91 | - f1_score 92 | - micro_f1_score 93 | - rouge_2 94 | - rouge_l 95 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_java_small_ablation_only_shortest_paths.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: True 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - shortest_paths 23 | 24 | distance_binning: 25 | type: 'exponential' 26 | growth_factor: 1.3 27 | n_fixed_bins: 9 28 | 29 | transfer_learning: 30 | use_pretrained_model: False 31 | model_type: 'ct_code_summarization' 32 | run_id: CT-23 33 | snapshot_iteration: 10 34 | cpu: False 35 | freeze_encoder_layers: None 36 | 37 | model: 38 | with_cuda: True 39 | label_smoothing: 0.1 40 | lm_encoder: 41 | input_nonlinearity: 'tanh' 42 | num_languages: None 43 | transformer: 44 | num_layers: 3 45 | encoder_layer: 46 | d_model: 1024 47 | nhead: 8 48 | dim_feedforward: 2048 49 | dropout: 0.2 50 | activation: 'gelu' 51 | use_content_content: True 52 | use_content_pos: True 53 | use_pos_content: True 54 | use_pos_pos: True 55 | use_token_distances: False 56 | lm_decoder: 57 | output_nonlinearity: None 58 | n_layers: 1 59 | decoder_dropout: 0 60 | decoder_nhead: 8 61 | decoder_dim_feedforward: 2048 62 | decoder_activation: 'gelu' 63 | use_teacher_forcing: True 64 | pointer_attention_type: 'additive' 65 | use_pointer_query_linear: False 66 | use_pointer_query_self_attention: False 67 | concat_query_and_pointer: True 68 | attend_cls_token: True 69 | 70 | optimizer: 71 | optimizer: 'Adam' 72 | learning_rate: 8e-5 73 | reg_scale: 3e-5 74 | 75 | training: 76 | random_seed: 456 77 | batch_size: 8 78 | simulated_batch_size: 128 79 | simulated_batch_size_valid: 1280 80 | accumulate_tokens_batch: False 81 | validate_every: 100 82 | persistent_snapshot_every: 10000 83 | early_stopping_patience: 20 84 | max_validation_samples: 50000 85 | metrics: 86 | - top1_accuracy 87 | - top5_accuracy 88 | - non_trivial_accuracy 89 | - precision 90 | - recall 91 | - f1_score 92 | - micro_f1_score 93 | - rouge_2 94 | - rouge_l 95 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_java_small_ablation_only_sibling_sp.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: True 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - sibling_sp 23 | 24 | distance_binning: 25 | type: 'exponential' 26 | growth_factor: 1.3 27 | n_fixed_bins: 9 28 | 29 | transfer_learning: 30 | use_pretrained_model: False 31 | model_type: 'ct_code_summarization' 32 | run_id: CT-23 33 | snapshot_iteration: 10 34 | cpu: False 35 | freeze_encoder_layers: None 36 | 37 | model: 38 | with_cuda: True 39 | label_smoothing: 0.1 40 | lm_encoder: 41 | input_nonlinearity: 'tanh' 42 | num_languages: None 43 | transformer: 44 | num_layers: 3 45 | encoder_layer: 46 | d_model: 1024 47 | nhead: 8 48 | dim_feedforward: 2048 49 | dropout: 0.2 50 | activation: 'gelu' 51 | use_content_content: True 52 | use_content_pos: True 53 | use_pos_content: True 54 | use_pos_pos: True 55 | use_token_distances: False 56 | lm_decoder: 57 | output_nonlinearity: None 58 | n_layers: 1 59 | decoder_dropout: 0 60 | decoder_nhead: 8 61 | decoder_dim_feedforward: 2048 62 | decoder_activation: 'gelu' 63 | use_teacher_forcing: True 64 | pointer_attention_type: 'additive' 65 | use_pointer_query_linear: False 66 | use_pointer_query_self_attention: False 67 | concat_query_and_pointer: True 68 | attend_cls_token: True 69 | 70 | optimizer: 71 | optimizer: 'Adam' 72 | learning_rate: 8e-5 73 | reg_scale: 3e-5 74 | 75 | training: 76 | random_seed: 456 77 | batch_size: 8 78 | simulated_batch_size: 128 79 | simulated_batch_size_valid: 1280 80 | accumulate_tokens_batch: False 81 | validate_every: 100 82 | persistent_snapshot_every: 10000 83 | early_stopping_patience: 20 84 | max_validation_samples: 50000 85 | metrics: 86 | - top1_accuracy 87 | - top5_accuracy 88 | - non_trivial_accuracy 89 | - precision 90 | - recall 91 | - f1_score 92 | - micro_f1_score 93 | - rouge_2 94 | - rouge_l 95 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_java_small_only_ast.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: True 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_java_small_pretrain.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small-pretrain' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: True 15 | shuffle: True 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: True 34 | model_type: 'ct_lm' 35 | run_id: CT-LM-1 36 | snapshot_iteration: "latest" 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_javascript.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'javascript' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_multilang.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,go,ruby' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: 4 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_multilang_go.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,go,ruby' 6 | filter_language: 'go' 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: True 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-10 36 | snapshot_iteration: 260000 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: 4 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_multilang_lm.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/language_modeling.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,go,ruby' 6 | num_predict: 2 7 | use_validation: True 8 | use_no_punctuation: True 9 | use_pointer_network: True 10 | num_sub_tokens: 5 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: 15 | - ppr 16 | - ancestor_sp 17 | - sibling_sp 18 | - shortest_paths 19 | distance_binning: 20 | type: 'exponential' 21 | growth_factor: 1.3 22 | n_fixed_bins: 9 23 | 24 | transfer_learning: 25 | use_pretrained_model: False 26 | model_type: 'ct_lm' 27 | run_id: 27 28 | snapshot_iteration: 'latest' 29 | 30 | model: 31 | with_cuda: True 32 | label_smoothing: 0.1 33 | lm_encoder: 34 | input_nonlinearity: 'tanh' 35 | num_languages: 4 36 | transformer: 37 | num_layers: 3 38 | encoder_layer: 39 | d_model: 1024 40 | nhead: 8 41 | dim_feedforward: 2048 42 | dropout: 0 43 | activation: 'gelu' 44 | use_content_content: True 45 | use_content_pos: True 46 | use_pos_content: True 47 | use_pos_pos: True 48 | use_token_distances: True 49 | lm_decoder: 50 | output_nonlinearity: None 51 | n_layers: 1 52 | decoder_dropout: 0 53 | decoder_nhead: 8 54 | decoder_dim_feedforward: 2048 55 | decoder_activation: 'gelu' 56 | use_teacher_forcing: True 57 | pointer_attention_type: 'additive' 58 | use_pointer_query_linear: False 59 | use_pointer_query_self_attention: False 60 | attend_cls_token: True 61 | 62 | optimizer: 63 | learning_rate: 8e-5 64 | reg_scale: 0 65 | 66 | training: 67 | random_seed: 123 68 | batch_size: 2 69 | simulated_batch_size: 128 70 | simulated_batch_size_valid: 1280 71 | validate_every: 1000 72 | persistent_snapshot_every: 50000 73 | max_validation_samples: 10000 74 | metrics: 75 | - top1_accuracy 76 | - top5_accuracy 77 | - non_trivial_accuracy 78 | - precision 79 | - recall 80 | - f1_score 81 | - rouge_2 82 | - rouge_l 83 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_multilang_lm_pretrain.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,go,ruby' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: True 34 | model_type: 'ct_lm' 35 | run_id: CT-LM-2 36 | snapshot_iteration: "latest" 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: 4 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: True 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_multilang_python.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,go,ruby' 6 | filter_language: 'python' 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: True 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-10 36 | snapshot_iteration: 260000 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: 4 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_multilang_ruby.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,go,ruby' 6 | filter_language: 'ruby' 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: True 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-10 36 | snapshot_iteration: 260000 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: 4 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_no_pointer_go.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'go' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: True 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_no_pointer_java_small.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_no_pointer_java_small_only_ast.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: True 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: False 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_no_pointer_javascript.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'javascript' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: True 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_no_pointer_multilang.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,go,ruby' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: 4 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: True 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_no_pointer_python.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: True 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_no_pointer_ruby.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'ruby' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: False 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: False 70 | concat_query_and_pointer: True 71 | attend_cls_token: True 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_python.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: True 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/ct_ruby.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/code_transformer/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'ruby' 6 | filter_language: None 7 | use_validation: True 8 | num_sub_tokens: 5 9 | num_subtokens_output: 6 10 | use_only_ast: False 11 | mask_all_tokens: False 12 | use_no_punctuation: True 13 | use_pointer_network: True 14 | sort_by_length: False 15 | shuffle: False 16 | chunk_size: 32 17 | 18 | data_transforms: 19 | max_distance_mask: None 20 | 21 | relative_distances: 22 | - ppr 23 | - ancestor_sp 24 | - sibling_sp 25 | - shortest_paths 26 | 27 | distance_binning: 28 | type: 'exponential' 29 | growth_factor: 1.3 30 | n_fixed_bins: 9 31 | 32 | transfer_learning: 33 | use_pretrained_model: False 34 | model_type: 'ct_code_summarization' 35 | run_id: CT-23 36 | snapshot_iteration: 10 37 | cpu: False 38 | freeze_encoder_layers: None 39 | 40 | model: 41 | with_cuda: True 42 | label_smoothing: 0.1 43 | lm_encoder: 44 | input_nonlinearity: 'tanh' 45 | num_languages: None 46 | transformer: 47 | num_layers: 3 48 | encoder_layer: 49 | d_model: 1024 50 | nhead: 8 51 | dim_feedforward: 2048 52 | dropout: 0.2 53 | activation: 'gelu' 54 | use_content_content: True 55 | use_content_pos: True 56 | use_pos_content: True 57 | use_pos_pos: True 58 | use_token_distances: True 59 | lm_decoder: 60 | output_nonlinearity: None 61 | n_layers: 1 62 | decoder_dropout: 0 63 | decoder_nhead: 8 64 | decoder_dim_feedforward: 2048 65 | decoder_activation: 'gelu' 66 | use_teacher_forcing: True 67 | pointer_attention_type: 'additive' 68 | use_pointer_query_linear: False 69 | use_pointer_query_self_attention: True 70 | concat_query_and_pointer: True 71 | attend_cls_token: False 72 | 73 | optimizer: 74 | optimizer: 'Adam' 75 | learning_rate: 8e-5 76 | reg_scale: 3e-5 77 | 78 | training: 79 | random_seed: 456 80 | batch_size: 8 81 | simulated_batch_size: 128 82 | simulated_batch_size_valid: 1280 83 | accumulate_tokens_batch: False 84 | validate_every: 100 85 | persistent_snapshot_every: 10000 86 | early_stopping_patience: 20 87 | max_validation_samples: 50000 88 | metrics: 89 | - top1_accuracy 90 | - top5_accuracy 91 | - non_trivial_accuracy 92 | - precision 93 | - recall 94 | - f1_score 95 | - micro_f1_score 96 | - rouge_2 97 | - rouge_l 98 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/great_go.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/great/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'go' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_pointer_network: True 10 | 11 | data_transforms: 12 | max_distance_mask: None 13 | relative_distances: 14 | - ppr 15 | - ancestor_sp 16 | - sibling_sp 17 | - shortest_paths 18 | 19 | distance_binning: 20 | type: 'exponential' 21 | growth_factor: 1.3 22 | n_fixed_bins: 9 23 | 24 | model: 25 | with_cuda: True 26 | label_smoothing: 0.1 27 | lm_encoder: 28 | transformer_config: 29 | embed_dim: 1024 30 | num_layers: 3 31 | num_heads: 8 32 | ff_dim: 2048 33 | dropout_rate: 0.2 34 | lm_decoder: 35 | output_nonlinearity: None 36 | n_layers: 1 37 | decoder_dropout: 0 38 | decoder_nhead: 8 39 | decoder_dim_feedforward: 2048 40 | decoder_activation: 'gelu' 41 | use_teacher_forcing: True 42 | pointer_attention_type: 'additive' 43 | use_pointer_query_linear: False 44 | use_pointer_query_self_attention: False 45 | attend_cls_token: False 46 | 47 | optimizer: 48 | optimizer: 'Adam' 49 | learning_rate: 8e-5 50 | reg_scale: 3e-5 51 | 52 | training: 53 | random_seed: 456 54 | batch_size: 8 55 | simulated_batch_size: 128 56 | simulated_batch_size_valid: 1280 57 | accumulate_tokens_batch: False 58 | validate_every: 100 59 | persistent_snapshot_every: 10000 60 | early_stopping_patience: 20 61 | max_validation_samples: 50000 62 | metrics: 63 | - top1_accuracy 64 | - top5_accuracy 65 | - non_trivial_accuracy 66 | - precision 67 | - recall 68 | - f1_score 69 | - micro_f1_score 70 | - rouge-2 71 | - rouge-l 72 | 73 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/great_java_small.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/great/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_pointer_network: True 10 | 11 | data_transforms: 12 | max_distance_mask: None 13 | relative_distances: 14 | - ppr 15 | - ancestor_sp 16 | - sibling_sp 17 | - shortest_paths 18 | 19 | distance_binning: 20 | type: 'exponential' 21 | growth_factor: 1.3 22 | n_fixed_bins: 9 23 | 24 | model: 25 | with_cuda: True 26 | label_smoothing: 0.1 27 | lm_encoder: 28 | transformer_config: 29 | embed_dim: 1024 30 | num_layers: 3 31 | num_heads: 8 32 | ff_dim: 2048 33 | dropout_rate: 0.2 34 | lm_decoder: 35 | output_nonlinearity: None 36 | n_layers: 1 37 | decoder_dropout: 0 38 | decoder_nhead: 8 39 | decoder_dim_feedforward: 2048 40 | decoder_activation: 'gelu' 41 | use_teacher_forcing: True 42 | pointer_attention_type: 'additive' 43 | use_pointer_query_linear: False 44 | use_pointer_query_self_attention: False 45 | attend_cls_token: False 46 | 47 | optimizer: 48 | optimizer: 'Adam' 49 | learning_rate: 8e-5 50 | reg_scale: 3e-5 51 | 52 | training: 53 | random_seed: 456 54 | batch_size: 8 55 | simulated_batch_size: 128 56 | simulated_batch_size_valid: 1280 57 | accumulate_tokens_batch: False 58 | validate_every: 100 59 | persistent_snapshot_every: 10000 60 | early_stopping_patience: 20 61 | max_validation_samples: 50000 62 | metrics: 63 | - top1_accuracy 64 | - top5_accuracy 65 | - non_trivial_accuracy 66 | - precision 67 | - recall 68 | - f1_score 69 | - micro_f1_score 70 | - rouge-2 71 | - rouge-l 72 | 73 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/great_javascript.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/great/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'javascript' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_pointer_network: True 10 | 11 | data_transforms: 12 | max_distance_mask: None 13 | relative_distances: 14 | - ppr 15 | - ancestor_sp 16 | - sibling_sp 17 | - shortest_paths 18 | 19 | distance_binning: 20 | type: 'exponential' 21 | growth_factor: 1.3 22 | n_fixed_bins: 9 23 | 24 | model: 25 | with_cuda: True 26 | label_smoothing: 0.1 27 | lm_encoder: 28 | transformer_config: 29 | embed_dim: 1024 30 | num_layers: 3 31 | num_heads: 8 32 | ff_dim: 2048 33 | dropout_rate: 0.2 34 | lm_decoder: 35 | output_nonlinearity: None 36 | n_layers: 1 37 | decoder_dropout: 0 38 | decoder_nhead: 8 39 | decoder_dim_feedforward: 2048 40 | decoder_activation: 'gelu' 41 | use_teacher_forcing: True 42 | pointer_attention_type: 'additive' 43 | use_pointer_query_linear: False 44 | use_pointer_query_self_attention: False 45 | attend_cls_token: False 46 | 47 | optimizer: 48 | optimizer: 'Adam' 49 | learning_rate: 8e-5 50 | reg_scale: 3e-5 51 | 52 | training: 53 | random_seed: 456 54 | batch_size: 8 55 | simulated_batch_size: 128 56 | simulated_batch_size_valid: 1280 57 | accumulate_tokens_batch: False 58 | validate_every: 100 59 | persistent_snapshot_every: 10000 60 | early_stopping_patience: 20 61 | max_validation_samples: 50000 62 | metrics: 63 | - top1_accuracy 64 | - top5_accuracy 65 | - non_trivial_accuracy 66 | - precision 67 | - recall 68 | - f1_score 69 | - micro_f1_score 70 | - rouge-2 71 | - rouge-l 72 | 73 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/great_multilang.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/great/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,go,ruby' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_pointer_network: True 10 | 11 | data_transforms: 12 | max_distance_mask: None 13 | relative_distances: 14 | - ppr 15 | - ancestor_sp 16 | - sibling_sp 17 | - shortest_paths 18 | 19 | distance_binning: 20 | type: 'exponential' 21 | growth_factor: 1.3 22 | n_fixed_bins: 9 23 | 24 | model: 25 | with_cuda: True 26 | label_smoothing: 0.1 27 | lm_encoder: 28 | transformer_config: 29 | embed_dim: 1024 30 | num_layers: 3 31 | num_heads: 8 32 | ff_dim: 2048 33 | dropout_rate: 0.2 34 | lm_decoder: 35 | output_nonlinearity: None 36 | n_layers: 1 37 | decoder_dropout: 0 38 | decoder_nhead: 8 39 | decoder_dim_feedforward: 2048 40 | decoder_activation: 'gelu' 41 | use_teacher_forcing: True 42 | pointer_attention_type: 'additive' 43 | use_pointer_query_linear: False 44 | use_pointer_query_self_attention: False 45 | attend_cls_token: True 46 | 47 | optimizer: 48 | optimizer: 'Adam' 49 | learning_rate: 8e-5 50 | reg_scale: 3e-5 51 | 52 | training: 53 | random_seed: 456 54 | batch_size: 8 55 | simulated_batch_size: 128 56 | simulated_batch_size_valid: 1280 57 | accumulate_tokens_batch: False 58 | validate_every: 100 59 | persistent_snapshot_every: 10000 60 | early_stopping_patience: 20 61 | max_validation_samples: 50000 62 | metrics: 63 | - top1_accuracy 64 | - top5_accuracy 65 | - non_trivial_accuracy 66 | - precision 67 | - recall 68 | - f1_score 69 | - micro_f1_score 70 | - rouge-2 71 | - rouge-l 72 | 73 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/great_python.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/great/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_pointer_network: True 10 | 11 | data_transforms: 12 | max_distance_mask: None 13 | relative_distances: 14 | - ppr 15 | - ancestor_sp 16 | - sibling_sp 17 | - shortest_paths 18 | 19 | distance_binning: 20 | type: 'exponential' 21 | growth_factor: 1.3 22 | n_fixed_bins: 9 23 | 24 | model: 25 | with_cuda: True 26 | label_smoothing: 0.1 27 | lm_encoder: 28 | transformer_config: 29 | embed_dim: 1024 30 | num_layers: 3 31 | num_heads: 8 32 | ff_dim: 2048 33 | dropout_rate: 0.2 34 | lm_decoder: 35 | output_nonlinearity: None 36 | n_layers: 1 37 | decoder_dropout: 0 38 | decoder_nhead: 8 39 | decoder_dim_feedforward: 2048 40 | decoder_activation: 'gelu' 41 | use_teacher_forcing: True 42 | pointer_attention_type: 'additive' 43 | use_pointer_query_linear: False 44 | use_pointer_query_self_attention: False 45 | attend_cls_token: False 46 | 47 | optimizer: 48 | optimizer: 'Adam' 49 | learning_rate: 8e-5 50 | reg_scale: 3e-5 51 | 52 | training: 53 | random_seed: 456 54 | batch_size: 8 55 | simulated_batch_size: 128 56 | simulated_batch_size_valid: 1280 57 | accumulate_tokens_batch: False 58 | validate_every: 100 59 | persistent_snapshot_every: 10000 60 | early_stopping_patience: 20 61 | max_validation_samples: 50000 62 | metrics: 63 | - top1_accuracy 64 | - top5_accuracy 65 | - non_trivial_accuracy 66 | - precision 67 | - recall 68 | - f1_score 69 | - micro_f1_score 70 | - rouge-2 71 | - rouge-l 72 | 73 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/great_ruby.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/great/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'ruby' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_pointer_network: True 10 | 11 | data_transforms: 12 | max_distance_mask: None 13 | relative_distances: 14 | - ppr 15 | - ancestor_sp 16 | - sibling_sp 17 | - shortest_paths 18 | 19 | distance_binning: 20 | type: 'exponential' 21 | growth_factor: 1.3 22 | n_fixed_bins: 9 23 | 24 | model: 25 | with_cuda: True 26 | label_smoothing: 0.1 27 | lm_encoder: 28 | transformer_config: 29 | embed_dim: 1024 30 | num_layers: 3 31 | num_heads: 8 32 | ff_dim: 2048 33 | dropout_rate: 0.2 34 | lm_decoder: 35 | output_nonlinearity: None 36 | n_layers: 1 37 | decoder_dropout: 0 38 | decoder_nhead: 8 39 | decoder_dim_feedforward: 2048 40 | decoder_activation: 'gelu' 41 | use_teacher_forcing: True 42 | pointer_attention_type: 'additive' 43 | use_pointer_query_linear: False 44 | use_pointer_query_self_attention: False 45 | attend_cls_token: False 46 | 47 | optimizer: 48 | optimizer: 'Adam' 49 | learning_rate: 8e-5 50 | reg_scale: 3e-5 51 | 52 | training: 53 | random_seed: 456 54 | batch_size: 8 55 | simulated_batch_size: 128 56 | simulated_batch_size_valid: 1280 57 | accumulate_tokens_batch: False 58 | validate_every: 100 59 | persistent_snapshot_every: 10000 60 | early_stopping_patience: 20 61 | max_validation_samples: 50000 62 | metrics: 63 | - top1_accuracy 64 | - top5_accuracy 65 | - non_trivial_accuracy 66 | - precision 67 | - recall 68 | - f1_score 69 | - micro_f1_score 70 | - rouge-2 71 | - rouge-l 72 | 73 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/xl_net_go.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/xl_net/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'go' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_no_punctuation: True 10 | use_pointer_network: True 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: None 15 | distance_binning: 16 | type: 'exponential' 17 | growth_factor: 1.3 18 | n_fixed_bins: 9 19 | 20 | transfer_learning: 21 | use_pretrained_model: False 22 | model_type: 'xl_net_lm' 23 | run_id: 4 24 | snapshot_iteration: 'latest' 25 | cpu: False 26 | freeze_encoder_layers: None 27 | 28 | model: 29 | with_cuda: True 30 | label_smoothing: 0.1 31 | lm_encoder: 32 | subtokens_per_token: 5 33 | num_languages: None 34 | input_nonlinearity: 'tanh' 35 | transformer: 36 | d_model: 1024 37 | n_layer: 3 38 | n_head: 8 39 | d_inner: 2048 40 | ff_activation: 'gelu' 41 | dropout: 0.2 42 | mem_len: 1024 43 | lm_decoder: 44 | output_nonlinearity: None 45 | n_layers: 1 46 | decoder_dropout: 0 47 | decoder_nhead: 8 48 | decoder_dim_feedforward: 2048 49 | decoder_activation: 'gelu' 50 | use_teacher_forcing: True 51 | pointer_attention_type: 'additive' 52 | use_pointer_query_linear: False 53 | use_pointer_query_self_attention: False 54 | attend_cls_token: False 55 | 56 | optimizer: 57 | optimizer: 'Adam' 58 | learning_rate: 8e-5 59 | reg_scale: 3e-5 60 | 61 | training: 62 | random_seed: 456 63 | batch_size: 8 64 | simulated_batch_size: 128 65 | simulated_batch_size_valid: 1280 66 | accumulate_tokens_batch: False 67 | validate_every: 100 68 | persistent_snapshot_every: 10000 69 | early_stopping_patience: 20 70 | max_validation_samples: 50000 71 | metrics: 72 | - top1_accuracy 73 | - top5_accuracy 74 | - non_trivial_accuracy 75 | - precision 76 | - recall 77 | - f1_score 78 | - micro_f1_score 79 | 80 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/xl_net_java_small.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/xl_net/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_no_punctuation: True 10 | use_pointer_network: True 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: None 15 | distance_binning: 16 | type: 'exponential' 17 | growth_factor: 1.3 18 | n_fixed_bins: 9 19 | 20 | transfer_learning: 21 | use_pretrained_model: False 22 | model_type: 'xl_net_lm' 23 | run_id: 4 24 | snapshot_iteration: 'latest' 25 | cpu: False 26 | freeze_encoder_layers: None 27 | 28 | model: 29 | with_cuda: True 30 | label_smoothing: 0.1 31 | lm_encoder: 32 | subtokens_per_token: 5 33 | num_languages: None 34 | input_nonlinearity: 'tanh' 35 | transformer: 36 | d_model: 1024 37 | n_layer: 3 38 | n_head: 8 39 | d_inner: 2048 40 | ff_activation: 'gelu' 41 | dropout: 0.2 42 | mem_len: 1024 43 | lm_decoder: 44 | output_nonlinearity: None 45 | n_layers: 1 46 | decoder_dropout: 0 47 | decoder_nhead: 8 48 | decoder_dim_feedforward: 2048 49 | decoder_activation: 'gelu' 50 | use_teacher_forcing: True 51 | pointer_attention_type: 'additive' 52 | use_pointer_query_linear: False 53 | use_pointer_query_self_attention: False 54 | attend_cls_token: False 55 | 56 | optimizer: 57 | optimizer: 'Adam' 58 | learning_rate: 8e-5 59 | reg_scale: 3e-5 60 | 61 | training: 62 | random_seed: 456 63 | batch_size: 8 64 | simulated_batch_size: 128 65 | simulated_batch_size_valid: 1280 66 | accumulate_tokens_batch: False 67 | validate_every: 100 68 | persistent_snapshot_every: 10000 69 | early_stopping_patience: 20 70 | max_validation_samples: 50000 71 | metrics: 72 | - top1_accuracy 73 | - top5_accuracy 74 | - non_trivial_accuracy 75 | - precision 76 | - recall 77 | - f1_score 78 | - micro_f1_score 79 | 80 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/xl_net_javascript.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/xl_net/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'javascript' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_no_punctuation: True 10 | use_pointer_network: True 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: None 15 | distance_binning: 16 | type: 'exponential' 17 | growth_factor: 1.3 18 | n_fixed_bins: 9 19 | 20 | transfer_learning: 21 | use_pretrained_model: False 22 | model_type: 'xl_net_lm' 23 | run_id: 4 24 | snapshot_iteration: 'latest' 25 | cpu: False 26 | freeze_encoder_layers: None 27 | 28 | model: 29 | with_cuda: True 30 | label_smoothing: 0.1 31 | lm_encoder: 32 | subtokens_per_token: 5 33 | num_languages: None 34 | input_nonlinearity: 'tanh' 35 | transformer: 36 | d_model: 1024 37 | n_layer: 3 38 | n_head: 8 39 | d_inner: 2048 40 | ff_activation: 'gelu' 41 | dropout: 0.2 42 | mem_len: 1024 43 | lm_decoder: 44 | output_nonlinearity: None 45 | n_layers: 1 46 | decoder_dropout: 0 47 | decoder_nhead: 8 48 | decoder_dim_feedforward: 2048 49 | decoder_activation: 'gelu' 50 | use_teacher_forcing: True 51 | pointer_attention_type: 'additive' 52 | use_pointer_query_linear: False 53 | use_pointer_query_self_attention: False 54 | attend_cls_token: False 55 | 56 | optimizer: 57 | optimizer: 'Adam' 58 | learning_rate: 8e-5 59 | reg_scale: 3e-5 60 | 61 | training: 62 | random_seed: 456 63 | batch_size: 4 64 | simulated_batch_size: 128 65 | simulated_batch_size_valid: 1280 66 | accumulate_tokens_batch: False 67 | validate_every: 100 68 | persistent_snapshot_every: 10000 69 | early_stopping_patience: 20 70 | max_validation_samples: 50000 71 | metrics: 72 | - top1_accuracy 73 | - top5_accuracy 74 | - non_trivial_accuracy 75 | - precision 76 | - recall 77 | - f1_score 78 | - micro_f1_score 79 | 80 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/xl_net_multilang.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/xl_net/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,go,ruby' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_no_punctuation: True 10 | use_pointer_network: True 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: None 15 | distance_binning: 16 | type: 'exponential' 17 | growth_factor: 1.3 18 | n_fixed_bins: 9 19 | 20 | transfer_learning: 21 | use_pretrained_model: False 22 | model_type: 'xl_net_lm' 23 | run_id: 4 24 | snapshot_iteration: 'latest' 25 | cpu: False 26 | freeze_encoder_layers: None 27 | 28 | model: 29 | with_cuda: True 30 | label_smoothing: 0.1 31 | lm_encoder: 32 | subtokens_per_token: 5 33 | num_languages: 4 34 | input_nonlinearity: 'tanh' 35 | transformer: 36 | d_model: 1024 37 | n_layer: 3 38 | n_head: 8 39 | d_inner: 2048 40 | ff_activation: 'gelu' 41 | dropout: 0.2 42 | mem_len: 1024 43 | lm_decoder: 44 | output_nonlinearity: None 45 | n_layers: 1 46 | decoder_dropout: 0 47 | decoder_nhead: 8 48 | decoder_dim_feedforward: 2048 49 | decoder_activation: 'gelu' 50 | use_teacher_forcing: True 51 | pointer_attention_type: 'additive' 52 | use_pointer_query_linear: False 53 | use_pointer_query_self_attention: False 54 | attend_cls_token: True 55 | 56 | optimizer: 57 | optimizer: 'Adam' 58 | learning_rate: 8e-5 59 | reg_scale: 3e-5 60 | 61 | training: 62 | random_seed: 456 63 | batch_size: 8 64 | simulated_batch_size: 128 65 | simulated_batch_size_valid: 1280 66 | accumulate_tokens_batch: False 67 | validate_every: 100 68 | persistent_snapshot_every: 10000 69 | early_stopping_patience: 20 70 | max_validation_samples: 50000 71 | metrics: 72 | - top1_accuracy 73 | - top5_accuracy 74 | - non_trivial_accuracy 75 | - precision 76 | - recall 77 | - f1_score 78 | - micro_f1_score 79 | 80 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/xl_net_no_pointer_java_small.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/xl_net/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_no_punctuation: True 10 | use_pointer_network: False 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: None 15 | distance_binning: 16 | type: 'exponential' 17 | growth_factor: 1.3 18 | n_fixed_bins: 9 19 | 20 | transfer_learning: 21 | use_pretrained_model: False 22 | model_type: 'xl_net_lm' 23 | run_id: 4 24 | snapshot_iteration: 'latest' 25 | cpu: False 26 | freeze_encoder_layers: None 27 | 28 | model: 29 | with_cuda: True 30 | label_smoothing: 0.1 31 | lm_encoder: 32 | subtokens_per_token: 5 33 | num_languages: None 34 | input_nonlinearity: 'tanh' 35 | transformer: 36 | d_model: 1024 37 | n_layer: 3 38 | n_head: 8 39 | d_inner: 2048 40 | ff_activation: 'gelu' 41 | dropout: 0.2 42 | mem_len: 1024 43 | lm_decoder: 44 | output_nonlinearity: None 45 | n_layers: 1 46 | decoder_dropout: 0 47 | decoder_nhead: 8 48 | decoder_dim_feedforward: 2048 49 | decoder_activation: 'gelu' 50 | use_teacher_forcing: True 51 | pointer_attention_type: 'additive' 52 | use_pointer_query_linear: False 53 | use_pointer_query_self_attention: False 54 | attend_cls_token: False 55 | 56 | optimizer: 57 | optimizer: 'Adam' 58 | learning_rate: 8e-5 59 | reg_scale: 3e-5 60 | 61 | training: 62 | random_seed: 456 63 | batch_size: 4 64 | simulated_batch_size: 128 65 | simulated_batch_size_valid: 1280 66 | accumulate_tokens_batch: False 67 | validate_every: 100 68 | persistent_snapshot_every: 10000 69 | early_stopping_patience: 20 70 | max_validation_samples: 50000 71 | metrics: 72 | - top1_accuracy 73 | - top5_accuracy 74 | - non_trivial_accuracy 75 | - precision 76 | - recall 77 | - f1_score 78 | - micro_f1_score 79 | 80 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/xl_net_python.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/xl_net/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_no_punctuation: True 10 | use_pointer_network: True 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: None 15 | distance_binning: 16 | type: 'exponential' 17 | growth_factor: 1.3 18 | n_fixed_bins: 9 19 | 20 | transfer_learning: 21 | use_pretrained_model: False 22 | model_type: 'xl_net_lm' 23 | run_id: 4 24 | snapshot_iteration: 'latest' 25 | cpu: False 26 | freeze_encoder_layers: None 27 | 28 | model: 29 | with_cuda: True 30 | label_smoothing: 0.1 31 | lm_encoder: 32 | subtokens_per_token: 5 33 | num_languages: None 34 | input_nonlinearity: 'tanh' 35 | transformer: 36 | d_model: 1024 37 | n_layer: 3 38 | n_head: 8 39 | d_inner: 2048 40 | ff_activation: 'gelu' 41 | dropout: 0.2 42 | mem_len: 1024 43 | lm_decoder: 44 | output_nonlinearity: None 45 | n_layers: 1 46 | decoder_dropout: 0 47 | decoder_nhead: 8 48 | decoder_dim_feedforward: 2048 49 | decoder_activation: 'gelu' 50 | use_teacher_forcing: True 51 | pointer_attention_type: 'additive' 52 | use_pointer_query_linear: False 53 | use_pointer_query_self_attention: False 54 | attend_cls_token: False 55 | 56 | optimizer: 57 | optimizer: 'Adam' 58 | learning_rate: 8e-5 59 | reg_scale: 3e-5 60 | 61 | training: 62 | random_seed: 456 63 | batch_size: 8 64 | simulated_batch_size: 128 65 | simulated_batch_size_valid: 1280 66 | accumulate_tokens_batch: False 67 | validate_every: 100 68 | persistent_snapshot_every: 10000 69 | early_stopping_patience: 20 70 | max_validation_samples: 50000 71 | metrics: 72 | - top1_accuracy 73 | - top5_accuracy 74 | - non_trivial_accuracy 75 | - precision 76 | - recall 77 | - f1_score 78 | - micro_f1_score 79 | 80 | -------------------------------------------------------------------------------- /code_transformer/experiments/paper/xl_net_ruby.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/xl_net/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'ruby' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_no_punctuation: True 10 | use_pointer_network: True 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: None 15 | distance_binning: 16 | type: 'exponential' 17 | growth_factor: 1.3 18 | n_fixed_bins: 9 19 | 20 | transfer_learning: 21 | use_pretrained_model: False 22 | model_type: 'xl_net_lm' 23 | run_id: 4 24 | snapshot_iteration: 'latest' 25 | cpu: False 26 | freeze_encoder_layers: None 27 | 28 | model: 29 | with_cuda: True 30 | label_smoothing: 0.1 31 | lm_encoder: 32 | subtokens_per_token: 5 33 | num_languages: None 34 | input_nonlinearity: 'tanh' 35 | transformer: 36 | d_model: 1024 37 | n_layer: 3 38 | n_head: 8 39 | d_inner: 2048 40 | ff_activation: 'gelu' 41 | dropout: 0.2 42 | mem_len: 1024 43 | lm_decoder: 44 | output_nonlinearity: None 45 | n_layers: 1 46 | decoder_dropout: 0 47 | decoder_nhead: 8 48 | decoder_dim_feedforward: 2048 49 | decoder_activation: 'gelu' 50 | use_teacher_forcing: True 51 | pointer_attention_type: 'additive' 52 | use_pointer_query_linear: False 53 | use_pointer_query_self_attention: False 54 | attend_cls_token: False 55 | 56 | optimizer: 57 | optimizer: 'Adam' 58 | learning_rate: 8e-5 59 | reg_scale: 3e-5 60 | 61 | training: 62 | random_seed: 456 63 | batch_size: 8 64 | simulated_batch_size: 128 65 | simulated_batch_size_valid: 1280 66 | accumulate_tokens_batch: False 67 | validate_every: 100 68 | persistent_snapshot_every: 10000 69 | early_stopping_patience: 20 70 | max_validation_samples: 50000 71 | metrics: 72 | - top1_accuracy 73 | - top5_accuracy 74 | - non_trivial_accuracy 75 | - precision 76 | - recall 77 | - f1_score 78 | - micro_f1_score 79 | 80 | -------------------------------------------------------------------------------- /code_transformer/experiments/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/experiments/preprocessing/__init__.py -------------------------------------------------------------------------------- /code_transformer/experiments/preprocessing/preprocess-1-code2seq.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/preprocessing/preprocess-1.py' 3 | 4 | execution: 5 | num_processes: 15 6 | batch_size: 10 7 | save_every: 10000 # Processed data will be saved into zipped chunks of this size for easier handling 8 | random_seed: 123 9 | 10 | preprocessing: 11 | use_tokens_limiter: False # Whether input snippets should be discarded if they are too long, i.e., have too many tokens 12 | hard_num_tokens_limit: 10000 # hard tokens limit. Snippets with more tokens have to be dropped as generating an AST 13 | # would not be feasible for such long snippets 14 | allow_empty_methods: True # Sometimes, methods have no body 15 | separate_label_vocabulary: False # Whether a separate word counter should be computed that only contains words that 16 | # appeared in the method name -------------------------------------------------------------------------------- /code_transformer/experiments/preprocessing/preprocess-1-csn.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/preprocessing/preprocess-1.py' 3 | 4 | execution: 5 | num_processes: 10 6 | batch_size: 10 7 | save_every: 10000 # Processed data will be saved into zipped chunks of this size for easier handling 8 | random_seed: 123 9 | 10 | preprocessing: 11 | use_tokens_limiter: True # Whether input snippets should be discarded if they are too long, i.e., have too many tokens 12 | hard_num_tokens_limit: 10000 # hard tokens limit. Snippets with more tokens have to be dropped as generating an AST 13 | # would not be feasible for such long snippets 14 | allow_empty_methods: False # Sometimes, methods have no body 15 | separate_label_vocabulary: False # Whether a separate word counter should be computed that only contains words that 16 | # appeared in the method name -------------------------------------------------------------------------------- /code_transformer/experiments/preprocessing/preprocess-2.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/preprocessing/preprocess-2.py' 3 | 4 | execution: 5 | num_processes: 15 6 | batch_size: 100 7 | dataset_slice_size: 5000 8 | 9 | preprocessing: 10 | remove_punctuation: True 11 | max_num_tokens: 10000 12 | vocab_size: 32000 13 | min_vocabulary_frequency: 100 14 | separate_label_vocabulary: False 15 | vocab_size_labels: 5000 16 | min_vocabulary_frequency_labels: None 17 | 18 | distances: 19 | ppr_alpha: 0.15 20 | ppr_use_log: True 21 | ppr_threshold: 0.006737946999085467 # e^-5 22 | sp_threshold: None 23 | ancestor_sp_forward: True 24 | ancestor_sp_backward: True 25 | ancestor_sp_negative_reverse_dists: True 26 | ancestor_sp_threshold: None 27 | sibling_sp_forward: True 28 | sibling_sp_backward: True 29 | sibling_sp_negative_reverse_dists: True 30 | sibling_sp_threshold: None 31 | 32 | binning: 33 | num_bins: 32 34 | n_fixed_bins: 9 35 | exponential_binning: true 36 | exponential_binning_growth_factor: 1.3 37 | bin_padding: 0 38 | -------------------------------------------------------------------------------- /code_transformer/experiments/xl_net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/experiments/xl_net/__init__.py -------------------------------------------------------------------------------- /code_transformer/experiments/xl_net/base.py: -------------------------------------------------------------------------------- 1 | from code_transformer.experiments.experiment import ExperimentSetup 2 | 3 | 4 | class XLNetExperimentSetup(ExperimentSetup): 5 | 6 | def __init__(self): 7 | super(XLNetExperimentSetup, self).__init__() 8 | 9 | def _init_data_transforms(self, max_distance_mask=None, relative_distances=[], distance_binning={'type': 'regular', 'n_fixed_bins': 0}): 10 | super(XLNetExperimentSetup, self)._init_data_transforms(max_distance_mask, relative_distances, distance_binning) -------------------------------------------------------------------------------- /code_transformer/experiments/xl_net/code_summarization.py: -------------------------------------------------------------------------------- 1 | from code_transformer.experiments.experiment import ex 2 | from code_transformer.experiments.mixins.code_summarization import CTCodeSummarizationMixin 3 | from code_transformer.experiments.mixins.xl_net_transformer import XLNetTransformerMixin 4 | 5 | 6 | class XLNetCodeSummarizationTransformerExperimentSetup(CTCodeSummarizationMixin, XLNetTransformerMixin): 7 | pass 8 | 9 | 10 | @ex.automain 11 | def main(): 12 | experiment = XLNetCodeSummarizationTransformerExperimentSetup() 13 | experiment.train() 14 | 15 | 16 | @ex.command(unobserved=True) 17 | def recreate_experiment(): 18 | return XLNetCodeSummarizationTransformerExperimentSetup() -------------------------------------------------------------------------------- /code_transformer/experiments/xl_net/code_summarization.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/xl_net/code_summarization.py' 3 | 4 | data_setup: 5 | language: 'python,javascript,ruby,go' 6 | use_validation: True 7 | num_sub_tokens: 5 8 | num_subtokens_output: 6 9 | use_no_punctuation: True 10 | use_pointer_network: True 11 | 12 | data_transforms: 13 | max_distance_mask: None 14 | relative_distances: None 15 | distance_binning: 16 | type: 'exponential' 17 | growth_factor: 1.3 18 | n_fixed_bins: 9 19 | 20 | transfer_learning: 21 | use_pretrained_model: False 22 | model_type: 'xl_net_lm' 23 | run_id: 4 24 | snapshot_iteration: 'latest' 25 | cpu: False 26 | freeze_encoder_layers: None 27 | 28 | model: 29 | with_cuda: False 30 | label_smoothing: 0.1 31 | lm_encoder: 32 | subtokens_per_token: 5 33 | num_languages: 4 34 | input_nonlinearity: 'tanh' 35 | transformer: 36 | d_model: 16 37 | n_layer: 3 38 | n_head: 8 39 | d_inner: 16 40 | ff_activation: 'gelu' 41 | dropout: 0.2 42 | mem_len: 16 43 | lm_decoder: 44 | output_nonlinearity: None 45 | n_layers: 1 46 | decoder_dropout: 0 47 | decoder_nhead: 8 48 | decoder_dim_feedforward: 2048 49 | decoder_activation: 'gelu' 50 | use_teacher_forcing: True 51 | pointer_attention_type: 'additive' 52 | use_pointer_query_linear: False 53 | use_pointer_query_self_attention: False 54 | 55 | optimizer: 56 | optimizer: 'Adam' 57 | learning_rate: 8e-5 58 | reg_scale: 3e-5 59 | 60 | training: 61 | random_seed: 456 62 | batch_size: 8 63 | simulated_batch_size: 128 64 | simulated_batch_size_valid: 1280 65 | accumulate_tokens_batch: False 66 | validate_every: 100 67 | persistent_snapshot_every: 100 68 | early_stopping_patience: 20 69 | max_validation_samples: 50000 70 | metrics: 71 | - top1_accuracy 72 | - top5_accuracy 73 | - non_trivial_accuracy 74 | - precision 75 | - recall 76 | - f1_score 77 | - micro_f1_score 78 | 79 | -------------------------------------------------------------------------------- /code_transformer/experiments/xl_net/language_modeling.py: -------------------------------------------------------------------------------- 1 | from code_transformer.experiments.experiment import ex 2 | from code_transformer.experiments.mixins.xl_net_lm import XLNetLanguageModelingMixin 3 | 4 | 5 | class XLNetLanguageModelingExperimentSetup(XLNetLanguageModelingMixin): 6 | pass 7 | 8 | 9 | @ex.automain 10 | def main(): 11 | experiment = XLNetLanguageModelingExperimentSetup() 12 | experiment.train() 13 | 14 | 15 | @ex.command(unobserved=True) 16 | def recreate_experiment(): 17 | return XLNetLanguageModelingExperimentSetup() 18 | -------------------------------------------------------------------------------- /code_transformer/experiments/xl_net/language_modeling.yaml: -------------------------------------------------------------------------------- 1 | experiment_setup: 2 | executable: 'code_transformer/experiments/xl_net/language_modeling.py' 3 | 4 | data_setup: 5 | language: 'java-small' 6 | dataset_name: 'stage2' 7 | use_validation: True 8 | num_predict: 5 9 | 10 | transfer_learning: 11 | use_pretrained_model: False 12 | model_type: 'xl_net_lm' 13 | run_id: None 14 | snapshot_name: None 15 | 16 | model: 17 | with_cuda: True 18 | output_nonlinearity: None 19 | transformer_lm_encoder: 20 | subtokens_per_token: 5 21 | input_nonlinearity: 'tanh' 22 | transformer: 23 | d_model: 1024 24 | n_layer: 3 25 | n_head: 8 26 | d_inner: 2048 27 | ff_activation: 'gelu' 28 | dropout: 0.1 29 | 30 | optimizer: 31 | learning_rate: 5e-5 32 | reg_scale: 0 33 | 34 | scheduler: 'OneCycleLR' 35 | scheduler_params: 36 | max_lr: 5e-5 37 | steps_per_epoch: 4000 # 500000 / 128 38 | epochs: 21 39 | pct_start: 0.1 40 | 41 | training: 42 | persistent_snapshot_every: 50000 43 | random_seed: 123 44 | batch_size: 8 45 | simulated_batch_size: 128 46 | validate_every: 10 47 | metrics: 48 | - top1_accuracy 49 | - top5_accuracy 50 | - non_trivial_accuracy 51 | - precision 52 | - recall 53 | - f1_score 54 | -------------------------------------------------------------------------------- /code_transformer/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/modeling/__init__.py -------------------------------------------------------------------------------- /code_transformer/modeling/code_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/modeling/code_transformer/__init__.py -------------------------------------------------------------------------------- /code_transformer/modeling/code_transformer/decoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from code_transformer.configuration.transformer_lm_decoder import TransformerLMDecoderConfig 4 | from code_transformer.configuration.transformer_lm_encoder import TransformerLMEncoderConfig 5 | from code_transformer.modeling.decoder.transformer import TransformerLMDecoder 6 | from code_transformer.modeling.code_transformer.lm import TransformerLMEncoder 7 | from code_transformer.preprocessing.datamanager.base import CTBatch 8 | 9 | 10 | class CodeTransformerDecoder(TransformerLMDecoder): 11 | 12 | def __init__(self, config: TransformerLMDecoderConfig): 13 | if not isinstance(config.lm_encoder, nn.Module): 14 | config.lm_encoder = TransformerLMEncoder( 15 | TransformerLMEncoderConfig(**config.lm_encoder)) 16 | 17 | super(CodeTransformerDecoder, self).__init__(config) 18 | 19 | def forward_batch(self, batch: CTBatch, need_weights=False): 20 | return self.forward(input_tokens=batch.tokens, input_node_types=batch.node_types, 21 | input_token_types=batch.token_types, 22 | relative_distances=batch.relative_distances, attention_mask=batch.perm_mask, 23 | pad_mask=1 - batch.pad_mask, target_mapping=batch.target_mapping, 24 | labels=batch.labels, 25 | need_weights=need_weights, 26 | extended_vocabulary_ids=batch.extended_vocabulary_ids, 27 | pointer_pad_mask=batch.pointer_pad_mask, 28 | languages=batch.languages) 29 | -------------------------------------------------------------------------------- /code_transformer/modeling/decoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/modeling/decoder/__init__.py -------------------------------------------------------------------------------- /code_transformer/modeling/great_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/modeling/great_transformer/__init__.py -------------------------------------------------------------------------------- /code_transformer/modeling/great_transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from code_transformer.configuration.great_transformer import GreatEncoderConfig 4 | from code_transformer.configuration.transformer_lm_decoder import TransformerLMDecoderConfig 5 | from code_transformer.modeling.decoder.transformer import TransformerLMDecoder 6 | from code_transformer.modeling.great_transformer.great_transformer import GreatEncoder 7 | from code_transformer.modeling.code_transformer.code_transformer import TransformerOutput 8 | from code_transformer.preprocessing.dataset.code_summarization import GreatBatch 9 | from torch import nn 10 | 11 | 12 | class GreatEncoderTransformerAdapter(GreatEncoder): 13 | def forward(self, **model_input): 14 | output = super(GreatEncoderTransformerAdapter, self).forward(**model_input) 15 | return TransformerOutput(output[:, 0, :].unsqueeze(1), None, 16 | [(output.transpose(0, 1), torch.zeros((1, output.shape[0], output.shape[2]), device=output.device))]) 17 | 18 | 19 | class GreatTransformerDecoder(TransformerLMDecoder): 20 | 21 | def __init__(self, config: TransformerLMDecoderConfig): 22 | if not isinstance(config.lm_encoder, nn.Module): 23 | config.transformer_lm_encoder = GreatEncoder(GreatEncoderConfig(**config.lm_encoder)) 24 | 25 | config.lm_encoder.d_model = config.lm_encoder.transformer.hidden_dim 26 | 27 | super(GreatTransformerDecoder, self).__init__(config) 28 | 29 | def forward_batch(self, batch: GreatBatch): 30 | return self.forward(input_tokens=batch.tokens, 31 | edge_ixs=batch.edge_ixs, 32 | attention_mask=batch.attention_mask, 33 | pad_mask=1 - batch.pad_mask, 34 | labels=batch.labels, 35 | pointer_pad_mask=batch.pointer_pad_mask, 36 | extended_vocabulary_ids=batch.extended_vocabulary_ids, 37 | languages=batch.languages) 38 | -------------------------------------------------------------------------------- /code_transformer/modeling/modelmanager/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ModelManager 2 | from .code_transformer import * 3 | from .xl_net import XLNetModelManager 4 | from .great import GreatModelManager 5 | -------------------------------------------------------------------------------- /code_transformer/modeling/xl_net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/modeling/xl_net/__init__.py -------------------------------------------------------------------------------- /code_transformer/modeling/xl_net/decoder.py: -------------------------------------------------------------------------------- 1 | from code_transformer.configuration.transformer_lm_decoder import TransformerLMDecoderConfig 2 | from code_transformer.configuration.transformer_lm_encoder import TransformerLMEncoderConfig 3 | from code_transformer.modeling.decoder.transformer import TransformerLMDecoder 4 | from torch import nn 5 | 6 | from code_transformer.modeling.xl_net.xl_net_language_model import XLNetLMEncoder 7 | from code_transformer.preprocessing.datamanager.base import CTBatch 8 | 9 | 10 | class XLNetTransformerDecoder(TransformerLMDecoder): 11 | def __init__(self, config: TransformerLMDecoderConfig): 12 | if not isinstance(config.lm_encoder, nn.Module): 13 | config.lm_encoder = XLNetLMEncoder(TransformerLMEncoderConfig(**config.lm_encoder)) 14 | 15 | super(XLNetTransformerDecoder, self).__init__(config) 16 | 17 | def forward_batch(self, batch: CTBatch): 18 | return self.forward(input_ids=batch.tokens, 19 | token_type_ids=batch.token_types, 20 | pad_mask=1 - batch.pad_mask, 21 | attention_mask=batch.perm_mask, 22 | target_mapping=batch.target_mapping, 23 | labels=batch.labels, 24 | extended_vocabulary_ids=batch.extended_vocabulary_ids, 25 | pointer_pad_mask=batch.pointer_pad_mask, 26 | languages=batch.languages) 27 | -------------------------------------------------------------------------------- /code_transformer/preprocessing/README.md: -------------------------------------------------------------------------------- 1 | Setup 2 | ===== 3 | 4 | Semantic executable 5 | ------------------- 6 | 7 | We generate ASTs from code using the (github/semantic)[https://github.com/github/semantic] tool. 8 | One can either use one of the provided docker images or build your own executable by running `cabal v2-build --enable-executable-static` in the repository. 9 | The `--enable-executable-static` flag tells cabal to statically link all necessary libraries into the executable which makes it easier to run the semantic on other machines. 10 | -------------------------------------------------------------------------------- /code_transformer/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/preprocessing/__init__.py -------------------------------------------------------------------------------- /code_transformer/preprocessing/datamanager/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/preprocessing/datamanager/__init__.py -------------------------------------------------------------------------------- /code_transformer/preprocessing/datamanager/c2s/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/preprocessing/datamanager/c2s/__init__.py -------------------------------------------------------------------------------- /code_transformer/preprocessing/datamanager/c2s/raw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loader class to facilitate loading code2seq .json dataset files that were generated from raw java classes by using 3 | the JavaMethodExtractor.jar 4 | """ 5 | 6 | import glob 7 | import json 8 | import os 9 | import random 10 | 11 | from code_transformer.preprocessing.datamanager.base import DataManager, RawDataLoader 12 | from code_transformer.preprocessing.datamanager.csn.raw import CSNRawSample 13 | 14 | 15 | class C2SRawDataLoader(RawDataLoader): 16 | 17 | def __init__(self, data_location): 18 | self.data_location = data_location 19 | 20 | def get_available_datasets(self): 21 | return [name for name in os.listdir(self.data_location) if os.path.isdir(f"{self.data_location}/{name}")] 22 | 23 | def load_dataset(self, dataset, partition="train"): 24 | if os.path.isdir(f"{self.data_location}/{dataset}/{partition}"): 25 | # raw methods are separated in multiple dataset slices 26 | files = glob.glob(f"{self.data_location}/{dataset}/{partition}/dataset-*.json") 27 | self.samples = [] 28 | for file in files: 29 | with open(file, 'r') as f: 30 | self.samples.extend(json.load(f)) 31 | else: 32 | with open(f"{self.data_location}/{dataset}/{partition}.json", 'r') as f: 33 | self.samples = json.load(f) 34 | 35 | def read(self, batch_size=1, shuffle=False): 36 | if shuffle: 37 | lines = random.sample(self.samples, len(self.samples)) 38 | else: 39 | lines = self.samples 40 | 41 | reader = map(lambda sample: CSNRawSample(sample['name'], sample['doc'] if 'doc' in sample else None, 42 | sample['code']), lines) 43 | 44 | if batch_size > 1: 45 | return DataManager.to_batches(reader, batch_size) 46 | else: 47 | return reader 48 | 49 | def __len__(self): 50 | return len(self.samples) 51 | -------------------------------------------------------------------------------- /code_transformer/preprocessing/datamanager/csn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/preprocessing/datamanager/csn/__init__.py -------------------------------------------------------------------------------- /code_transformer/preprocessing/datamanager/csn/raw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloader for the raw method snippets from the CodeSearchNet dataset stored as .jsonl.gz files. 3 | """ 4 | 5 | import gzip 6 | import os 7 | from collections import namedtuple 8 | import random 9 | 10 | import jsonlines 11 | 12 | from code_transformer.preprocessing.datamanager.base import DataManager, RawDataLoader 13 | 14 | CSNRawSample = namedtuple("CSNRawSample", ["func_name", "docstring", "code_snippet"]) 15 | 16 | 17 | class CSNRawDataLoader(RawDataLoader): 18 | """ 19 | Loads and unzips the code snippets from the Code Search Net dataset. 20 | Returns sample as a 3-tuple containing 21 | """ 22 | 23 | def __init__(self, data_location): 24 | self.data_location = data_location 25 | self.lines = [] 26 | 27 | def load(self, language, partition, seq): 28 | with gzip.GzipFile( 29 | f"{self.data_location}/{language}/final/jsonl/{partition}/{language}_{partition}_{seq}.jsonl.gz", 30 | 'r') as fin: 31 | json_bytes = fin.read() 32 | json_str = json_bytes.decode('utf-8') 33 | self.lines.extend(json_str.split("\n")) 34 | 35 | def load_all_for(self, language, partition=None): 36 | if partition is None: 37 | partitions = ["train", "valid", "test"] 38 | else: 39 | partitions = [partition] 40 | for part in partitions: 41 | for seq in range(self.get_num_files(language, part)): 42 | self.load(language, part, seq) 43 | 44 | def get_available_languages(self): 45 | return [name for name in os.listdir(self.data_location) if os.path.isdir(f"{self.data_location}/{name}")] 46 | 47 | def get_num_files(self, language, partition): 48 | return len(os.listdir(f"{self.data_location}/{language}/final/jsonl/{partition}")) 49 | 50 | def __len__(self): 51 | return len(self.lines) 52 | 53 | def read(self, batch_size=1, shuffle=False): 54 | if shuffle: 55 | lines = random.sample(self.lines, len(self.lines)) 56 | else: 57 | lines = self.lines 58 | reader = jsonlines.Reader(lines) 59 | reader = map(lambda line: CSNRawSample(line['func_name'], line['docstring'], line['code']), reader) 60 | 61 | if batch_size > 1: 62 | return DataManager.to_batches(reader, batch_size) 63 | else: 64 | return reader 65 | -------------------------------------------------------------------------------- /code_transformer/preprocessing/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/preprocessing/dataset/__init__.py -------------------------------------------------------------------------------- /code_transformer/preprocessing/graph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/preprocessing/graph/__init__.py -------------------------------------------------------------------------------- /code_transformer/preprocessing/nlp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/preprocessing/nlp/__init__.py -------------------------------------------------------------------------------- /code_transformer/preprocessing/nlp/javaparser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Uses the JavaParser .jar to obtain an AST from a Java method snippet. 3 | """ 4 | 5 | import json 6 | import subprocess 7 | 8 | from code_transformer.env import JAVA_EXECUTABLE, JAVA_PARSER_EXECUTABLE 9 | from code_transformer.utils.log import get_logger 10 | 11 | JAVA_PARSER_CMD = f"{JAVA_EXECUTABLE} -jar {JAVA_PARSER_EXECUTABLE}" 12 | 13 | logger = get_logger(__file__) 14 | 15 | 16 | def java_to_ast(*code_snippets): 17 | asts = [] 18 | idx_successful = [] 19 | for i, code_snippet in enumerate(code_snippets): 20 | java_parser_call = subprocess.Popen(JAVA_PARSER_CMD, stdin=subprocess.PIPE, stdout=subprocess.PIPE, 21 | stderr=subprocess.PIPE, 22 | text=True, shell=True) 23 | output, errors = java_parser_call.communicate(code_snippet) 24 | java_parser_call.wait() 25 | if not errors == "": 26 | logger.warn(errors) 27 | logger.warn(code_snippet) 28 | else: 29 | output = json.loads(output) 30 | asts.append(output) 31 | idx_successful.append(i) 32 | return asts, idx_successful 33 | -------------------------------------------------------------------------------- /code_transformer/preprocessing/pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/preprocessing/pipeline/__init__.py -------------------------------------------------------------------------------- /code_transformer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/code_transformer/utils/__init__.py -------------------------------------------------------------------------------- /code_transformer/utils/io.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import os 4 | import pickle 5 | from pathlib import Path 6 | 7 | 8 | def save_zipped(obj, file): 9 | file = _file_ending(file, "p.gzip") 10 | create_directories(file) 11 | with gzip.open(file, 'wb') as f: 12 | pickle.dump(obj, f) 13 | 14 | 15 | def load_zipped(file): 16 | file = _file_ending(file, "p.gzip") 17 | with gzip.open(file, 'rb') as f: 18 | return pickle.load(f) 19 | 20 | 21 | def save_pickled(obj, file): 22 | file = _file_ending(file, "p") 23 | create_directories(file) 24 | with open(f"{file}", 'wb') as f: 25 | pickle.dump(obj, f) 26 | 27 | 28 | def load_pickled(file): 29 | file = _file_ending(file, "p") 30 | with open(file, 'rb') as f: 31 | return pickle.load(f) 32 | 33 | 34 | def save_json(obj: dict, file): 35 | file = _file_ending(file, "json") 36 | create_directories(file) 37 | with open(file, 'w') as f: 38 | json.dump(obj, f, indent=4) 39 | 40 | 41 | def load_json(file): 42 | file = _file_ending(file, "json") 43 | with open(file, 'r') as f: 44 | return json.load(f) 45 | 46 | 47 | def _file_ending(file, ending): 48 | return f"{file}.{ending}" if f".{ending}" not in file else file 49 | 50 | 51 | def create_directories(path): 52 | Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) 53 | -------------------------------------------------------------------------------- /code_transformer/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LabelSmoothingLoss(nn.Module): 6 | """ 7 | Adapted from https://github.com/pytorch/pytorch/issues/7455#issuecomment-513062631 8 | """ 9 | def __init__(self, smoothing=0.0, dim=-1): 10 | super(LabelSmoothingLoss, self).__init__() 11 | self.confidence = 1.0 - smoothing 12 | self.smoothing = smoothing 13 | self.dim = dim 14 | 15 | def forward(self, pred, target): 16 | num_classes = pred.shape[1] 17 | pred = pred.log_softmax(dim=self.dim) 18 | with torch.no_grad(): 19 | # true_dist = pred.data.clone() 20 | true_dist = torch.zeros_like(pred) 21 | true_dist.fill_(self.smoothing / (num_classes - 1)) 22 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 23 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) 24 | -------------------------------------------------------------------------------- /code_transformer/utils/sacred.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | 4 | import jsonpickle 5 | import yaml 6 | 7 | 8 | def _restore(flat): 9 | """ 10 | Restore more complex data that Python's json can't handle (e.g. Numpy arrays). 11 | Copied from sacred.serializer for performance reasons. 12 | """ 13 | return jsonpickle.decode(json.dumps(flat), keys=True) 14 | 15 | 16 | def _convert_value(value): 17 | """ 18 | Parse string as python literal if possible and fallback to string. 19 | Copied from sacred.arg_parser for performance reasons. 20 | """ 21 | 22 | try: 23 | return _restore(ast.literal_eval(value)) 24 | except (ValueError, SyntaxError): 25 | # use as string if nothing else worked 26 | return value 27 | 28 | 29 | def _convert_values(val): 30 | if isinstance(val, dict): 31 | for key, inner_val in val.items(): 32 | val[key] = _convert_values(inner_val) 33 | elif isinstance(val, list): 34 | for i, inner_val in enumerate(val): 35 | val[i] = _convert_values(inner_val) 36 | elif isinstance(val, str): 37 | return _convert_value(val) 38 | return val 39 | 40 | 41 | def read_config(config_path): 42 | with open(config_path, 'r') as conf: 43 | config_dict = _convert_values(yaml.load(conf, Loader=yaml.FullLoader)) 44 | 45 | return config_dict 46 | 47 | 48 | def parse_command(config): 49 | config_strings = [f'{key}="{val}"' if type(val) != str else f'{key}="\'{val}\'"' for key, val in config.items() if 50 | not key == 'experiment_setup'] 51 | exe = config['experiment_setup']['executable'] 52 | cmd = f"PYTHONPATH=$(pwd) python {exe} with {' '.join(config_strings)}" 53 | return exe, cmd 54 | -------------------------------------------------------------------------------- /code_transformer/utils/timing.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from time import time 3 | 4 | 5 | @contextmanager 6 | def timing(descr: str): 7 | start = time() 8 | yield 9 | print(f"{descr}: {time() - start:0.2f} seconds") 10 | 11 | 12 | class Timing: 13 | """ 14 | Usage: 15 | 16 | with Timing() as t: 17 | ... 18 | print(f"took {t[0]:0.2} seconds") 19 | 20 | Within the with-statement we can take several measurements by calling t.measure(). 21 | Every measurement will be available after the closure in the list t. 22 | """ 23 | 24 | class TimeMeasurer: 25 | 26 | def __init__(self, start): 27 | self.start = start 28 | self.times = [] 29 | 30 | def measure(self): 31 | now = time() 32 | measured_time = now - self.start 33 | self.times.append(measured_time) 34 | self.start = now 35 | return measured_time 36 | 37 | def __len__(self): 38 | return len(self.times) 39 | 40 | def __getitem__(self, i): 41 | return self.times[i] 42 | 43 | def __str__(self): 44 | if len(self.times) == 1: 45 | return f"{self.times[0]:.2f}" 46 | return f"{self.times}" 47 | 48 | def __enter__(self): 49 | self.time_measurer = Timing.TimeMeasurer(time()) 50 | return self.time_measurer 51 | 52 | def __exit__(self, exc_type, exc_val, exc_tb): 53 | self.time_measurer.measure() 54 | -------------------------------------------------------------------------------- /code_transformer/utils/vocab.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | from code_transformer.modeling.constants import PAD_TOKEN 6 | from code_transformer.preprocessing.datamanager.preprocessed import CTPreprocessedDataManager 7 | 8 | 9 | def decode_tokens(tokens: torch.Tensor, data_manager: CTPreprocessedDataManager = None, word_vocab=None, 10 | config=None) -> List[List[str]]: 11 | assert data_manager is not None or word_vocab is not None and config is not None, "Either data_manager or word_vocab and config have to be provided" 12 | if word_vocab is None: 13 | word_vocab, _, _ = data_manager.load_vocabularies() 14 | if config is None: 15 | config = data_manager.load_config() 16 | pad_id = config['preprocessing']['special_symbols'][PAD_TOKEN] 17 | 18 | words = [] 19 | for token in tokens: 20 | if isinstance(token, list) or isinstance(token, torch.Tensor): 21 | words.append( 22 | [word_vocab.reverse_lookup(sub_token.item()) for sub_token in token if not sub_token == pad_id]) 23 | elif not token == pad_id: 24 | words.append(word_vocab.reverse_lookup(token)) 25 | 26 | return words 27 | -------------------------------------------------------------------------------- /figures/code_transformer_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/figures/code_transformer_overview.png -------------------------------------------------------------------------------- /figures/preprocessing_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/figures/preprocessing_overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonlines==1.2.0 2 | rouge==1.0.0 3 | tensorflow==1.15.0 4 | joblib==0.14.1 5 | scipy==1.4.1 6 | networkx==2.4 7 | Pygments==2.6.1 8 | torch==1.4.0 9 | numpy==1.18.1 10 | jsonpickle==1.3 11 | pandas==1.0.5 12 | tqdm==4.43.0 13 | transformers==3.1.0 14 | six==1.14.0 15 | pytest==6.2.2 16 | PyYAML==5.4.1 17 | Requests==2.23.0 18 | scikit_learn==0.24.1 19 | environs==9.3.1 20 | sacred==0.8.1 -------------------------------------------------------------------------------- /scripts/extract-java-methods.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python wrapper for the JavaMethodExtractor to extract code snippets containing Java methods from the code2seq datasets. 3 | After this script is run, the extracted methods can be further preprocessed with our stage1 and stage2 pipeline to be 4 | fed into a CodeTransformer model eventually. 5 | """ 6 | 7 | import argparse 8 | import subprocess 9 | 10 | from code_transformer.env import CODE2SEQ_RAW_DATA_PATH, CODE2SEQ_EXTRACTED_METHODS_DATA_PATH, JAVA_METHOD_EXTRACTOR_EXECUTABLE, \ 11 | JAVA_EXECUTABLE 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("language") 15 | args = parser.parse_args() 16 | 17 | cmd_train = f"{JAVA_EXECUTABLE} -jar {JAVA_METHOD_EXTRACTOR_EXECUTABLE} " \ 18 | f"--dir {CODE2SEQ_RAW_DATA_PATH}/{args.language}/training " \ 19 | f"--output_dir {CODE2SEQ_EXTRACTED_METHODS_DATA_PATH}/{args.language}/train" 20 | 21 | subprocess.check_call(cmd_train, shell=True) 22 | 23 | cmd_valid = f"{JAVA_EXECUTABLE} -jar {JAVA_METHOD_EXTRACTOR_EXECUTABLE} " \ 24 | f"--dir {CODE2SEQ_RAW_DATA_PATH}/{args.language}/validation " \ 25 | f"--output_dir {CODE2SEQ_EXTRACTED_METHODS_DATA_PATH}/{args.language}/valid" 26 | 27 | subprocess.check_call(cmd_valid, shell=True) 28 | 29 | cmd_test = f"{JAVA_EXECUTABLE} -jar {JAVA_METHOD_EXTRACTOR_EXECUTABLE} " \ 30 | f"--dir {CODE2SEQ_RAW_DATA_PATH}/{args.language}/test " \ 31 | f"--output_dir {CODE2SEQ_EXTRACTED_METHODS_DATA_PATH}/{args.language}/test" 32 | 33 | subprocess.check_call(cmd_test, shell=True) 34 | -------------------------------------------------------------------------------- /scripts/run-experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Starts a train run. 3 | Usage: python -m scripts.run-experiments {config_file} 4 | the {config_file} is a .yaml file that contains experiment-specific configuration. 5 | See code_transformer/experiments/code_transformer/code_summarization.yaml for an example. 6 | """ 7 | 8 | import argparse 9 | import subprocess 10 | 11 | from code_transformer.utils.sacred import read_config, parse_command 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("config_file") 16 | args = parser.parse_args() 17 | 18 | config = read_config(args.config_file) 19 | exe, cmd = parse_command(config) 20 | 21 | subprocess.check_call(cmd, shell=True) 22 | -------------------------------------------------------------------------------- /scripts/run-preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Executes stage 1 and stage 2 data preprocessing of code snippets. 3 | Usage: python -m scripts.run-preprocessing {config_file} {language} {train|valid|test} 4 | the {config_file} is a .yaml file that contains preprocessing-specific configuration. 5 | See code_transformer/experiments/preprocessing/preprocess-1.yaml for an example. 6 | """ 7 | 8 | import argparse 9 | import subprocess 10 | 11 | from code_transformer.utils.sacred import read_config, parse_command 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("config_file") 16 | parser.add_argument("language") 17 | parser.add_argument("partition", choices=["train", "valid", "test"]) 18 | args = parser.parse_args() 19 | 20 | config = read_config(args.config_file) 21 | config['data'] = dict(language=args.language, 22 | partition=args.partition) 23 | exe, cmd = parse_command(config) 24 | 25 | subprocess.check_call(cmd, shell=True) 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='code_transformer', 4 | version='0.1', 5 | description='Code Transformer', 6 | author='Daniel Zügner, Tobias Kirschstein, Michele Catasta, Jure Leskovec, Stephan Günnemann', 7 | author_email='zuegnerd@in.tum.de,kirschto@in.tum.de', 8 | packages=find_packages(), 9 | install_requires=['jsonlines==1.2.0', 'rouge==1.0.0', 'joblib==0.14.1', 10 | 'scipy==1.4.1', 'networkx==2.4', 'Pygments==2.6.1', 'torch==1.4.0', 11 | 'numpy==1.18.1', 'jsonpickle==1.3', 'pandas==1.0.5', 'tqdm==4.43.0', 12 | 'transformers==3.1.0', 'six==1.14.0', 'pytest==6.2.2', 'PyYAML==5.4.1', 13 | 'Requests==2.23.0', 'scikit_learn==0.24.1', 'environs==9.3.1', 'sacred==0.8.1'], 14 | zip_safe=False) 15 | -------------------------------------------------------------------------------- /sub_modules/code2seq/.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.lst 3 | .idea/* 4 | *.iml 5 | *.xml 6 | *.pyc 7 | 8 | -------------------------------------------------------------------------------- /sub_modules/code2seq/CSharpExtractor/CSharpExtractor/.nuget/packages.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /sub_modules/code2seq/CSharpExtractor/CSharpExtractor/Extractor/Extractor.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Exe 5 | netcoreapp2.2 6 | Extractor.Program 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /sub_modules/code2seq/CSharpExtractor/CSharpExtractor/Extractor/Program.cs: -------------------------------------------------------------------------------- 1 | using CommandLine; 2 | using CommandLine.Text; 3 | using System; 4 | using System.Collections.Generic; 5 | using System.IO; 6 | using System.Linq; 7 | 8 | namespace Extractor 9 | { 10 | class Program 11 | { 12 | static List ExtractSingleFile(string filename, Options opts) 13 | { 14 | string data = File.ReadAllText(filename); 15 | var extractor = new Extractor(data, opts); 16 | List result = extractor.Extract(); 17 | 18 | return result; 19 | } 20 | 21 | static void Main(string[] args) 22 | { 23 | Options options = new Options(); 24 | Parser.Default.ParseArguments(args) 25 | .WithParsed(opt => options = opt) 26 | .WithNotParsed(errors => 27 | { 28 | Console.WriteLine(errors); 29 | return; 30 | }); 31 | 32 | string path = options.Path; 33 | string[] files; 34 | if (Directory.Exists(path)) 35 | { 36 | files = Directory.GetFiles(path, "*.cs", SearchOption.AllDirectories); 37 | } 38 | else 39 | { 40 | files = new string[] { path }; 41 | } 42 | 43 | IEnumerable results = null; 44 | 45 | results = files.AsParallel().WithDegreeOfParallelism(options.Threads).SelectMany(filename => ExtractSingleFile(filename, options)); 46 | 47 | using (StreamWriter sw = new StreamWriter(options.OFileName, append: true)) 48 | { 49 | foreach (var res in results) 50 | { 51 | sw.WriteLine(res); 52 | } 53 | } 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /sub_modules/code2seq/CSharpExtractor/CSharpExtractor/Extractor/Properties/launchSettings.json: -------------------------------------------------------------------------------- 1 | { 2 | "profiles": { 3 | "Extractor": { 4 | "commandName": "Project", 5 | "commandLineArgs": "--path C:\\Users\\urial\\Source\\Repos\\CSharpExtractor\\CSharpExtractor\\Extractor\\bin\\ --no_hash" 6 | } 7 | } 8 | } -------------------------------------------------------------------------------- /sub_modules/code2seq/CSharpExtractor/CSharpExtractor/Extractor/Temp.cs: -------------------------------------------------------------------------------- 1 | namespace Extractor 2 | { 3 | class Temp 4 | { 5 | class NestedClass 6 | { 7 | void fooBar() 8 | { 9 | a.b = c; 10 | } 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /sub_modules/code2seq/Input.java: -------------------------------------------------------------------------------- 1 | public String getName() { 2 | return name; 3 | } -------------------------------------------------------------------------------- /sub_modules/code2seq/JavaExtractor/JPredict/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /sub_modules/code2seq/JavaExtractor/JPredict/.gitignore: -------------------------------------------------------------------------------- 1 | /target/ -------------------------------------------------------------------------------- /sub_modules/code2seq/JavaExtractor/JPredict/src/main/java/JavaExtractor/App.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor; 2 | 3 | import JavaExtractor.Common.CommandLineValues; 4 | import org.kohsuke.args4j.CmdLineException; 5 | 6 | import java.io.IOException; 7 | import java.nio.file.Files; 8 | import java.nio.file.Paths; 9 | import java.util.LinkedList; 10 | import java.util.List; 11 | import java.util.concurrent.ExecutionException; 12 | import java.util.concurrent.Executors; 13 | import java.util.concurrent.Future; 14 | import java.util.concurrent.ThreadPoolExecutor; 15 | 16 | public class App { 17 | private static CommandLineValues s_CommandLineValues; 18 | 19 | public static void main(String[] args) { 20 | try { 21 | s_CommandLineValues = new CommandLineValues(args); 22 | } catch (CmdLineException e) { 23 | e.printStackTrace(); 24 | return; 25 | } 26 | 27 | if (s_CommandLineValues.File != null) { 28 | ExtractFeaturesTask extractFeaturesTask = new ExtractFeaturesTask(s_CommandLineValues, 29 | s_CommandLineValues.File.toPath()); 30 | extractFeaturesTask.processFile(); 31 | } else if (s_CommandLineValues.Dir != null) { 32 | extractDir(); 33 | } 34 | } 35 | 36 | private static void extractDir() { 37 | ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(s_CommandLineValues.NumThreads); 38 | LinkedList tasks = new LinkedList<>(); 39 | try { 40 | Files.walk(Paths.get(s_CommandLineValues.Dir)).filter(Files::isRegularFile) 41 | .filter(p -> p.toString().toLowerCase().endsWith(".java")).forEach(f -> { 42 | ExtractFeaturesTask task = new ExtractFeaturesTask(s_CommandLineValues, f); 43 | tasks.add(task); 44 | }); 45 | } catch (IOException e) { 46 | e.printStackTrace(); 47 | return; 48 | } 49 | List> tasksResults = null; 50 | try { 51 | tasksResults = executor.invokeAll(tasks); 52 | } catch (InterruptedException e) { 53 | e.printStackTrace(); 54 | } finally { 55 | executor.shutdown(); 56 | } 57 | tasksResults.forEach(f -> { 58 | try { 59 | f.get(); 60 | } catch (InterruptedException | ExecutionException e) { 61 | e.printStackTrace(); 62 | } 63 | }); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /sub_modules/code2seq/JavaExtractor/JPredict/src/main/java/JavaExtractor/Common/CommandLineValues.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor.Common; 2 | 3 | import org.kohsuke.args4j.CmdLineException; 4 | import org.kohsuke.args4j.CmdLineParser; 5 | import org.kohsuke.args4j.Option; 6 | 7 | import java.io.File; 8 | 9 | /** 10 | * This class handles the programs arguments. 11 | */ 12 | public class CommandLineValues { 13 | @Option(name = "--file", required = false) 14 | public File File = null; 15 | 16 | @Option(name = "--dir", required = false, forbids = "--file") 17 | public String Dir = null; 18 | 19 | @Option(name = "--max_path_length", required = true) 20 | public int MaxPathLength; 21 | 22 | @Option(name = "--max_path_width", required = true) 23 | public int MaxPathWidth; 24 | 25 | @Option(name = "--num_threads", required = false) 26 | public int NumThreads = 64; 27 | 28 | @Option(name = "--min_code_len", required = false) 29 | public int MinCodeLength = 1; 30 | 31 | @Option(name = "--max_code_len", required = false) 32 | public int MaxCodeLength = -1; 33 | 34 | @Option(name = "--max_file_len", required = false) 35 | public int MaxFileLength = -1; 36 | 37 | @Option(name = "--pretty_print", required = false) 38 | public boolean PrettyPrint = false; 39 | 40 | @Option(name = "--max_child_id", required = false) 41 | public int MaxChildId = 3; 42 | 43 | @Option(name = "--json_output", required = false) 44 | public boolean JsonOutput = false; 45 | 46 | public CommandLineValues(String... args) throws CmdLineException { 47 | CmdLineParser parser = new CmdLineParser(this); 48 | try { 49 | parser.parseArgument(args); 50 | } catch (CmdLineException e) { 51 | System.err.println(e.getMessage()); 52 | parser.printUsage(System.err); 53 | throw e; 54 | } 55 | } 56 | 57 | public CommandLineValues() { 58 | 59 | } 60 | } -------------------------------------------------------------------------------- /sub_modules/code2seq/JavaExtractor/JPredict/src/main/java/JavaExtractor/Common/MethodContent.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor.Common; 2 | 3 | import com.github.javaparser.ast.Node; 4 | 5 | import java.util.ArrayList; 6 | 7 | public class MethodContent { 8 | private final ArrayList leaves; 9 | private final String name; 10 | 11 | private final String content; 12 | 13 | public MethodContent(ArrayList leaves, String name, String content) { 14 | this.leaves = leaves; 15 | this.name = name; 16 | this.content = content; 17 | } 18 | 19 | public ArrayList getLeaves() { 20 | return leaves; 21 | } 22 | 23 | public String getName() { 24 | return name; 25 | } 26 | 27 | public String getContent() { 28 | return content; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /sub_modules/code2seq/JavaExtractor/JPredict/src/main/java/JavaExtractor/FeaturesEntities/ProgramFeatures.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor.FeaturesEntities; 2 | 3 | import java.nio.file.Path; 4 | import java.util.ArrayList; 5 | import java.util.stream.Collectors; 6 | 7 | public class ProgramFeatures { 8 | String name; 9 | 10 | transient ArrayList features = new ArrayList<>(); 11 | String textContent; 12 | 13 | String filePath; 14 | 15 | public ProgramFeatures(String name, Path filePath, String textContent) { 16 | 17 | this.name = name; 18 | this.filePath = filePath.toAbsolutePath().toString(); 19 | this.textContent = textContent; 20 | } 21 | 22 | @SuppressWarnings("StringBufferReplaceableByString") 23 | @Override 24 | public String toString() { 25 | StringBuilder stringBuilder = new StringBuilder(); 26 | stringBuilder.append(name).append(" "); 27 | stringBuilder.append(features.stream().map(ProgramRelation::toString).collect(Collectors.joining(" "))); 28 | 29 | return stringBuilder.toString(); 30 | } 31 | 32 | public void addFeature(Property source, String path, Property target) { 33 | ProgramRelation newRelation = new ProgramRelation(source, target, path); 34 | features.add(newRelation); 35 | } 36 | 37 | public boolean isEmpty() { 38 | return features.isEmpty(); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /sub_modules/code2seq/JavaExtractor/JPredict/src/main/java/JavaExtractor/FeaturesEntities/ProgramRelation.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor.FeaturesEntities; 2 | 3 | public class ProgramRelation { 4 | Property source; 5 | Property target; 6 | String path; 7 | 8 | public ProgramRelation(Property sourceName, Property targetName, String path) { 9 | source = sourceName; 10 | target = targetName; 11 | this.path = path; 12 | } 13 | 14 | public String toString() { 15 | return String.format("%s,%s,%s", source.getName(), path, 16 | target.getName()); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /sub_modules/code2seq/JavaExtractor/JPredict/src/main/java/Test.java: -------------------------------------------------------------------------------- 1 | class Test { 2 | void fooBar() { 3 | System.out.println("http://github.com"); 4 | } 5 | } -------------------------------------------------------------------------------- /sub_modules/code2seq/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Technion 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sub_modules/code2seq/Python150kExtractor/README.md: -------------------------------------------------------------------------------- 1 | # Python150k dataset 2 | 3 | ## Steps to reproduce 4 | 5 | 1. Download parsed python dataset from [here](https://www.sri.inf.ethz.ch/py150 6 | ), unarchive and place under `PYTHON150K_DIR`: 7 | 8 | ```bash 9 | # Replace with desired path. 10 | >>> PYTHON150K_DIR=/path/to/data/dir 11 | >>> mkdir -p $PYTHON150K_DIR 12 | >>> cd $PYTHON150K_DIR 13 | >>> wget http://files.srl.inf.ethz.ch/data/py150.tar.gz 14 | ... 15 | >>> tar -xzvf py150.tar.gz 16 | ... 17 | ``` 18 | 19 | 2. Extract samples to `DATA_DIR`: 20 | 21 | ```bash 22 | # Replace with desired path. 23 | >>> DATA_DIR=$(pwd)/data/default 24 | >>> SEED=239 25 | >>> python extract.py \ 26 | --data_dir=$PYTHON150K_DIR \ 27 | --output_dir=$DATA_DIR \ 28 | --seed=$SEED 29 | ... 30 | ``` 31 | 32 | 3. Preprocess for training: 33 | 34 | ```bash 35 | >>> ./preprocess.sh $DATA_DIR 36 | ... 37 | ``` 38 | 39 | 4. Train: 40 | 41 | ```bash 42 | >>> cd .. 43 | >>> DESC=default 44 | >>> CUDA=0 45 | >>> ./train_python150k.sh $DATA_DIR $DESC $CUDA $SEED 46 | ... 47 | ``` 48 | 49 | ## Test results (seed=239) 50 | 51 | ### Best scores 52 | 53 | **setup#2**: `batch_size=64` 54 | **setup#3**: `embedding_size=256,use_momentum=False` 55 | **setup#4**: `batch_size=32,embedding_size=256,embeddings_dropout_keep_prob=0.5,use_momentum=False` 56 | 57 | | params | Precision | Recall | F1 | ROUGE-2 | ROUGE-L | 58 | |---|---|---|---|---|---| 59 | | default | 0.37 | 0.27 | 0.31 | 0.06 | 0.38 | 60 | | setup#2 | 0.40 | 0.31 | 0.34 | 0.08 | 0.41 | 61 | | setup#3 | 0.36 | 0.31 | 0.33 | 0.09 | 0.38 | 62 | | setup#4 | 0.33 | 0.25 | 0.28 | 0.05 | 0.34 | 63 | 64 | ### Ablation studies 65 | 66 | | params | Precision | Recall | F1 | ROUGE-2 | ROUGE-L | 67 | |---|---|---|---|---|---| 68 | | default | 0.37 | 0.27 | 0.31 | 0.06 | 0.38 | 69 | | no ast nodes (5th epoch) | 0.27 | 0.16 | 0.20 | 0.02 | 0.28 | 70 | | no token split (4th epoch) | 0.60 | 0.09 | 0.15 | 0.00 | 0.60 | -------------------------------------------------------------------------------- /sub_modules/code2seq/Python150kExtractor/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MAX_CONTEXTS=200 4 | MAX_DATA_CONTEXTS=1000 5 | SUBTOKEN_VOCAB_SIZE=186277 6 | TARGET_VOCAB_SIZE=26347 7 | 8 | data_dir=${1:-data} 9 | mkdir -p "${data_dir}" 10 | train_data_file=$data_dir/train_output_file.txt 11 | valid_data_file=$data_dir/valid_output_file.txt 12 | test_data_file=$data_dir/test_output_file.txt 13 | 14 | echo "Creating histograms from the training data..." 15 | target_histogram_file=$data_dir/histo.tgt.c2s 16 | source_subtoken_histogram=$data_dir/histo.ori.c2s 17 | node_histogram_file=$data_dir/histo.node.c2s 18 | cut <"${train_data_file}" -d' ' -f1 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' >"${target_histogram_file}" 19 | cut <"${train_data_file}" -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' >"${source_subtoken_histogram}" 20 | cut <"${train_data_file}" -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' >"${node_histogram_file}" 21 | 22 | echo "Preprocessing..." 23 | python ../preprocess.py \ 24 | --train_data "${train_data_file}" \ 25 | --val_data "${valid_data_file}" \ 26 | --test_data "${test_data_file}" \ 27 | --max_contexts ${MAX_CONTEXTS} \ 28 | --max_data_contexts ${MAX_DATA_CONTEXTS} \ 29 | --subtoken_vocab_size ${SUBTOKEN_VOCAB_SIZE} \ 30 | --target_vocab_size ${TARGET_VOCAB_SIZE} \ 31 | --target_histogram "${target_histogram_file}" \ 32 | --subtoken_histogram "${source_subtoken_histogram}" \ 33 | --node_histogram "${node_histogram_file}" \ 34 | --output_name "${data_dir}"/"$(basename "${data_dir}")" 35 | rm \ 36 | "${target_histogram_file}" \ 37 | "${source_subtoken_histogram}" \ 38 | "${node_histogram_file}" 39 | -------------------------------------------------------------------------------- /sub_modules/code2seq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/sub_modules/code2seq/__init__.py -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/input_example.txt: -------------------------------------------------------------------------------- 1 | requires landscape|boolean (){ return false; } 2 | get parent key|Object (){ return new ContactsUiKey(); } 3 | get parent key|Object (){ return new ContactsUiKey(); } 4 | get layout id|int (){ return R.layout.loose_screen; } 5 | get parent key|Object (){ return new EditContactKey(contactId); } 6 | to contact|Contact (){ return new Contact(id, name, email); } 7 | to string|String (){ return "Welcome!\nClick to continue."; } 8 | get parent key|Object (){ return new EditContactKey(contactId); } 9 | tear down services|void (@NonNull Services services){ } 10 | get layout id|int (){ return R.layout.landscape_screen; } 11 | -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from . import parser 3 | from . import parse 4 | from . import tokenizer 5 | from . import javadoc 6 | 7 | 8 | __version__ = "0.10.1" 9 | -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/ast.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import six 4 | 5 | 6 | class MetaNode(type): 7 | def __new__(mcs, name, bases, dict): 8 | attrs = list(dict['attrs']) 9 | dict['attrs'] = list() 10 | 11 | for base in bases: 12 | if hasattr(base, 'attrs'): 13 | dict['attrs'].extend(base.attrs) 14 | 15 | dict['attrs'].extend(attrs) 16 | 17 | return type.__new__(mcs, name, bases, dict) 18 | 19 | 20 | @six.add_metaclass(MetaNode) 21 | class Node(object): 22 | attrs = () 23 | 24 | def __init__(self, **kwargs): 25 | values = kwargs.copy() 26 | 27 | for attr_name in self.attrs: 28 | value = values.pop(attr_name, None) 29 | setattr(self, attr_name, value) 30 | 31 | if values: 32 | raise ValueError('Extraneous arguments') 33 | 34 | def __equals__(self, other): 35 | if type(other) is not type(self): 36 | return False 37 | 38 | for attr in self.attrs: 39 | if getattr(other, attr) != getattr(self, attr): 40 | return False 41 | 42 | return True 43 | 44 | def __repr__(self): 45 | return type(self).__name__ 46 | 47 | def __iter__(self): 48 | return walk_tree(self) 49 | 50 | def filter(self, pattern): 51 | for path, node in self: 52 | if ((isinstance(pattern, type) and isinstance(node, pattern)) or 53 | (node == pattern)): 54 | yield path, node 55 | 56 | @property 57 | def children(self): 58 | return [getattr(self, attr_name) for attr_name in self.attrs] 59 | 60 | def walk_tree(root): 61 | children = None 62 | 63 | if isinstance(root, Node): 64 | yield (), root 65 | children = root.children 66 | else: 67 | children = root 68 | 69 | for child in children: 70 | if isinstance(child, (Node, list, tuple)): 71 | for path, node in walk_tree(child): 72 | yield (root,) + path, node 73 | 74 | def dump(ast, file): 75 | pickle.dump(ast, file) 76 | 77 | def load(file): 78 | return pickle.load(file) 79 | -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/parse.py: -------------------------------------------------------------------------------- 1 | 2 | from .parser import Parser 3 | from .tokenizer import tokenize 4 | 5 | def parse_expression(exp): 6 | if not exp.endswith(';'): 7 | exp = exp + ';' 8 | 9 | tokens = tokenize(exp) 10 | parser = Parser(tokens) 11 | 12 | return parser.parse_expression() 13 | 14 | def parse_member_signature(sig): 15 | if not sig.endswith(';'): 16 | sig = sig + ';' 17 | 18 | tokens = tokenize(sig) 19 | parser = Parser(tokens) 20 | 21 | return parser.parse_member_declaration() 22 | 23 | def parse_constructor_signature(sig): 24 | # Add an empty body to the signature, replacing a ; if necessary 25 | if sig.endswith(';'): 26 | sig = sig[:-1] 27 | sig = sig + '{ }' 28 | 29 | tokens = tokenize(sig) 30 | parser = Parser(tokens) 31 | 32 | return parser.parse_member_declaration() 33 | 34 | def parse_type(s): 35 | tokens = tokenize(s) 36 | parser = Parser(tokens) 37 | 38 | return parser.parse_type() 39 | 40 | def parse_type_signature(sig): 41 | if sig.endswith(';'): 42 | sig = sig[:-1] 43 | sig = sig + '{ }' 44 | 45 | tokens = tokenize(sig) 46 | parser = Parser(tokens) 47 | 48 | return parser.parse_class_or_interface_declaration() 49 | 50 | def parse(s): 51 | tokens = tokenize(s) 52 | parser = Parser(tokens) 53 | return parser.parse() 54 | -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/sub_modules/code2seq/baseline_tokenization/javalang/test/__init__.py -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/test/source/package-info/AnnotationJavadoc.java: -------------------------------------------------------------------------------- 1 | @Package 2 | /** 3 | Test that includes java doc first but no annotation 4 | */ 5 | package org.javalang.test; -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/test/source/package-info/AnnotationOnly.java: -------------------------------------------------------------------------------- 1 | @Package 2 | package org.javalang.test; -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/test/source/package-info/JavadocAnnotation.java: -------------------------------------------------------------------------------- 1 | /** 2 | Test that includes java doc first but no annotation 3 | */ 4 | @Package 5 | package org.javalang.test; -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/test/source/package-info/JavadocOnly.java: -------------------------------------------------------------------------------- 1 | /** 2 | Test that includes java doc first but no annotation 3 | */ 4 | package org.javalang.test; -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/test/source/package-info/NoAnnotationNoJavadoc.java: -------------------------------------------------------------------------------- 1 | package org.javalang.test; -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/test/test_javadoc.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from .. import javadoc 4 | 5 | 6 | class TestJavadoc(unittest.TestCase): 7 | def test_empty_comment(self): 8 | javadoc.parse('/** */') 9 | javadoc.parse('/***/') 10 | javadoc.parse('/**\n *\n */') 11 | javadoc.parse('/**\n *\n *\n */') 12 | 13 | if __name__ == "__main__": 14 | unittest.main() 15 | -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/test/test_package_declaration.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from pkg_resources import resource_string 4 | from .. import parse 5 | 6 | 7 | # From my reading of the spec (http://docs.oracle.com/javase/specs/jls/se7/html/jls-7.html) the 8 | # allowed order is javadoc, optional annotation, package declaration 9 | class PackageInfo(unittest.TestCase): 10 | def testPackageDeclarationOnly(self): 11 | source_file = "source/package-info/NoAnnotationNoJavadoc.java" 12 | ast = self.get_ast(source_file) 13 | 14 | self.failUnless(ast.package.name == "org.javalang.test") 15 | self.failIf(ast.package.annotations) 16 | self.failIf(ast.package.documentation) 17 | 18 | def testAnnotationOnly(self): 19 | source_file = "source/package-info/AnnotationOnly.java" 20 | ast = self.get_ast(source_file) 21 | 22 | self.failUnless(ast.package.name == "org.javalang.test") 23 | self.failUnless(ast.package.annotations) 24 | self.failIf(ast.package.documentation) 25 | 26 | def testJavadocOnly(self): 27 | source_file = "source/package-info/JavadocOnly.java" 28 | ast = self.get_ast(source_file) 29 | 30 | self.failUnless(ast.package.name == "org.javalang.test") 31 | self.failIf(ast.package.annotations) 32 | self.failUnless(ast.package.documentation) 33 | 34 | def testAnnotationThenJavadoc(self): 35 | source_file = "source/package-info/AnnotationJavadoc.java" 36 | ast = self.get_ast(source_file) 37 | 38 | self.failUnless(ast.package.name == "org.javalang.test") 39 | self.failUnless(ast.package.annotations) 40 | self.failIf(ast.package.documentation) 41 | 42 | def testJavadocThenAnnotation(self): 43 | source_file = "source/package-info/JavadocAnnotation.java" 44 | ast = self.get_ast(source_file) 45 | 46 | self.failUnless(ast.package.name == "org.javalang.test") 47 | self.failUnless(ast.package.annotations) 48 | self.failUnless(ast.package.documentation) 49 | 50 | def get_ast(self, filename): 51 | source = resource_string(__name__, filename) 52 | ast = parse.parse(source) 53 | 54 | return ast 55 | 56 | 57 | def main(): 58 | unittest.main() 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/javalang/test/test_util.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from ..util import LookAheadIterator 4 | 5 | 6 | class TestLookAheadIterator(unittest.TestCase): 7 | def test_usage(self): 8 | i = LookAheadIterator(list(range(0, 10000))) 9 | 10 | self.assertEqual(next(i), 0) 11 | self.assertEqual(next(i), 1) 12 | self.assertEqual(next(i), 2) 13 | 14 | self.assertEqual(i.last(), 2) 15 | 16 | self.assertEqual(i.look(), 3) 17 | self.assertEqual(i.last(), 3) 18 | 19 | self.assertEqual(i.look(1), 4) 20 | self.assertEqual(i.look(2), 5) 21 | self.assertEqual(i.look(3), 6) 22 | self.assertEqual(i.look(4), 7) 23 | 24 | self.assertEqual(i.last(), 7) 25 | 26 | i.push_marker() 27 | self.assertEqual(next(i), 3) 28 | self.assertEqual(next(i), 4) 29 | self.assertEqual(next(i), 5) 30 | i.pop_marker(True) # reset 31 | 32 | self.assertEqual(i.look(), 3) 33 | self.assertEqual(next(i), 3) 34 | 35 | i.push_marker() #1 36 | self.assertEqual(next(i), 4) 37 | self.assertEqual(next(i), 5) 38 | i.push_marker() #2 39 | self.assertEqual(next(i), 6) 40 | self.assertEqual(next(i), 7) 41 | i.push_marker() #3 42 | self.assertEqual(next(i), 8) 43 | self.assertEqual(next(i), 9) 44 | i.pop_marker(False) #3 45 | self.assertEqual(next(i), 10) 46 | i.pop_marker(True) #2 47 | self.assertEqual(next(i), 6) 48 | self.assertEqual(next(i), 7) 49 | self.assertEqual(next(i), 8) 50 | i.pop_marker(False) #1 51 | self.assertEqual(next(i), 9) 52 | 53 | try: 54 | with i: 55 | self.assertEqual(next(i), 10) 56 | self.assertEqual(next(i), 11) 57 | raise Exception() 58 | except: 59 | self.assertEqual(next(i), 10) 60 | self.assertEqual(next(i), 11) 61 | 62 | with i: 63 | self.assertEqual(next(i), 12) 64 | self.assertEqual(next(i), 13) 65 | self.assertEqual(next(i), 14) 66 | 67 | 68 | if __name__=="__main__": 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /sub_modules/code2seq/baseline_tokenization/subtokenize_nmt_baseline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import javalang 4 | import sys 5 | import re 6 | 7 | 8 | modifiers = ['public', 'private', 'protected', 'static'] 9 | 10 | RE_WORDS = re.compile(r''' 11 | # Find words in a string. Order matters! 12 | [A-Z]+(?=[A-Z][a-z]) | # All upper case before a capitalized word 13 | [A-Z]?[a-z]+ | # Capitalized words / all lower case 14 | [A-Z]+ | # All upper case 15 | \d+ | # Numbers 16 | .+ 17 | ''', re.VERBOSE) 18 | 19 | def split_subtokens(str): 20 | return [subtok for subtok in RE_WORDS.findall(str) if not subtok == '_'] 21 | 22 | def tokenizeFile(file_path): 23 | lines = 0 24 | with open(file_path, 'r', encoding="utf-8") as file: 25 | with open(file_path + 'method_names.txt', 'w') as method_names_file: 26 | with open(file_path + 'method_subtokens_content.txt', 'w') as method_contents_file: 27 | for line in file: 28 | lines += 1 29 | line = line.rstrip() 30 | parts = line.split('|', 1) 31 | method_name = parts[0] 32 | method_content = parts[1] 33 | try: 34 | tokens = list(javalang.tokenizer.tokenize(method_content)) 35 | except: 36 | print('ERROR in tokenizing: ' + method_content) 37 | #tokens = method_content.split(' ') 38 | if len(method_name) > 0 and len(tokens) > 0: 39 | method_names_file.write(method_name + '\n') 40 | method_contents_file.write(' '.join([' '.join(split_subtokens(i.value)) for i in tokens if not i.value in modifiers]) + '\n') 41 | else: 42 | print('ERROR in len of: ' + method_name + ', tokens: ' + str(tokens)) 43 | print(str(lines)) 44 | 45 | 46 | if __name__ == '__main__': 47 | file = sys.argv[1] 48 | tokenizeFile(file) 49 | 50 | 51 | -------------------------------------------------------------------------------- /sub_modules/code2seq/code2seq.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from config import Config 6 | from interactive_predict import InteractivePredictor 7 | from model import Model 8 | 9 | if __name__ == '__main__': 10 | parser = ArgumentParser() 11 | parser.add_argument("-d", "--data", dest="data_path", 12 | help="path to preprocessed dataset", required=False) 13 | parser.add_argument("-te", "--test", dest="test_path", 14 | help="path to test file", metavar="FILE", required=False) 15 | 16 | parser.add_argument("-s", "--save_prefix", dest="save_path_prefix", 17 | help="path to save file", metavar="FILE", required=False) 18 | parser.add_argument("-l", "--load", dest="load_path", 19 | help="path to saved file", metavar="FILE", required=False) 20 | parser.add_argument('--release', action='store_true', 21 | help='if specified and loading a trained model, release the loaded model for a smaller model ' 22 | 'size.') 23 | parser.add_argument('--predict', action='store_true') 24 | parser.add_argument('--debug', action='store_true') 25 | parser.add_argument('--seed', type=int, default=239) 26 | args = parser.parse_args() 27 | 28 | np.random.seed(args.seed) 29 | tf.set_random_seed(args.seed) 30 | 31 | if args.debug: 32 | config = Config.get_debug_config(args) 33 | else: 34 | config = Config.get_default_config(args) 35 | 36 | model = Model(config) 37 | print('Created model') 38 | if config.TRAIN_PATH: 39 | model.train() 40 | if config.TEST_PATH and not args.data_path: 41 | results, precision, recall, f1, rouge = model.evaluate() 42 | print('Accuracy: ' + str(results)) 43 | print('Precision: ' + str(precision) + ', recall: ' + str(recall) + ', F1: ' + str(f1)) 44 | print('Rouge: ', rouge) 45 | if args.predict: 46 | predictor = InteractivePredictor(config, model) 47 | predictor.predict() 48 | if args.release and args.load_path: 49 | model.evaluate(release=True) 50 | model.close_session() 51 | -------------------------------------------------------------------------------- /sub_modules/code2seq/evaluate.sh: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # Change the following values to train a new model. 3 | # type: the name of the new model, only affects the saved file name. 4 | # dataset: the name of the dataset, as was preprocessed using preprocess.sh 5 | # test_data: by default, points to the validation set, since this is the set that 6 | # will be evaluated after each training iteration. If you wish to test 7 | # on the final (held-out) test set, change 'val' to 'test'. 8 | 9 | dataset_name=$1 10 | partition=$2 11 | snapshot_name=$3 12 | data_dir=<<>>/${dataset_name} 13 | test_data=${data_dir}/${dataset_name}.${partition}.c2s 14 | parameters_file=<<>>/code2seq/${snapshot_name} 15 | output_dir=<<>>/code2seq-evaluation/${snapshot_name}/${partition} 16 | 17 | set -e 18 | python3 -u code2seq.py --test ${test_data} --load ${parameters_file} --save_prefix ${output_dir} 19 | -------------------------------------------------------------------------------- /sub_modules/code2seq/extractor.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import requests 4 | 5 | from common import PathContextInformation 6 | 7 | 8 | class Extractor: 9 | def __init__(self, config, extractor_api_url, max_path_length, max_path_width): 10 | self.config = config 11 | self.max_path_length = max_path_length 12 | self.max_path_width = max_path_width 13 | self.extractor_api_url = extractor_api_url 14 | self.bad_characters_table = str.maketrans('', '', '\t\r\n') 15 | 16 | @staticmethod 17 | def post_request(url, code_string): 18 | return requests.post(url, data=json.dumps({"code": code_string, "decompose": True}, separators=(',', ':'))) 19 | 20 | def extract_paths(self, code_string): 21 | response = self.post_request(self.extractor_api_url, code_string) 22 | response_array = json.loads(response.text) 23 | if 'errorType' in response_array: 24 | raise ValueError(response.text) 25 | if 'errorMessage' in response_array: 26 | raise TimeoutError(response.text) 27 | pc_info_dict = {} 28 | result = [] 29 | for single_method in response_array: 30 | method_name = single_method['target'] 31 | current_result_line_parts = [method_name] 32 | contexts = single_method['paths'] 33 | for context in contexts[:self.config.DATA_NUM_CONTEXTS]: 34 | pc_info = PathContextInformation(context) 35 | current_result_line_parts += [str(pc_info)] 36 | pc_info_dict[(pc_info.token1, pc_info.shortPath, pc_info.token2)] = pc_info 37 | space_padding = ' ' * (self.config.DATA_NUM_CONTEXTS - len(contexts)) 38 | result_line = ' '.join(current_result_line_parts) + space_padding 39 | result.append(result_line) 40 | return result, pc_info_dict 41 | -------------------------------------------------------------------------------- /sub_modules/code2seq/images/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/sub_modules/code2seq/images/network.png -------------------------------------------------------------------------------- /sub_modules/code2seq/train.sh: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # Change the following values to train a new model. 3 | # type: the name of the new model, only affects the saved file name. 4 | # dataset: the name of the dataset, as was preprocessed using preprocess.sh 5 | # test_data: by default, points to the validation set, since this is the set that 6 | # will be evaluated after each training iteration. If you wish to test 7 | # on the final (held-out) test set, change 'val' to 'test'. 8 | type=python_test 9 | dataset_name=$1 10 | data_dir=<<>>/${dataset_name} 11 | data=${data_dir}/${dataset_name} 12 | test_data=${data_dir}/${dataset_name}.val.c2s 13 | model_dir=<<>>/code2seq/${dataset_name} 14 | 15 | mkdir -p ${model_dir} 16 | set -e 17 | python3 -u code2seq.py --data ${data} --test ${test_data} --save_prefix ${model_dir}/model 18 | -------------------------------------------------------------------------------- /sub_modules/code2seq/train_python150k.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | data_dir=$1 4 | data_name=$(basename "${data_dir}") 5 | data=${data_dir}/${data_name} 6 | test=${data_dir}/${data_name}.val.c2s 7 | run_name=$2 8 | model_dir=models/python150k-${run_name} 9 | save_prefix=${model_dir}/model 10 | cuda=${3:-0} 11 | seed=${4:-239} 12 | 13 | mkdir -p "${model_dir}" 14 | set -e 15 | CUDA_VISIBLE_DEVICES=$cuda python -u code2seq.py \ 16 | --data="${data}" \ 17 | --test="${test}" \ 18 | --save_prefix="${save_prefix}" \ 19 | --seed="${seed}" 20 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/JavaExtractor (1).iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/JavaExtractor.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/JavaMethodExtractor.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/code-2-seq-java-extractor.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/dependency-reduced-pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | JavaExtractor 5 | JavaExtractor 6 | JPredict 7 | 0.0.1-SNAPSHOT 8 | http://maven.apache.org 9 | 10 | 11 | 12 | maven-compiler-plugin 13 | 3.2 14 | 15 | 1.8 16 | 1.8 17 | 18 | Test.java 19 | 20 | 21 | 22 | 23 | maven-shade-plugin 24 | 2.1 25 | 26 | 27 | package 28 | 29 | shade 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | UTF-8 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/src/main/JavaExtractor/App.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor; 2 | 3 | import JavaExtractor.Common.CommandLineValues; 4 | import org.kohsuke.args4j.CmdLineException; 5 | 6 | import java.io.IOException; 7 | import java.nio.file.Files; 8 | import java.nio.file.Paths; 9 | import java.util.LinkedList; 10 | import java.util.List; 11 | import java.util.concurrent.ExecutionException; 12 | import java.util.concurrent.Executors; 13 | import java.util.concurrent.Future; 14 | import java.util.concurrent.ThreadPoolExecutor; 15 | 16 | public class App { 17 | private static CommandLineValues s_CommandLineValues; 18 | 19 | public static void main(String[] args) { 20 | try { 21 | s_CommandLineValues = new CommandLineValues(args); 22 | } catch (CmdLineException e) { 23 | e.printStackTrace(); 24 | return; 25 | } 26 | 27 | if (s_CommandLineValues.File != null) { 28 | ExtractFeaturesTask extractFeaturesTask = new ExtractFeaturesTask(s_CommandLineValues, 29 | s_CommandLineValues.File.toPath()); 30 | extractFeaturesTask.processFile(); 31 | } else if (s_CommandLineValues.Dir != null) { 32 | extractDir(); 33 | } 34 | } 35 | 36 | private static void extractDir() { 37 | ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(s_CommandLineValues.NumThreads); 38 | LinkedList tasks = new LinkedList<>(); 39 | try { 40 | Files.walk(Paths.get(s_CommandLineValues.Dir)).filter(Files::isRegularFile) 41 | .filter(p -> p.toString().toLowerCase().endsWith(".java")).forEach(f -> { 42 | ExtractFeaturesTask task = new ExtractFeaturesTask(s_CommandLineValues, f); 43 | tasks.add(task); 44 | }); 45 | } catch (IOException e) { 46 | e.printStackTrace(); 47 | return; 48 | } 49 | List> tasksResults = null; 50 | try { 51 | tasksResults = executor.invokeAll(tasks); 52 | } catch (InterruptedException e) { 53 | e.printStackTrace(); 54 | } finally { 55 | executor.shutdown(); 56 | } 57 | tasksResults.forEach(f -> { 58 | try { 59 | f.get(); 60 | } catch (InterruptedException | ExecutionException e) { 61 | e.printStackTrace(); 62 | } 63 | }); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/src/main/JavaExtractor/Common/CommandLineValues.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor.Common; 2 | 3 | import org.kohsuke.args4j.CmdLineException; 4 | import org.kohsuke.args4j.CmdLineParser; 5 | import org.kohsuke.args4j.Option; 6 | 7 | import java.io.File; 8 | 9 | /** 10 | * This class handles the programs arguments. 11 | */ 12 | public class CommandLineValues { 13 | @Option(name = "--file", required = false) 14 | public File File = null; 15 | 16 | @Option(name = "--dir", required = false, forbids = "--file") 17 | public String Dir = null; 18 | 19 | @Option(name = "--num_threads", required = false) 20 | public int NumThreads = 64; 21 | 22 | @Option(name = "--min_code_len", required = false) 23 | public int MinCodeLength = 1; 24 | 25 | @Option(name = "--max_code_len", required = false) 26 | public int MaxCodeLength = -1; 27 | 28 | @Option(name = "--max_file_len", required = false) 29 | public int MaxFileLength = -1; 30 | 31 | @Option(name = "--pretty_print", required = false) 32 | public boolean PrettyPrint = false; 33 | 34 | @Option(name = "--max_child_id", required = false) 35 | public int MaxChildId = 3; 36 | 37 | public CommandLineValues(String... args) throws CmdLineException { 38 | CmdLineParser parser = new CmdLineParser(this); 39 | try { 40 | parser.parseArgument(args); 41 | } catch (CmdLineException e) { 42 | System.err.println(e.getMessage()); 43 | parser.printUsage(System.err); 44 | throw e; 45 | } 46 | } 47 | 48 | public CommandLineValues() { 49 | 50 | } 51 | } -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/src/main/JavaExtractor/Common/MethodContent.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor.Common; 2 | 3 | import com.github.javaparser.ast.Node; 4 | 5 | import java.util.ArrayList; 6 | 7 | public class MethodContent { 8 | private final ArrayList leaves; 9 | private final String name; 10 | 11 | public MethodContent(ArrayList leaves, String name) { 12 | this.leaves = leaves; 13 | this.name = name; 14 | } 15 | 16 | public ArrayList getLeaves() { 17 | return leaves; 18 | } 19 | 20 | public String getName() { 21 | return name; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/src/main/JavaExtractor/FeatureExtractor.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor; 2 | 3 | import JavaExtractor.Common.CommandLineValues; 4 | import JavaExtractor.Common.MethodContent; 5 | import JavaExtractor.FeaturesEntities.ProgramFeatures; 6 | import JavaExtractor.Visitors.FunctionVisitor; 7 | import com.github.javaparser.JavaParser; 8 | import com.github.javaparser.ParseProblemException; 9 | import com.github.javaparser.ast.CompilationUnit; 10 | 11 | import java.util.ArrayList; 12 | 13 | class FeatureExtractor { 14 | private final CommandLineValues m_CommandLineValues; 15 | 16 | public FeatureExtractor(CommandLineValues commandLineValues) { 17 | this.m_CommandLineValues = commandLineValues; 18 | } 19 | 20 | 21 | public ArrayList extractFeatures(String code) { 22 | CompilationUnit m_CompilationUnit = parseFileWithRetries(code); 23 | FunctionVisitor functionVisitor = new FunctionVisitor(m_CommandLineValues); 24 | 25 | functionVisitor.visit(m_CompilationUnit, null); 26 | 27 | ArrayList methods = functionVisitor.getMethodContents(); 28 | 29 | return null; 30 | } 31 | 32 | private CompilationUnit parseFileWithRetries(String code) { 33 | final String classPrefix = "public class Test {"; 34 | final String classSuffix = "}"; 35 | final String methodPrefix = "SomeUnknownReturnType f() {"; 36 | final String methodSuffix = "return noSuchReturnValue; }"; 37 | 38 | String content = code; 39 | CompilationUnit parsed; 40 | try { 41 | parsed = JavaParser.parse(content); 42 | } catch (ParseProblemException e1) { 43 | // Wrap with a class and method 44 | try { 45 | content = classPrefix + methodPrefix + code + methodSuffix + classSuffix; 46 | parsed = JavaParser.parse(content); 47 | } catch (ParseProblemException e2) { 48 | // Wrap with a class only 49 | content = classPrefix + code + classSuffix; 50 | parsed = JavaParser.parse(content); 51 | } 52 | } 53 | 54 | return parsed; 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/src/main/JavaExtractor/FeaturesEntities/ProgramFeatures.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor.FeaturesEntities; 2 | 3 | import com.fasterxml.jackson.annotation.JsonIgnore; 4 | 5 | import java.util.ArrayList; 6 | import java.util.stream.Collectors; 7 | 8 | public class ProgramFeatures { 9 | private final String name; 10 | 11 | private final ArrayList features = new ArrayList<>(); 12 | 13 | public ProgramFeatures(String name) { 14 | this.name = name; 15 | } 16 | 17 | @SuppressWarnings("StringBufferReplaceableByString") 18 | @Override 19 | public String toString() { 20 | StringBuilder stringBuilder = new StringBuilder(); 21 | stringBuilder.append(name).append(" "); 22 | stringBuilder.append(features.stream().map(ProgramRelation::toString).collect(Collectors.joining(" "))); 23 | 24 | return stringBuilder.toString(); 25 | } 26 | 27 | public void addFeature(Property source, String path, Property target) { 28 | ProgramRelation newRelation = new ProgramRelation(source, target, path); 29 | features.add(newRelation); 30 | } 31 | 32 | @JsonIgnore 33 | public boolean isEmpty() { 34 | return features.isEmpty(); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/src/main/JavaExtractor/FeaturesEntities/ProgramRelation.java: -------------------------------------------------------------------------------- 1 | package JavaExtractor.FeaturesEntities; 2 | 3 | public class ProgramRelation { 4 | private final Property m_Source; 5 | private final Property m_Target; 6 | private final String m_Path; 7 | 8 | public ProgramRelation(Property sourceName, Property targetName, String path) { 9 | m_Source = sourceName; 10 | m_Target = targetName; 11 | m_Path = path; 12 | } 13 | 14 | public String toString() { 15 | return String.format("%s,%s,%s", m_Source.getName(), m_Path, 16 | m_Target.getName()); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/src/main/java/CommandLineValues.java: -------------------------------------------------------------------------------- 1 | import org.kohsuke.args4j.CmdLineException; 2 | import org.kohsuke.args4j.CmdLineParser; 3 | import org.kohsuke.args4j.Option; 4 | 5 | public class CommandLineValues { 6 | @Option(name = "--file", required = false) 7 | public java.io.File File = null; 8 | 9 | @Option(name = "--dir", required = false, forbids = "--file") 10 | public String Dir = null; 11 | 12 | @Option(name = "--output_dir", required = false) 13 | public String OutputDir = null; 14 | 15 | @Option(name = "--num_threads", required = false) 16 | public int NumThreads = 64; 17 | 18 | @Option(name = "--max_file_len", required = false) 19 | public int MaxFileLength = -1; 20 | 21 | @Option(name = "--min_code_len", required = false) 22 | public int MinCodeLength = 1; 23 | 24 | @Option(name = "--max_code_len", required = false) 25 | public int MaxCodeLength = -1; 26 | 27 | public CommandLineValues(String... args) throws CmdLineException { 28 | CmdLineParser parser = new CmdLineParser(this); 29 | try { 30 | parser.parseArgument(args); 31 | } catch (CmdLineException e) { 32 | System.err.println(e.getMessage()); 33 | parser.printUsage(System.err); 34 | throw e; 35 | } 36 | } 37 | 38 | public CommandLineValues() { 39 | 40 | } 41 | } -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/src/main/java/MethodContent.java: -------------------------------------------------------------------------------- 1 | public class MethodContent { 2 | 3 | private final String code; 4 | private final String name; 5 | private final String doc; 6 | 7 | public MethodContent(String code, String name, String doc) { 8 | this.code = code; 9 | this.name = name; 10 | this.doc = doc; 11 | } 12 | 13 | public String getCode() { 14 | return code; 15 | } 16 | 17 | public String getName() { 18 | return name; 19 | } 20 | 21 | public String getDoc() { 22 | return doc; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/src/main/java/MethodVisitor.java: -------------------------------------------------------------------------------- 1 | import com.github.javaparser.ast.body.MethodDeclaration; 2 | import com.github.javaparser.ast.visitor.VoidVisitorAdapter; 3 | 4 | import java.util.ArrayList; 5 | import java.util.Arrays; 6 | 7 | public class MethodVisitor extends VoidVisitorAdapter { 8 | 9 | private final ArrayList methods = new ArrayList<>(); 10 | private final CommandLineValues commandLineValues; 11 | 12 | public MethodVisitor(CommandLineValues commandLineValues) { 13 | this.commandLineValues = commandLineValues; 14 | } 15 | 16 | @Override 17 | public void visit(MethodDeclaration node, Object arg) { 18 | String methodCode = node.toString(); 19 | String methodName = node.getName(); 20 | String doc = null; 21 | if (node.getJavaDoc() != null) { 22 | doc = node.getJavaDoc().getContent(); 23 | } else if (node.getComment() != null) { 24 | doc = node.getComment().getContent(); 25 | } else if (node.getParentNode().getComment() != null) { 26 | doc = node.getParentNode().getComment().getContent(); 27 | } 28 | 29 | if (node.getBody() != null) { 30 | long methodLength = getMethodLength(node.getBody().toString()); 31 | if (commandLineValues.MaxCodeLength > 0) { 32 | if (methodLength >=commandLineValues.MinCodeLength && methodLength <= commandLineValues.MaxCodeLength) { 33 | methods.add(new MethodContent(methodCode, methodName, doc)); 34 | } 35 | } else { 36 | methods.add(new MethodContent(methodCode, methodName, doc)); 37 | } 38 | } 39 | 40 | super.visit(node, arg); 41 | } 42 | 43 | private long getMethodLength(String code) { 44 | String cleanCode = code.replaceAll("\r\n", "\n").replaceAll("\t", " "); 45 | if (cleanCode.startsWith("{\n")) 46 | cleanCode = cleanCode.substring(3).trim(); 47 | if (cleanCode.endsWith("\n}")) 48 | cleanCode = cleanCode.substring(0, cleanCode.length() - 2).trim(); 49 | if (cleanCode.length() == 0) { 50 | return 0; 51 | } 52 | return Arrays.stream(cleanCode.split("\n")) 53 | .filter(line -> (line.trim() != "{" && line.trim() != "}" && line.trim() != "")) 54 | .filter(line -> !line.trim().startsWith("/") && !line.trim().startsWith("*")).count(); 55 | } 56 | 57 | public ArrayList getMethods() { 58 | return methods; 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /sub_modules/java-method-extractor/target/JavaMethodExtractor-1.0.0-SNAPSHOT.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/sub_modules/java-method-extractor/target/JavaMethodExtractor-1.0.0-SNAPSHOT.jar -------------------------------------------------------------------------------- /sub_modules/java-parser/java-parser.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sub_modules/java-parser/src/main/java/ASTNode.java: -------------------------------------------------------------------------------- 1 | import com.github.javaparser.Position; 2 | import com.github.javaparser.Range; 3 | import com.github.javaparser.ast.Node; 4 | 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | 8 | public class ASTNode { 9 | private String type; 10 | private Range sourceRange; 11 | private List childNodes; 12 | 13 | 14 | public ASTNode(String type, Range sourceRange) { 15 | this.type = type; 16 | 17 | if (sourceRange == null) { 18 | this.sourceRange = null; 19 | } else { 20 | this.sourceRange = new Range( 21 | new Position(sourceRange.begin.line - 1, sourceRange.begin.column), 22 | new Position(sourceRange.end.line - 1, sourceRange.end.column + 1)); 23 | } 24 | 25 | this.childNodes = new ArrayList<>(); 26 | } 27 | 28 | public static ASTNode fromNode(Node node) { 29 | Range range = null; 30 | if (node.getRange().isPresent()) { 31 | range = node.getRange().get(); 32 | } 33 | ASTNode parsedNode = new ASTNode(node.getClass().getSimpleName(), range); 34 | for (Node childNode : node.getChildNodes()) { 35 | parsedNode.addChildNode(ASTNode.fromNode(childNode)); 36 | } 37 | return parsedNode; 38 | } 39 | 40 | public void addChildNode(ASTNode childNode) { 41 | this.childNodes.add(childNode); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /sub_modules/java-parser/src/main/java/META-INF/MANIFEST.MF: -------------------------------------------------------------------------------- 1 | Manifest-Version: 1.0 2 | Main-Class: ASTParser 3 | 4 | -------------------------------------------------------------------------------- /sub_modules/java-parser/target/java-parser-1.0-SNAPSHOT.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielzuegner/code-transformer/c7eb56e895cd70307cf4a69cb6c5d8495d17b469/sub_modules/java-parser/target/java-parser-1.0-SNAPSHOT.jar -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch 4 | from torch.nn import CrossEntropyLoss 5 | 6 | from code_transformer.utils.loss import LabelSmoothingLoss 7 | 8 | VOCAB_SIZE = 37 9 | NUM_SUB_TOKENS = 5 10 | PAD_ID = 3 11 | 12 | 13 | class TestLoss(TestCase): 14 | 15 | @staticmethod 16 | def create_prediction(*desired_words): 17 | token = [] 18 | for desired_word in desired_words: 19 | vocab_distribution = [0 for _ in range(VOCAB_SIZE)] 20 | if isinstance(desired_word, list): 21 | for i, w in enumerate(desired_word): 22 | vocab_distribution[w] = len(desired_word) - i 23 | else: 24 | vocab_distribution[desired_word] = 1 25 | token.append(vocab_distribution) 26 | for _ in range(NUM_SUB_TOKENS - len(desired_words)): 27 | vocab_distribution = [0 for _ in range(VOCAB_SIZE)] 28 | vocab_distribution[PAD_ID] = 1 29 | token.append(vocab_distribution) 30 | return token 31 | 32 | @staticmethod 33 | def create_label(*desired_words): 34 | label = [w for w in desired_words] 35 | for i in range(NUM_SUB_TOKENS - len(desired_words)): 36 | label.append(PAD_ID) 37 | return label 38 | 39 | def test_label_smoothing(self): 40 | label_smoothing = LabelSmoothingLoss() 41 | label_smoothing_01 = LabelSmoothingLoss(0.1) 42 | cross_entropy = CrossEntropyLoss() 43 | 44 | logits = torch.tensor([[TestLoss.create_prediction(10), 45 | TestLoss.create_prediction(11)], 46 | [TestLoss.create_prediction(20, 21), 47 | TestLoss.create_prediction()]], dtype=torch.float32) 48 | labels = torch.tensor([[TestLoss.create_label(10), 49 | TestLoss.create_label(10)], 50 | [TestLoss.create_label(20), 51 | TestLoss.create_label(20)]]) 52 | 53 | self.assertAlmostEqual(label_smoothing(logits.view(-1, VOCAB_SIZE), labels.view(-1)), 54 | cross_entropy(logits.view(-1, VOCAB_SIZE), labels.view(-1))) 55 | 56 | logits = torch.tensor(TestLoss.create_prediction(10), dtype=torch.float32) 57 | labels = torch.tensor(TestLoss.create_label(10)) 58 | self.assertTrue(label_smoothing_01(logits, labels) > cross_entropy(logits, labels)) 59 | 60 | --------------------------------------------------------------------------------