├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── azure-pipelines.yml ├── dotnet ├── CSharpSourceGraphExtraction │ ├── GraphBuilders │ │ ├── ASTGraphBuilder.cs │ │ ├── DataFlowGraphBuilder.cs │ │ ├── GuardedByGraphBuilder.cs │ │ ├── ReturnToGraphBuilder.cs │ │ └── VariableUseGraphBuilder.cs │ ├── MSRC.DPU.CSharpSourceGraphExtraction.csproj │ ├── MethodUseInformationCollector.cs │ ├── SourceGraph.cs │ └── Utils │ │ ├── IntVocabulary.cs │ │ ├── MethodUtils.cs │ │ ├── RoslynUtils.cs │ │ └── TypeHierarchy.cs ├── DPU.Utils.sln └── Utils │ ├── BidirectionalMap.cs │ ├── ChunkedJsonWriter.cs │ ├── DirectedGraph.cs │ ├── ExtensionUtils.cs │ ├── MSRC.DPU.Utils.csproj │ ├── Multimap.cs │ └── RichPath.cs └── python ├── MANIFEST.in ├── dpu_utils ├── __init__.py ├── codeutils │ ├── __init__.py │ ├── deduplication │ │ ├── __init__.py │ │ ├── deduplication.py │ │ └── deduplicationcli │ ├── filesuffix.py │ ├── identifiersplitting.py │ ├── keywords │ │ ├── __init__.py │ │ ├── c.txt │ │ ├── cpp.txt │ │ ├── csharp.txt │ │ ├── go.txt │ │ ├── java.txt │ │ ├── javascript.txt │ │ ├── keywordlist.py │ │ ├── php.txt │ │ ├── ruby.txt │ │ └── typescript.txt │ ├── lattice │ │ ├── __init__.py │ │ ├── csharplattice.py │ │ └── lattice.py │ ├── text.py │ └── treesitter │ │ ├── __init__.py │ │ └── parser.py ├── mlutils │ ├── __init__.py │ ├── bpevocabulary.py │ ├── chartensorizer.py │ └── vocabulary.py ├── ptutils │ ├── __init__.py │ ├── basecomponent.py │ └── componenttrainer.py ├── py.typed ├── tf2utils │ ├── __init__.py │ ├── activation.py │ ├── constants.py │ ├── mlp.py │ └── unsorted_segment_ops.py ├── tfmodels │ ├── __init__.py │ ├── asyncgnn.py │ └── sparsegnn.py ├── tfutils │ ├── __init__.py │ ├── activation.py │ ├── gradratiologgingoptimizer.py │ ├── pick_indices.py │ ├── tfvariablesaver.py │ └── unsortedsegmentops.py └── utils │ ├── __init__.py │ ├── chunkwriter.py │ ├── dataloading.py │ ├── debughelper.py │ ├── gitlog.py │ ├── iterators.py │ ├── msgpackloading.py │ └── richpath.py ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── codeutils ├── __init__.py ├── test_code_range.py └── test_identifiersplitting.py ├── mlutils ├── __init__.py └── test_bpevocabulary.py ├── ptutils ├── __init__.py ├── test_component.py ├── testdata.py └── testmodel.py └── utils ├── __init__.py ├── test_chunkwriter.py ├── test_iterators.py └── test_richpath.py /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | 4 | # User-specific files 5 | *.suo 6 | *.user 7 | *.userosscache 8 | *.sln.docstates 9 | 10 | # User-specific files (MonoDevelop/Xamarin Studio) 11 | *.userprefs 12 | 13 | # Build results 14 | [Dd]ebug/ 15 | [Dd]ebugPublic/ 16 | [Rr]elease/ 17 | [Rr]eleases/ 18 | [Xx]64/ 19 | [Xx]86/ 20 | [Bb]uild/ 21 | bld/ 22 | [Bb]in/ 23 | [Oo]bj/ 24 | 25 | # Visual Studio 2015 cache/options directory 26 | .vs/ 27 | # Uncomment if you have tasks that create the project's static files in wwwroot 28 | #wwwroot/ 29 | 30 | # MSTest test Results 31 | [Tt]est[Rr]esult*/ 32 | [Bb]uild[Ll]og.* 33 | 34 | # NUNIT 35 | *.VisualState.xml 36 | TestResult.xml 37 | 38 | # Build Results of an ATL Project 39 | [Dd]ebugPS/ 40 | [Rr]eleasePS/ 41 | dlldata.c 42 | 43 | # DNX 44 | project.lock.json 45 | artifacts/ 46 | 47 | *_i.c 48 | *_p.c 49 | *_i.h 50 | *.ilk 51 | *.meta 52 | *.obj 53 | *.pch 54 | *.pdb 55 | *.pgc 56 | *.pgd 57 | *.rsp 58 | *.sbr 59 | *.tlb 60 | *.tli 61 | *.tlh 62 | *.tmp 63 | *.tmp_proj 64 | *.log 65 | *.vspscc 66 | *.vssscc 67 | .builds 68 | *.pidb 69 | *.svclog 70 | *.scc 71 | 72 | # Chutzpah Test files 73 | _Chutzpah* 74 | 75 | # Visual C++ cache files 76 | ipch/ 77 | *.aps 78 | *.ncb 79 | *.opendb 80 | *.opensdf 81 | *.sdf 82 | *.cachefile 83 | *.VC.db 84 | 85 | # Visual Studio profiler 86 | *.psess 87 | *.vsp 88 | *.vspx 89 | *.sap 90 | 91 | # TFS 2012 Local Workspace 92 | $tf/ 93 | 94 | # Guidance Automation Toolkit 95 | *.gpState 96 | 97 | # ReSharper is a .NET coding add-in 98 | _ReSharper*/ 99 | *.[Rr]e[Ss]harper 100 | *.DotSettings.user 101 | 102 | # JustCode is a .NET coding add-in 103 | .JustCode 104 | 105 | # TeamCity is a build add-in 106 | _TeamCity* 107 | 108 | # DotCover is a Code Coverage Tool 109 | *.dotCover 110 | 111 | # NCrunch 112 | _NCrunch_* 113 | .*crunch*.local.xml 114 | nCrunchTemp_* 115 | 116 | # MightyMoose 117 | *.mm.* 118 | AutoTest.Net/ 119 | 120 | # Web workbench (sass) 121 | .sass-cache/ 122 | 123 | # Installshield output folder 124 | [Ee]xpress/ 125 | 126 | # DocProject is a documentation generator add-in 127 | DocProject/buildhelp/ 128 | DocProject/Help/*.HxT 129 | DocProject/Help/*.HxC 130 | DocProject/Help/*.hhc 131 | DocProject/Help/*.hhk 132 | DocProject/Help/*.hhp 133 | DocProject/Help/Html2 134 | DocProject/Help/html 135 | 136 | # Click-Once directory 137 | publish/ 138 | 139 | # Publish Web Output 140 | *.[Pp]ublish.xml 141 | *.azurePubxml 142 | 143 | # TODO: Un-comment the next line if you do not want to checkin 144 | # your web deploy settings because they may include unencrypted 145 | # passwords 146 | #*.pubxml 147 | *.publishproj 148 | 149 | # NuGet Packages 150 | *.nupkg 151 | # The packages folder can be ignored because of Package Restore 152 | **/packages/* 153 | # except build/, which is used as an MSBuild target. 154 | !**/packages/build/ 155 | # Uncomment if necessary however generally it will be regenerated when needed 156 | #!**/packages/repositories.config 157 | # NuGet v3's project.json files produces more ignoreable files 158 | *.nuget.props 159 | *.nuget.targets 160 | 161 | # Microsoft Azure Build Output 162 | csx/ 163 | *.build.csdef 164 | 165 | # Microsoft Azure Emulator 166 | ecf/ 167 | rcf/ 168 | 169 | # Windows Store app package directory 170 | AppPackages/ 171 | BundleArtifacts/ 172 | 173 | # Visual Studio cache files 174 | # files ending in .cache can be ignored 175 | *.[Cc]ache 176 | # but keep track of directories ending in .cache 177 | !*.[Cc]ache/ 178 | 179 | # Others 180 | ClientBin/ 181 | [Ss]tyle[Cc]op.* 182 | ~$* 183 | *~ 184 | *.dbmdl 185 | *.dbproj.schemaview 186 | *.pfx 187 | *.publishsettings 188 | node_modules/ 189 | orleans.codegen.cs 190 | 191 | # RIA/Silverlight projects 192 | Generated_Code/ 193 | 194 | # Backup & report files from converting an old project file 195 | # to a newer Visual Studio version. Backup files are not needed, 196 | # because we have git ;-) 197 | _UpgradeReport_Files/ 198 | Backup*/ 199 | UpgradeLog*.XML 200 | UpgradeLog*.htm 201 | 202 | # SQL Server files 203 | *.mdf 204 | *.ldf 205 | 206 | # Business Intelligence projects 207 | *.rdl.data 208 | *.bim.layout 209 | *.bim_*.settings 210 | 211 | # Microsoft Fakes 212 | FakesAssemblies/ 213 | 214 | # GhostDoc plugin setting file 215 | *.GhostDoc.xml 216 | 217 | # Node.js Tools for Visual Studio 218 | .ntvs_analysis.dat 219 | 220 | # Visual Studio 6 build log 221 | *.plg 222 | 223 | # Visual Studio 6 workspace options file 224 | *.opt 225 | 226 | # Visual Studio LightSwitch build output 227 | **/*.HTMLClient/GeneratedArtifacts 228 | **/*.DesktopClient/GeneratedArtifacts 229 | **/*.DesktopClient/ModelManifest.xml 230 | **/*.Server/GeneratedArtifacts 231 | **/*.Server/ModelManifest.xml 232 | _Pvt_Extensions 233 | 234 | # LightSwitch generated files 235 | GeneratedArtifacts/ 236 | ModelManifest.xml 237 | 238 | # Paket dependency manager 239 | .paket/paket.exe 240 | 241 | # FAKE - F# Make 242 | .fake/ 243 | 244 | # Profiling 245 | *.pyperf 246 | *.zip 247 | /allprojects.txt 248 | 249 | # Python: 250 | *.pyc 251 | *.so 252 | *.gz 253 | *.pkl 254 | .mypy_cache/ 255 | .vscode/ 256 | 257 | 258 | /latex/*.pdf 259 | /latex/*.out 260 | /latex/*.fls 261 | /latex/*.fdb_latexmk 262 | /latex/*.blg 263 | /latex/*.aux 264 | *.bbl 265 | 266 | .idea/ 267 | 268 | # Setuptools distribution folder. 269 | /dist/ 270 | 271 | # Python egg metadata, regenerated from source files by setuptools. 272 | /*.egg-info 273 | .eggs/ 274 | 275 | # Files generated by Coverage.py. 276 | .coverage 277 | htmlcov/ 278 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DPU Utilities ![PyPI - Python Version](https://img.shields.io/pypi/v/dpu-utils)![Anaconda](https://anaconda.org/conda-forge/dpu-utils/badges/version.svg) 2 | [![Build Status](https://deepproceduralintelligence.visualstudio.com/dpu-utils/_apis/build/status/Microsoft.dpu-utils?branchName=master)](https://deepproceduralintelligence.visualstudio.com/dpu-utils/_build/latest?definitionId=3) 3 | 4 | 5 | This contains a set of utilities used across projects of the [DPU team](https://www.microsoft.com/en-us/research/project/program/). 6 | 7 | ## Python 8 | 9 | Stored in the `python` subdirectory, published as the `dpu-utils` package. 10 | 11 | ### Installation 12 | 13 | ```bash 14 | pip install dpu-utils 15 | ``` 16 | OR via the community-maintained Conda recipe: 17 | ```bash 18 | conda install -c conda-forge dpu-utils 19 | ``` 20 | 21 | ### Overview 22 | Below you can find an overview of the utilities included. Detailed documentation 23 | is provided at the docstring of each class. 24 | 25 | ##### Generic Utilities `dpu_utils.utils` 26 | * [`ChunkWriter`](python/dpu_utils/utils/chunkwriter.py) provides a convenient API for writing output in multiple parts (chunks). 27 | * [`RichPath`](python/dpu_utils/utils/richpath.py) an API that abstract local and Azure Blob paths in your code. 28 | * [`*Iterator`](python/dpu_utils/utils/iterators.py) Wrappers that can parallelize and shuffle iterators. 29 | * [`{load,save}_json[l]_gz`](python/dpu_utils/utils/dataloading.py) convenience API for loading and writing `.json[l].gz` files. 30 | * [`git_tag_run`](python/dpu_utils/utils/gitlog.py) tags the current working directory git the state of the code. 31 | * [`run_and_debug`](python/dpu_utils/utils/debughelper.py) when an exception happens, start a debug session. Usually a wrapper of `__main__`. 32 | 33 | ##### General Machine Learning Utilities `dpu_utils.mlutils` 34 | * [`Vocabulary`](python/dpu_utils/mlutils/vocabulary.py) map elements into unique integer ids and back. 35 | Commonly used in machine learning models that work over discrete data (e.g. 36 | words in NLP). Contains methods for converting an list of tokens into their 37 | "tensorized" for of integer ids. 38 | * [`BpeVocabulary`](python/dpu_utils/mlutils/bpevocabulary.py) a vocabulary for machine learning models that employs BPE (via `sentencepiece`). 39 | * [`CharTensorizer`](python/dpu_utils/mlutils/chartensorizer.py) convert character sequences into into tensors, commonly used 40 | in machine learning models whose input is a list of characters. 41 | 42 | ##### Code-related Utilities `dpu_utils.codeutils` 43 | * [`split_identifier_into_parts()`](python/dpu_utils/codeutils/identifiersplitting.py) split identifiers into subtokens on CamelCase and snake_case. 44 | * [`Lattice`](python/dpu_utils/codeutils/lattice/lattice.py), [`CSharpLattice`](python/dpu_utils/codeutils/lattice/csharplattice.py) represent lattices and useful operations on lattices in Python. 45 | * [`get_language_keywords()`](python/dpu_utils/codeutils/keywords/keywordlist.py) an API to retrieve the keyword tokens for many programming languages. 46 | * [`language_candidates_from_suffix()`](python/dpu_utils/codeutils/filesuffix.py) a function to retrieve the candidate language given the file suffix. 47 | * [`deduplication.DuplicateDetector`](python/dpu_utils/codeutils/deduplication/deduplication.py) API to detects (near)duplicates in codebases. 48 | See also [here](#approximate-duplicate-code-detection) for a command line tool. 49 | * [`treesitter.parser_for`](python/dpu_utils/codeutils/treesitter/parser.py) get [Tree-sitter](https://tree-sitter.github.io/tree-sitter/) parser by language name. 50 | 51 | ##### TensorFlow 1.x Utilities `dpu_utils.tfutils` 52 | * [`get_activation`](python/dpu_utils/tfutils/activation.py) retrieve activations function by name. 53 | * [`GradRatioLoggingOptimizer`](python/dpu_utils/tfutils/gradratiologgingoptimizer.py) a wrapper around optimizers that logs the ratios of grad norms to parameter norms. 54 | * [`TFVariableSaver`](python/dpu_utils/tfutils/tfvariablesaver.py) save TF variables in an object that can be pickled. 55 | 56 | Unsorted segment operations following TensorFlow's [`unsorted_segment_sum`](https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum) operations: 57 | * [`unsorted_segment_logsumexp`](python/dpu_utils/tfutils/unsortedsegmentops.py) 58 | * [`unsorted_segment_log_softmax`](python/dpu_utils/tfutils/unsortedsegmentops.py) 59 | * [`unsorted_segment_softmax`](python/dpu_utils/tfutils/unsortedsegmentops.py) 60 | 61 | ##### TensorFlow 2.x Utilities `dpu_utils.tf2utils` 62 | * [`get_activation_function_by_name`](python/dpu_utils/tf2utils/activation.py) retrieve activation functions by name. 63 | * [`gelu`](python/dpu_utils/tf2utils/activation.py) The GeLU activation function. 64 | * [`MLP`](python/dpu_utils/tf2utils/mlp.py) An MLP layer. 65 | 66 | Unsorted segment operations following TensorFlow's [`unsorted_segment_sum`](https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum) operations: 67 | * [`unsorted_segment_logsumexp`](python/dpu_utils/tf2utils/unsorted_segment_ops.py) 68 | * [`unsorted_segment_log_softmax`](python/dpu_utils/tf2utils/unsorted_segment_ops.py) 69 | * [`unsorted_segment_softmax`](python/dpu_utils/tf2utils/unsorted_segment_ops.py) 70 | 71 | 72 | ##### TensorFlow Models `dpu_utils.tfmodels` 73 | * [`SparseGGNN`](python/dpu_utils/tfmodels/sparsegnn.py) a sparse GGNN implementation. 74 | * [`AsyncGGNN`](python/dpu_utils/tfmodels/asyncgnn.py) an asynchronous GGNN implementation. 75 | 76 | These models have not been tested with TF 2.0. 77 | 78 | ##### PyTorch Utilities `dpu_utils.ptutils` 79 | * [`BaseComponent`](python/dpu_utils/ptutils/basecomponent.py) a wrapper abstract class around `nn.Module` that 80 | takes care of essential elements of most neural network components. 81 | * [`ComponentTrainer`](python/dpu_utils/ptutils/basecomponent.py) a training loop for `BaseComponent`s. 82 | 83 | 84 | ### Command-line tools 85 | 86 | #### Approximate Duplicate Code Detection 87 | You can use the `deduplicationcli` command to detect duplicates in pre-processed source code, by invoking 88 | ```bash 89 | deduplicationcli DATA_PATH OUT_JSON 90 | ``` 91 | where `DATA_PATH` is a file containing tokenized `.jsonl.gz` files and `OUT_JSON` is the target output file. 92 | For more options look at `--help`. 93 | 94 | An exact (but usually slower) version of this can be found [here](https://github.com/Microsoft/near-duplicate-code-detector) 95 | along with code to tokenize Java, C#, Python and JavaScript into the relevant formats. 96 | 97 | ### Tests 98 | 99 | #### Run the unit tests 100 | 101 | ```bash 102 | python setup.py test 103 | ``` 104 | 105 | #### Generate code coverage reports 106 | 107 | ```bash 108 | # pip install coverage 109 | coverage run --source dpu_utils/ setup.py test && \ 110 | coverage html 111 | ``` 112 | 113 | The resulting HTML file will be in `htmlcov/index.html`. 114 | 115 | ## .NET 116 | 117 | Stored in the `dotnet` subdirectory. 118 | 119 | Generic Utilities: 120 | * `Microsoft.Research.DPU.Utils.RichPath`: a convenient way of using both paths and Azure paths in your code. 121 | 122 | Code-related Utilities: 123 | * `Microsoft.Research.DPU.CSharpSourceGraphExtraction`: infrastructure to extract Program Graphs from C# projects. 124 | 125 | # Contributing 126 | 127 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 128 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 129 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 130 | 131 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 132 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 133 | provided by the bot. You will only need to do this once across all repos using our CLA. 134 | 135 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 136 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 137 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 138 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | jobs: 2 | 3 | - job: 'Python_Test' 4 | pool: 5 | vmImage: 'Ubuntu 20.04' 6 | strategy: 7 | matrix: 8 | Python36: 9 | python.version: '3.6' 10 | Python37: 11 | python.version: '3.7' 12 | Python38: 13 | python.version: '3.8' 14 | maxParallel: 4 15 | 16 | steps: 17 | - task: UsePythonVersion@0 18 | inputs: 19 | versionSpec: '$(python.version)' 20 | architecture: 'x64' 21 | 22 | - script: python -m pip install --upgrade pip && pip install -r requirements.txt && pip install torch==1.6.0 23 | displayName: 'Install dependencies' 24 | workingDirectory: 'python/' 25 | 26 | - task: Bash@3 27 | displayName: "Install Azurite" 28 | inputs: 29 | targetType: 'inline' 30 | script: 'sudo npm install -g azurite' 31 | 32 | - task: Bash@3 33 | displayName: "Run Azurite" 34 | inputs: 35 | targetType: 'inline' 36 | script: 'sudo azurite --silent -l /tmp --loose &' 37 | 38 | - script: | 39 | pip install pytest 40 | pytest tests --doctest-modules --junitxml=junit/test-results.xml 41 | displayName: 'pytest' 42 | workingDirectory: 'python/' 43 | 44 | - task: PublishTestResults@2 45 | inputs: 46 | testResultsFiles: 'python/**/test-results.xml' 47 | testRunTitle: 'Python $(python.version)' 48 | condition: succeededOrFailed() 49 | 50 | - job: 'Python_Publish' 51 | dependsOn: 'Python_Test' 52 | pool: 53 | vmImage: 'Ubuntu 20.04' 54 | 55 | steps: 56 | - task: UsePythonVersion@0 57 | inputs: 58 | versionSpec: '3.x' 59 | architecture: 'x64' 60 | 61 | - script: python setup.py sdist 62 | displayName: 'Build sdist' 63 | workingDirectory: 'python/' 64 | 65 | - task: PublishBuildArtifacts@1 66 | displayName: 'Publish artifact: dist' 67 | condition: and(succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/master')) 68 | inputs: 69 | pathtoPublish: './' 70 | artifactName: 'dist' 71 | 72 | - job: 'DotNet_Build' 73 | pool: 74 | vmImage: 'Ubuntu 20.04' 75 | 76 | steps: 77 | - task: UseDotNet@2 78 | displayName: 'Use .NET Core sdk' 79 | inputs: 80 | packageType: sdk 81 | version: 3.x 82 | installationPath: $(Agent.ToolsDirectory)/dotnet 83 | 84 | - script: dotnet restore 85 | workingDirectory: 'dotnet/' 86 | 87 | - task: DotNetCoreCLI@2 88 | displayName: Build 89 | inputs: 90 | command: build 91 | projects: 'dotnet/**/*.csproj' 92 | arguments: '--configuration Release' 93 | 94 | # One of these days, we should add tests: 95 | #- task: DotNetCoreCLI@2 96 | # inputs: 97 | # command: test 98 | # projects: 'dotnet/**/*Tests/*.csproj' 99 | # arguments: '--configuration $(buildConfiguration)' 100 | 101 | - job: 'DotNet_Publish' 102 | dependsOn: 'DotNet_Build' 103 | pool: 104 | vmImage: 'Ubuntu 20.04' 105 | 106 | steps: 107 | - script: dotnet pack 108 | workingDirectory: 'dotnet/' 109 | 110 | # - task: NuGetCommand@2 111 | # condition: and(succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/master')) 112 | # inputs: 113 | # command: push 114 | # nuGetFeedType: external 115 | # publishFeedCredentials: 'mabrocks-DPU.Utils-NuGet' 116 | # packagesToPush: 'dotnet/**/*.nupkg' 117 | -------------------------------------------------------------------------------- /dotnet/CSharpSourceGraphExtraction/GraphBuilders/ASTGraphBuilder.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.CodeAnalysis; 2 | using Microsoft.CodeAnalysis.CSharp; 3 | using Microsoft.CodeAnalysis.CSharp.Syntax; 4 | using System.Linq; 5 | 6 | namespace MSRC.DPU.CSharpSourceGraphExtraction.GraphBuilders 7 | { 8 | internal class ASTGraphBuilder : CSharpSyntaxWalker 9 | { 10 | private readonly SourceGraph _graph; 11 | private SyntaxToken? _lastAddedToken; 12 | 13 | private ASTGraphBuilder(SourceGraph graph) 14 | { 15 | _graph = graph; 16 | } 17 | 18 | public static void AddASTGraph(SourceGraph sourceGraph) 19 | { 20 | new ASTGraphBuilder(sourceGraph).Visit(sourceGraph.SemanticModel.SyntaxTree.GetRoot()); 21 | } 22 | 23 | private void AddToken(SyntaxToken token) 24 | { 25 | if (_lastAddedToken.HasValue) 26 | { 27 | _graph.AddEdge(_lastAddedToken.Value, SourceGraphEdge.NextToken, token); 28 | } 29 | _lastAddedToken = token; 30 | } 31 | 32 | public override void Visit(SyntaxNode node) 33 | { 34 | foreach (var child in node.ChildNodesAndTokens()) 35 | { 36 | if (!node.DescendantNodes().Any(n=>n is BaseMethodDeclarationSyntax || n is PropertyDeclarationSyntax)) 37 | { 38 | _graph.AddEdge(node, SourceGraphEdge.Child, child); 39 | } 40 | if (child.IsNode) 41 | { 42 | Visit(child.AsNode()); 43 | } 44 | else 45 | { 46 | AddToken(child.AsToken()); 47 | } 48 | } 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /dotnet/CSharpSourceGraphExtraction/GraphBuilders/DataFlowGraphBuilder.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using Microsoft.CodeAnalysis; 3 | using Microsoft.CodeAnalysis.CSharp; 4 | using MSRC.DPU.Utils; 5 | 6 | namespace MSRC.DPU.CSharpSourceGraphExtraction.GraphBuilders 7 | { 8 | internal static class DataFlowGraphBuilder 9 | { 10 | public static void AddDataFlowEdges(SourceGraph sourceGraph, SyntaxNodeOrToken tokenOfInterest, 11 | ICollection forbiddenNodes = null, 12 | ICollection> addedEdges = null) 13 | { 14 | var semanticModel = sourceGraph.SemanticModel; 15 | 16 | //There's nothing before the declaration, so we don't need to bother: 17 | if (sourceGraph.VariableDeclarationNodes.Contains(tokenOfInterest)) 18 | { 19 | return; 20 | } 21 | 22 | //We only ever need to visit each node once, so collect visited nodes here: 23 | var visitedNodes = new HashSet<(SyntaxNodeOrToken, bool)>(); 24 | 25 | //Start from all predecessors of the token of interest: 26 | var toVisit = new Stack<(SyntaxNodeOrToken node, bool haveFoundUse)>(); 27 | foreach (var (_, label, target) in sourceGraph.GetOutEdges(tokenOfInterest)) 28 | { 29 | if (label != SourceGraphEdge.LastUsedVariable || (forbiddenNodes?.Contains(target) ?? false)) 30 | { 31 | continue; 32 | } 33 | if (visitedNodes.Add((target, false))) 34 | { 35 | toVisit.Push((target, false)); 36 | } 37 | } 38 | 39 | var nodeOfInterest = tokenOfInterest.IsToken ? tokenOfInterest.AsToken().Parent : tokenOfInterest.AsNode(); 40 | ISymbol symbolToLookFor = nodeOfInterest != null ? semanticModel.GetSymbolInfo(nodeOfInterest).Symbol?.OriginalDefinition : null; 41 | string nodeLabelToLookFor = tokenOfInterest.ToString(); 42 | 43 | while (toVisit.Count > 0) 44 | { 45 | var (node, haveFoundUse) = toVisit.Pop(); 46 | var nodeSyntaxNode = node.IsToken ? node.AsToken().Parent : node.AsNode(); 47 | var nodeSymbol = nodeSyntaxNode != null ? semanticModel.GetSymbolInfo(nodeSyntaxNode).Symbol?.OriginalDefinition : null; 48 | 49 | bool matches; 50 | if (symbolToLookFor == null || nodeSymbol == null) 51 | { 52 | // This may happen in cases where Roslyn doesn't have symbol info 53 | // or when one of the nodes is a dummy node (and thus doesn't belong to the SyntaxTree) 54 | matches = node.ToString().Equals(nodeLabelToLookFor); 55 | } 56 | else 57 | { 58 | matches = nodeSymbol.Equals(symbolToLookFor); 59 | } 60 | 61 | if (matches) 62 | { 63 | if (!haveFoundUse) 64 | { 65 | var lastUseEdge = new Edge(tokenOfInterest, SourceGraphEdge.LastUse, node); 66 | if (sourceGraph.AddEdge(lastUseEdge)) 67 | { 68 | addedEdges?.Add(lastUseEdge); 69 | } 70 | haveFoundUse = true; 71 | } 72 | 73 | if (sourceGraph.VariableWriteNodes.Contains(node)) 74 | { 75 | var lastWriteEdge = new Edge(tokenOfInterest, SourceGraphEdge.LastWrite, node); 76 | if (sourceGraph.AddEdge(lastWriteEdge)) 77 | { 78 | addedEdges?.Add(lastWriteEdge); 79 | } 80 | //We are done with this path -- we found a use and a write! 81 | continue; 82 | } 83 | 84 | //There's nothing before the declaration, so we don't need to bother to recurse further: 85 | if (sourceGraph.VariableDeclarationNodes.Contains(node)) 86 | { 87 | continue; 88 | } 89 | } 90 | 91 | foreach (var (_, label, target) in sourceGraph.GetOutEdges(node)) 92 | { 93 | if (label != SourceGraphEdge.LastUsedVariable || (forbiddenNodes?.Contains(target) ?? false)) 94 | { 95 | continue; 96 | } 97 | if (visitedNodes.Add((target, haveFoundUse))) { 98 | toVisit.Push((target, haveFoundUse)); 99 | } 100 | } 101 | } 102 | } 103 | 104 | /// 105 | /// Adds LastUse/LastWrite dataflow edges to SourceGraph. 106 | /// Requires LastUsedVariables in the graph. 107 | /// 108 | public static void AddDataFlowGraph(SourceGraph sourceGraph) 109 | { 110 | foreach (var tokenOfInterest in sourceGraph.VariableUseNodes) 111 | { 112 | AddDataFlowEdges(sourceGraph, tokenOfInterest); 113 | } 114 | } 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /dotnet/CSharpSourceGraphExtraction/GraphBuilders/GuardedByGraphBuilder.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.CodeAnalysis; 2 | using Microsoft.CodeAnalysis.CSharp; 3 | using Microsoft.CodeAnalysis.CSharp.Syntax; 4 | using System.Collections.Generic; 5 | using System.Linq; 6 | 7 | namespace MSRC.DPU.CSharpSourceGraphExtraction.GraphBuilders 8 | { 9 | internal class BlockGuardInformation 10 | { 11 | /// 12 | /// Guard information for enclosing block, potentially null if there is none. 13 | /// 14 | public readonly BlockGuardInformation EnclosingBlockInformation; 15 | 16 | private readonly SourceGraph _graph; 17 | 18 | private readonly List<(SyntaxNode guardNode, ISet usedVariables)> _validatedGuards 19 | = new List<(SyntaxNode guardNode, ISet usedVariables)>(); 20 | 21 | private readonly List<(SyntaxNode guardNode, ISet usedVariables)> _invalidatedGuards 22 | = new List<(SyntaxNode guardNode, ISet usedVariables)>(); 23 | 24 | /// 25 | /// SyntaxNodes of guards that were checked to hold for this block. 26 | /// 27 | public IEnumerable<(SyntaxNode guardNode, ISet usedVariables)> ValidatedGuards => _validatedGuards; 28 | 29 | /// 30 | /// SyntaxNodes of guards that were checked to not hold in this block. 31 | /// 32 | public IEnumerable<(SyntaxNode guardNode, ISet usedVariables)> InvalidatedGuards => _invalidatedGuards; 33 | 34 | public BlockGuardInformation(SourceGraph graph, BlockGuardInformation enclosingBlockInfo = null, SyntaxNode validatedGuard = null, SyntaxNode invalidatedGuard = null) 35 | { 36 | _graph = graph; 37 | EnclosingBlockInformation = enclosingBlockInfo; 38 | 39 | if (validatedGuard != null) 40 | { 41 | RecordValidatedGuard(validatedGuard); 42 | } 43 | 44 | if (invalidatedGuard != null) 45 | { 46 | RecordInvalidatedGuard(invalidatedGuard); 47 | } 48 | } 49 | 50 | private ISet GetUsedVariables(SyntaxNode node) 51 | => new HashSet(node.DescendantTokens().Where(tok => _graph.VariableUseNodes.Contains(tok)).Select(tok => tok.Text)); 52 | 53 | public void RecordValidatedGuard(SyntaxNode guardNode) 54 | => _validatedGuards.Add((guardNode, usedVariables: GetUsedVariables(guardNode))); 55 | 56 | public void RecordValidatedGuard((SyntaxNode guardNode, ISet usedVariables) validatedGuardInfo) 57 | => _validatedGuards.Add(validatedGuardInfo); 58 | 59 | public void RecordInvalidatedGuard(SyntaxNode guardNode) 60 | => _invalidatedGuards.Add((guardNode, usedVariables: GetUsedVariables(guardNode))); 61 | 62 | public void RecordInvalidatedGuard((SyntaxNode guardNode, ISet usedVariables) invalidatedGuardInfo) 63 | => _invalidatedGuards.Add(invalidatedGuardInfo); 64 | } 65 | 66 | internal class GuardedByGraphBuilder : CSharpSyntaxWalker 67 | { 68 | private readonly SourceGraph _graph; 69 | 70 | private BlockGuardInformation _blockGuardInformation; 71 | 72 | private GuardedByGraphBuilder(SourceGraph graph) 73 | { 74 | _graph = graph; 75 | } 76 | 77 | public static void AddGuardedByGraph(SourceGraph graph) 78 | { 79 | new GuardedByGraphBuilder(graph).Visit(graph.SemanticModel.SyntaxTree.GetRoot()); 80 | } 81 | 82 | private void VisitVariable(SyntaxToken identifierToken) 83 | { 84 | var curGuardInformation = _blockGuardInformation; 85 | while (curGuardInformation != null) 86 | { 87 | foreach (var validatedGuard in curGuardInformation.ValidatedGuards) 88 | { 89 | _graph.AddEdge(identifierToken, SourceGraphEdge.GuardedBy, validatedGuard.guardNode); 90 | } 91 | 92 | foreach (var invalidatedGuard in curGuardInformation.InvalidatedGuards) 93 | { 94 | _graph.AddEdge(identifierToken, SourceGraphEdge.GuardedByNegation, invalidatedGuard.guardNode); 95 | } 96 | 97 | curGuardInformation = curGuardInformation.EnclosingBlockInformation; 98 | } 99 | } 100 | 101 | public override void VisitIdentifierName(IdentifierNameSyntax node) 102 | { 103 | if (_graph.VariableUseNodes.Contains(node.Identifier)) 104 | { 105 | VisitVariable(node.Identifier); 106 | } 107 | } 108 | 109 | #region Method (and related) declarations 110 | private void HandleBaseMethodDeclaration(BaseMethodDeclarationSyntax node) 111 | { 112 | // Roughly, set up a new block context, descent and then pop that context again. 113 | _blockGuardInformation = new BlockGuardInformation(_graph); 114 | Visit(node.Body); 115 | _blockGuardInformation = null; 116 | } 117 | 118 | public override void VisitConversionOperatorDeclaration(ConversionOperatorDeclarationSyntax node) 119 | => HandleBaseMethodDeclaration(node); 120 | 121 | public override void VisitConstructorDeclaration(ConstructorDeclarationSyntax node) 122 | => HandleBaseMethodDeclaration(node); 123 | 124 | public override void VisitMethodDeclaration(MethodDeclarationSyntax node) 125 | => HandleBaseMethodDeclaration(node); 126 | 127 | public override void VisitOperatorDeclaration(OperatorDeclarationSyntax node) 128 | => HandleBaseMethodDeclaration(node); 129 | 130 | public override void VisitDestructorDeclaration(DestructorDeclarationSyntax node) 131 | => HandleBaseMethodDeclaration(node); 132 | #endregion 133 | 134 | #region Statements that create new conditional blocks (if, loops, ...) 135 | private void HandleConditionalBlock(IEnumerable bodyNodes, SyntaxNode validatedGuard = null, SyntaxNode invalidatedGuard = null) 136 | { 137 | // Roughly, set up a new block context with condition, descent and then pop that context again. 138 | _blockGuardInformation = new BlockGuardInformation(_graph, enclosingBlockInfo: _blockGuardInformation, 139 | validatedGuard: validatedGuard, invalidatedGuard: invalidatedGuard); 140 | foreach (var bodyNode in bodyNodes) 141 | { 142 | Visit(bodyNode); 143 | } 144 | _blockGuardInformation = _blockGuardInformation.EnclosingBlockInformation; 145 | } 146 | 147 | public override void VisitIfStatement(IfStatementSyntax node) 148 | { 149 | Visit(node.Condition); 150 | HandleConditionalBlock(new[] { node.Statement }, validatedGuard: node.Condition); 151 | HandleConditionalBlock(new[] { node.Else }, invalidatedGuard: node.Condition); 152 | } 153 | 154 | public override void VisitConditionalExpression(ConditionalExpressionSyntax node) 155 | { 156 | Visit(node.Condition); 157 | HandleConditionalBlock(new[] { node.WhenTrue }, validatedGuard: node.Condition); 158 | HandleConditionalBlock(new[] { node.WhenFalse }, invalidatedGuard: node.Condition); 159 | } 160 | 161 | public override void VisitSwitchStatement(SwitchStatementSyntax node) 162 | { 163 | foreach (var switchSection in node.Sections) 164 | { 165 | //TODO: Can we do something with the labels? 166 | HandleConditionalBlock(switchSection.Statements); 167 | } 168 | } 169 | 170 | public override void VisitForStatement(ForStatementSyntax node) 171 | { 172 | //Declaration and condition are unguarded, but rest is guarded: 173 | Visit(node.Declaration); 174 | Visit(node.Condition); 175 | HandleConditionalBlock(new SyntaxNode[] { node.Statement }.Concat(node.Incrementors), validatedGuard: node.Condition); 176 | } 177 | 178 | public override void VisitForEachStatement(ForEachStatementSyntax node) 179 | { 180 | Visit(node.Expression); 181 | if (_graph.VariableUseNodes.Contains(node.Identifier)) 182 | { 183 | // This is a special case required because ForEachStatements require a raw identifier 184 | // token (but matches what HandleConditionalBlock does): 185 | _blockGuardInformation = new BlockGuardInformation(_graph, enclosingBlockInfo: _blockGuardInformation, 186 | validatedGuard: node.Expression, invalidatedGuard: null); 187 | VisitVariable(node.Identifier); 188 | _blockGuardInformation = _blockGuardInformation.EnclosingBlockInformation; 189 | } 190 | HandleConditionalBlock(new SyntaxNode[] { node.Statement }, validatedGuard: node.Expression); 191 | } 192 | 193 | public override void VisitForEachVariableStatement(ForEachVariableStatementSyntax node) 194 | { 195 | Visit(node.Expression); 196 | HandleConditionalBlock(new SyntaxNode[] { node.Variable, node.Statement }, validatedGuard: node.Expression); 197 | } 198 | 199 | public override void VisitWhileStatement(WhileStatementSyntax node) 200 | { 201 | Visit(node.Condition); 202 | HandleConditionalBlock(new SyntaxNode[] { node.Statement }, validatedGuard: node.Condition); 203 | } 204 | 205 | public override void VisitDoStatement(DoStatementSyntax node) 206 | { 207 | //TODO: Can we do something with the condition, even if it's not checked in the first iteration? 208 | HandleConditionalBlock(new SyntaxNode[] { node.Statement }); 209 | Visit(node.Condition); 210 | } 211 | #endregion 212 | 213 | #region Statements that modify control flow in conditional blocks (break, continue, return) 214 | private void HandleControlFlowBreak() 215 | { 216 | /* We know that if execution continues in this method after this return, 217 | * that's because the condition for this return did not hold. Thus, we can 218 | * consider the valid guards of the current lock as invalidated for its 219 | * enclosing block in the subsequent analysis (and similarly for invalidated 220 | * guards). 221 | */ 222 | var enclosingBlockInformation = _blockGuardInformation?.EnclosingBlockInformation; 223 | if (enclosingBlockInformation != null) 224 | { 225 | foreach (var validatedGuard in _blockGuardInformation.ValidatedGuards) 226 | { 227 | enclosingBlockInformation.RecordInvalidatedGuard(validatedGuard); 228 | } 229 | foreach (var invalidatedGuard in _blockGuardInformation.InvalidatedGuards) 230 | { 231 | enclosingBlockInformation.RecordValidatedGuard(invalidatedGuard); 232 | } 233 | } 234 | } 235 | 236 | public override void VisitReturnStatement(ReturnStatementSyntax node) 237 | { 238 | Visit(node.Expression); 239 | HandleControlFlowBreak(); 240 | } 241 | 242 | public override void VisitBreakStatement(BreakStatementSyntax node) => HandleControlFlowBreak(); 243 | public override void VisitContinueStatement(ContinueStatementSyntax node) => HandleControlFlowBreak(); 244 | #endregion 245 | } 246 | } 247 | -------------------------------------------------------------------------------- /dotnet/CSharpSourceGraphExtraction/GraphBuilders/ReturnToGraphBuilder.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using System.Diagnostics; 3 | using Microsoft.CodeAnalysis; 4 | using Microsoft.CodeAnalysis.CSharp; 5 | using Microsoft.CodeAnalysis.CSharp.Syntax; 6 | 7 | namespace MSRC.DPU.CSharpSourceGraphExtraction.GraphBuilders 8 | { 9 | internal class ReturnToGraphBuilder : CSharpSyntaxWalker 10 | { 11 | private readonly SourceGraph _graph; 12 | private readonly Stack _returningPoint; 13 | 14 | private ReturnToGraphBuilder(SourceGraph graph) 15 | { 16 | _graph = graph; 17 | _returningPoint = new Stack(); 18 | } 19 | 20 | public static void AddReturnToGraph(SourceGraph graph) 21 | { 22 | new ReturnToGraphBuilder(graph).Visit(graph.SemanticModel.SyntaxTree.GetRoot()); 23 | } 24 | 25 | public override void VisitMethodDeclaration(MethodDeclarationSyntax node) 26 | { 27 | if (node.Body == null && node.ExpressionBody == null) 28 | { 29 | return; // Don't bother with abstract methods 30 | } 31 | 32 | _returningPoint.Push(node.Identifier); 33 | base.VisitMethodDeclaration(node); 34 | _returningPoint.Pop(); 35 | } 36 | 37 | public override void VisitConstructorDeclaration(ConstructorDeclarationSyntax node) 38 | { 39 | _returningPoint.Push(node.Identifier); 40 | base.VisitConstructorDeclaration(node); 41 | _returningPoint.Pop(); 42 | } 43 | 44 | public override void VisitDestructorDeclaration(DestructorDeclarationSyntax node) 45 | { 46 | _returningPoint.Push(node.Identifier); 47 | base.VisitDestructorDeclaration(node); 48 | _returningPoint.Pop(); 49 | } 50 | 51 | public override void VisitConversionOperatorDeclaration(ConversionOperatorDeclarationSyntax node) 52 | { 53 | _returningPoint.Push(node); 54 | base.VisitConversionOperatorDeclaration(node); 55 | _returningPoint.Pop(); 56 | } 57 | 58 | public override void VisitOperatorDeclaration(OperatorDeclarationSyntax node) 59 | { 60 | _returningPoint.Push(node.OperatorToken); 61 | base.VisitOperatorDeclaration(node); 62 | _returningPoint.Pop(); 63 | } 64 | 65 | public override void VisitPropertyDeclaration(PropertyDeclarationSyntax node) 66 | { 67 | _returningPoint.Push(node.Identifier); 68 | base.VisitPropertyDeclaration(node); 69 | _returningPoint.Pop(); 70 | } 71 | 72 | public override void VisitIndexerDeclaration(IndexerDeclarationSyntax node) 73 | { 74 | _returningPoint.Push(node); 75 | base.VisitIndexerDeclaration(node); 76 | _returningPoint.Pop(); 77 | } 78 | 79 | public override void VisitEventDeclaration(EventDeclarationSyntax node) 80 | { 81 | _returningPoint.Push(node.Identifier); 82 | base.VisitEventDeclaration(node); 83 | _returningPoint.Pop(); 84 | } 85 | 86 | public override void VisitSimpleLambdaExpression(SimpleLambdaExpressionSyntax node) 87 | { 88 | _returningPoint.Push(node); 89 | base.VisitSimpleLambdaExpression(node); 90 | _returningPoint.Pop(); 91 | } 92 | 93 | public override void VisitParenthesizedLambdaExpression(ParenthesizedLambdaExpressionSyntax node) 94 | { 95 | _returningPoint.Push(node); 96 | base.VisitParenthesizedLambdaExpression(node); 97 | _returningPoint.Pop(); 98 | } 99 | 100 | public override void VisitLocalFunctionStatement(LocalFunctionStatementSyntax node) 101 | { 102 | _returningPoint.Push(node.Identifier); 103 | base.VisitLocalFunctionStatement(node); 104 | _returningPoint.Pop(); 105 | } 106 | 107 | public override void VisitAnonymousMethodExpression(AnonymousMethodExpressionSyntax node) 108 | { 109 | _returningPoint.Push(node); 110 | base.VisitAnonymousMethodExpression(node); 111 | _returningPoint.Pop(); 112 | } 113 | 114 | public override void VisitReturnStatement(ReturnStatementSyntax node) 115 | { 116 | Debug.Assert(_returningPoint.Count > 0); 117 | if (node.Expression != null) 118 | { 119 | _graph.AddEdge(node.Expression, SourceGraphEdge.ReturnsTo, _returningPoint.Peek()); 120 | } 121 | base.VisitReturnStatement(node); 122 | } 123 | 124 | public override void VisitYieldStatement(YieldStatementSyntax node) 125 | { 126 | bool isReturnStatement = node.IsKind(SyntaxKind.YieldReturnStatement); 127 | if (isReturnStatement && node.Expression != null) 128 | { 129 | _graph.AddEdge(node.Expression, SourceGraphEdge.ReturnsTo, _returningPoint.Peek()); 130 | } 131 | base.VisitYieldStatement(node); 132 | } 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /dotnet/CSharpSourceGraphExtraction/MSRC.DPU.CSharpSourceGraphExtraction.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | netstandard2.0 4 | 7.2 5 | AnyCPU;x64 6 | 7 | 8 | MSRC.DPU.CSharpSourceGraphExtraction 9 | Source Graph extraction from C# programs 10 | MIT 11 | local 12 | 1.0.2-$(BUILD_BUILDNUMBER) 13 | Miltos Allamanis, Marc Brockschmidt 14 | Microsoft 15 | https://github.com/Microsoft/dpu-utils/ 16 | Copyright (c) 2018- Microsoft Corporation 17 | 18 | 19 | 20 | bin\Debug 21 | 22 | 23 | 24 | bin\Debug 25 | 26 | 27 | 28 | bin\Release 29 | 30 | 31 | 32 | bin\Release\ 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /dotnet/CSharpSourceGraphExtraction/MethodUseInformationCollector.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Diagnostics; 4 | using Microsoft.CodeAnalysis; 5 | using Microsoft.CodeAnalysis.CSharp; 6 | using Microsoft.CodeAnalysis.CSharp.Syntax; 7 | using MSRC.DPU.CSharpSourceGraphExtraction.Utils; 8 | 9 | namespace MSRC.DPU.CSharpSourceGraphExtraction 10 | { 11 | internal class MethodUseInformationCollector : CSharpSyntaxWalker 12 | { 13 | private readonly SourceGraph _graph; 14 | 15 | private MethodUseInformationCollector(SourceGraph graph) 16 | { 17 | _graph = graph; 18 | } 19 | 20 | public static void AddMethodUseInformation(SourceGraph graph) 21 | { 22 | new MethodUseInformationCollector(graph).Visit(graph.SemanticModel.SyntaxTree.GetRoot()); 23 | } 24 | 25 | private void RecordDeclaration(SyntaxNode declarationNode, ParameterListSyntax parameterList) 26 | { 27 | var methodSymbol = _graph.SemanticModel.GetDeclaredSymbol(declarationNode) as IMethodSymbol; 28 | if (methodSymbol == null || methodSymbol.IsAbstract) 29 | { 30 | return; // Skip abstract methods. 31 | } 32 | var declarationInfo = new MethodDeclarationInformation( 33 | methodSymbol, 34 | declarationNode, 35 | ParamSymbolsToTokens(parameterList, methodSymbol)); 36 | _graph.MethodDeclarationSites.Add(declarationInfo); 37 | } 38 | 39 | private void RecordInvocation(IMethodSymbol invokedMethodSymbol, SyntaxNode methodInvocationNode, ArgumentListSyntax methodArguments) 40 | { 41 | var argumentNodeToMethodParameter = new Dictionary(); 42 | if (methodArguments != null) 43 | { 44 | foreach (var arg in methodArguments.Arguments) 45 | { 46 | var paramSymbol = MethodUtils.DetermineParameter(methodArguments, arg, invokedMethodSymbol); 47 | Debug.Assert(arg.Expression != null); 48 | Debug.Assert(paramSymbol != null); 49 | argumentNodeToMethodParameter.Add(arg.Expression, (paramSymbol.Name, paramSymbol.Type)); 50 | } 51 | } 52 | 53 | var invocationInfo = new MethodInvocationInformation(invokedMethodSymbol, methodInvocationNode, argumentNodeToMethodParameter); 54 | _graph.MethodInvocationSites.Add(invocationInfo); 55 | } 56 | 57 | private static Dictionary ParamSymbolsToTokens(ParameterListSyntax parameterList, IMethodSymbol methodSymbol) 58 | { 59 | Debug.Assert(methodSymbol != null); 60 | var paramSymbolsToIdentifiers = new Dictionary(); 61 | if (parameterList == null || parameterList.Parameters == null || methodSymbol.Parameters.Length != parameterList.Parameters.Count) 62 | { 63 | return paramSymbolsToIdentifiers; 64 | } 65 | for (int i = 0; i < methodSymbol.Parameters.Length; i++) 66 | { 67 | paramSymbolsToIdentifiers.Add(methodSymbol.Parameters[i], parameterList.Parameters[i].Identifier); 68 | } 69 | return paramSymbolsToIdentifiers; 70 | } 71 | 72 | public override void VisitInvocationExpression(InvocationExpressionSyntax node) 73 | { 74 | var invocationSymbol = _graph.SemanticModel.GetSymbolInfo(node).Symbol as IMethodSymbol; 75 | 76 | if (invocationSymbol == null) return; 77 | string methodName = invocationSymbol.Name; 78 | if (invocationSymbol.Name.IndexOf('<') != -1) 79 | { 80 | methodName = methodName.Substring(0, methodName.IndexOf('<')); 81 | } 82 | string invocationExpression; 83 | if (node.Expression is MemberAccessExpressionSyntax memberAccess) 84 | { 85 | invocationExpression = memberAccess.Name.ToString(); 86 | } else 87 | { 88 | invocationExpression = node.Expression.ToString(); 89 | } 90 | if (invocationExpression.IndexOf('<') != -1) 91 | { 92 | invocationExpression = invocationExpression.Substring(0, invocationExpression.IndexOf('<')); 93 | } 94 | 95 | if (!invocationExpression.EndsWith(methodName)) 96 | { 97 | // Heuristic: this may happen when an implicit conversion exits e.g. "string x = SomeObjectRetType()" 98 | // or an invocation of an anonymous function, such as a lambda, when the method name is "Invoke" 99 | if (methodName != "Invoke") 100 | { 101 | Console.WriteLine($"Rejecting Invocation because expression name and symbol do not match: {methodName} -> {node.Expression}"); 102 | } 103 | return; 104 | } 105 | else if (node.ArgumentList.ToString().Contains("__arglist")) 106 | { 107 | Console.WriteLine($"Rejecting Invocation because it contains an __arglist: {node}"); 108 | return; 109 | } 110 | 111 | RecordInvocation(invocationSymbol, node, node.ArgumentList); 112 | base.VisitInvocationExpression(node); 113 | } 114 | 115 | public override void VisitObjectCreationExpression(ObjectCreationExpressionSyntax node) 116 | { 117 | var constructorSymbol = _graph.SemanticModel.GetSymbolInfo(node).Symbol as IMethodSymbol; 118 | if (constructorSymbol == null) return; 119 | 120 | RecordInvocation(constructorSymbol, node, node.ArgumentList); 121 | base.VisitObjectCreationExpression(node); 122 | } 123 | 124 | public override void VisitMethodDeclaration(MethodDeclarationSyntax node) 125 | { 126 | if (node.Body == null && node.ExpressionBody == null) 127 | { 128 | return; // Method must be abstract 129 | } 130 | 131 | RecordDeclaration(node, node.ParameterList); 132 | base.VisitMethodDeclaration(node); 133 | } 134 | 135 | public override void VisitLocalFunctionStatement(LocalFunctionStatementSyntax node) 136 | { 137 | RecordDeclaration(node, node.ParameterList); 138 | base.VisitLocalFunctionStatement(node); 139 | } 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /dotnet/CSharpSourceGraphExtraction/Utils/IntVocabulary.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using MSRC.DPU.Utils; 3 | 4 | namespace MSRC.DPU.CSharpSourceGraphExtraction.Utils 5 | { 6 | public class IntVocabulary where T : class 7 | { 8 | private readonly BidirectionalMap _dictionary = new BidirectionalMap(); 9 | private int _nextId = 0; 10 | 11 | public int Count => _dictionary.Count; 12 | 13 | public int Get(T obj, bool addIfNotPresent=false) 14 | { 15 | if (!_dictionary.TryGetKey(obj, out int key)) 16 | { 17 | if (!addIfNotPresent) 18 | { 19 | throw new Exception("Object not in vocabulary"); 20 | } 21 | key = _nextId; 22 | _dictionary.Add(key, obj); 23 | _nextId++; 24 | } 25 | return key; 26 | } 27 | 28 | public bool Contains(T obj) 29 | { 30 | return _dictionary.Contains(obj); 31 | } 32 | 33 | public T Get(int objId) 34 | { 35 | return _dictionary.GetValue(objId); 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /dotnet/CSharpSourceGraphExtraction/Utils/MethodUtils.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Xml; 5 | using Microsoft.CodeAnalysis; 6 | using Microsoft.CodeAnalysis.CSharp.Syntax; 7 | 8 | namespace MSRC.DPU.CSharpSourceGraphExtraction.Utils 9 | { 10 | public static class MethodUtils 11 | { 12 | public static IEnumerable GetAllMethodDeclarations(SyntaxNode root) 13 | { 14 | return root.DescendantNodes().OfType(); 15 | } 16 | 17 | public static (string Summary, string Returns, Dictionary ParameterComments) 18 | GetDocumentationComment(IMethodSymbol methodSymbol, bool recurseToParents=true) 19 | { 20 | var comment = methodSymbol.GetDocumentationCommentXml().Trim(); 21 | 22 | if (string.IsNullOrWhiteSpace(comment) && recurseToParents) 23 | { 24 | foreach (var parentMethod in AllImplementedMethods(methodSymbol)) 25 | { 26 | comment = parentMethod.GetDocumentationCommentXml().Trim(); 27 | if (!string.IsNullOrWhiteSpace(comment)) 28 | { 29 | break; 30 | } 31 | } 32 | } 33 | 34 | if (string.IsNullOrWhiteSpace(comment)) 35 | { 36 | return ("", "", null); 37 | } 38 | 39 | if (!comment.StartsWith(""; 42 | } 43 | var xmlDoc = new XmlDocument(); 44 | try 45 | { 46 | xmlDoc.LoadXml(comment); 47 | } catch (Exception) 48 | { 49 | return ("", "", null); 50 | } 51 | 52 | if (xmlDoc.SelectSingleNode("member") == null) 53 | { 54 | return ("", "", null); 55 | } 56 | 57 | var memberXmlNode = xmlDoc.SelectSingleNode("member"); 58 | 59 | string summary = ""; 60 | if (memberXmlNode.SelectSingleNode("summary") != null) { 61 | summary = xmlDoc.SelectSingleNode("member").SelectSingleNode("summary").InnerXml.Trim(); 62 | } 63 | 64 | var parameterComments = new Dictionary(); 65 | var paramNamesToSymbols = methodSymbol.Parameters.ToDictionary(s => s.Name, s => s); 66 | 67 | foreach(var paramXmlNode in memberXmlNode.SelectNodes("param")) 68 | { 69 | var paramName = ((XmlNode)paramXmlNode).Attributes["name"].InnerText; 70 | if (paramNamesToSymbols.ContainsKey(paramName)) 71 | { 72 | parameterComments.Add(paramNamesToSymbols[paramName], ((XmlNode)paramXmlNode).InnerXml.Trim()); 73 | } 74 | } 75 | 76 | string returnVal = ""; 77 | if (memberXmlNode.SelectSingleNode("returns") != null) 78 | { 79 | returnVal = xmlDoc.SelectSingleNode("member").SelectSingleNode("returns").InnerXml.Trim(); 80 | } 81 | 82 | return (summary, returnVal, parameterComments); 83 | } 84 | 85 | public static string MethodFullyQualifiedName(IMethodSymbol methodSymbol) 86 | => methodSymbol.OriginalDefinition.ToDisplayString(); // Using the OriginalDefinition avoids instantiated type variables. 87 | 88 | public static IEnumerable AllImplementedMethods(IMethodSymbol methodSymbol) 89 | { 90 | var seenMethods = new HashSet(); 91 | if (methodSymbol.OverriddenMethod != null) 92 | { 93 | yield return methodSymbol.OverriddenMethod; 94 | foreach (var m in AllImplementedMethods(methodSymbol.OverriddenMethod)) 95 | { 96 | yield return m; 97 | seenMethods.Add(m); 98 | } 99 | } 100 | 101 | foreach(var implementedMethod in methodSymbol.ContainingType.AllInterfaces 102 | .SelectMany(iface => iface.GetMembers().OfType()) 103 | .Where(m => methodSymbol.Equals(methodSymbol.ContainingType.FindImplementationForInterfaceMember(m)))) 104 | { 105 | if (seenMethods.Add(implementedMethod)) 106 | { 107 | yield return implementedMethod; 108 | foreach(var m in AllImplementedMethods(implementedMethod)) 109 | { 110 | yield return m; 111 | seenMethods.Add(m); 112 | } 113 | } 114 | } 115 | } 116 | 117 | /// 118 | /// Copied from Roslyn source code. Determines the parameter for a given argument 119 | /// 120 | /// 121 | /// 122 | /// 123 | /// 124 | public static IParameterSymbol DetermineParameter(BaseArgumentListSyntax argumentList, ArgumentSyntax argument, IMethodSymbol symbol) 125 | { 126 | var parameters = symbol.Parameters; 127 | 128 | // Handle named argument 129 | if (argument.NameColon != null && !argument.NameColon.IsMissing) 130 | { 131 | var name = argument.NameColon.Name.Identifier.ValueText; 132 | return parameters.FirstOrDefault(p => p.Name == name); 133 | } 134 | 135 | // Handle positional argument 136 | var index = argumentList.Arguments.IndexOf(argument); 137 | if (index < 0) 138 | { 139 | return null; 140 | } 141 | 142 | if (index < parameters.Length) 143 | { 144 | return parameters[index]; 145 | } 146 | 147 | // Handle Params 148 | var lastParameter = parameters.LastOrDefault(); 149 | if (lastParameter == null) 150 | { 151 | return null; 152 | } 153 | 154 | if (lastParameter.IsParams) 155 | { 156 | return lastParameter; 157 | } 158 | 159 | return null; 160 | } 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /dotnet/CSharpSourceGraphExtraction/Utils/RoslynUtils.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.CodeAnalysis; 2 | using Microsoft.CodeAnalysis.CSharp; 3 | using Microsoft.CodeAnalysis.CSharp.Syntax; 4 | using System.Collections.Generic; 5 | using System.Linq; 6 | 7 | namespace MSRC.DPU.CSharpSourceGraphExtraction.Utils 8 | { 9 | public static class RoslynUtils 10 | { 11 | /// 12 | /// Get the symbol of a given node. 13 | /// 14 | /// 15 | /// 16 | /// 17 | public static ISymbol GetReferenceSymbol(SyntaxNode node, SemanticModel semanticModel) 18 | { 19 | ISymbol identifierSymbol = semanticModel.GetSymbolInfo(node).Symbol; 20 | if (identifierSymbol == null) 21 | { 22 | identifierSymbol = semanticModel.GetDeclaredSymbol(node); 23 | } 24 | return identifierSymbol; 25 | } 26 | 27 | public static bool IsVariableLikeSymbol(ISymbol symbol) => symbol is ILocalSymbol || symbol is IParameterSymbol || symbol is IPropertySymbol || symbol is IFieldSymbol; 28 | 29 | public static bool IsVariableLike(IdentifierNameSyntax node, SemanticModel semanticModel, out ISymbol nodeSymbol) 30 | { 31 | nodeSymbol = semanticModel.GetSymbolInfo(node).Symbol; 32 | if (nodeSymbol is IFieldSymbol || nodeSymbol is IPropertySymbol) 33 | { 34 | // We need to check if the field/property is in an LHR on an ObjectInitializerExpression 35 | return !(node.Parent is AssignmentExpressionSyntax assignmentExpression 36 | && assignmentExpression.Parent is InitializerExpressionSyntax) 37 | || !assignmentExpression.Left.Equals(node); 38 | } 39 | return nodeSymbol is ILocalSymbol || nodeSymbol is IParameterSymbol || 40 | nodeSymbol is IEventSymbol; 41 | } 42 | 43 | public static IEnumerable GetAllVariableSymbolsInSyntaxTree(SyntaxNode node, SemanticModel semanticModel) 44 | { 45 | foreach (var descendant in node.DescendantNodesAndSelf()) 46 | { 47 | var idNode = descendant as IdentifierNameSyntax; 48 | if (idNode == null) continue; 49 | if (IsVariableLike(idNode, semanticModel, out var _)) 50 | { 51 | yield return idNode.Identifier; 52 | } 53 | } 54 | } 55 | 56 | /// 57 | /// Find all the symbols used in a given tree. 58 | /// 59 | private class VariableSymbolFinder : CSharpSyntaxWalker 60 | { 61 | readonly HashSet relevantSymbols; 62 | readonly SemanticModel semanticModel; 63 | 64 | public VariableSymbolFinder(HashSet relevantSymbols, SemanticModel semanticModel) 65 | { 66 | this.relevantSymbols = relevantSymbols; 67 | this.semanticModel = semanticModel; 68 | } 69 | 70 | public override void Visit(SyntaxNode node) 71 | { 72 | ISymbol symbol = GetReferenceSymbol(node, semanticModel); 73 | if (symbol != null 74 | && !(symbol is IMethodSymbol) && !(symbol is INamespaceOrTypeSymbol) 75 | && !(symbol is IPreprocessingSymbol) && !(symbol is ITypeSymbol) 76 | && !(symbol is ILabelSymbol)) 77 | { 78 | if (symbol.OriginalDefinition != null && symbol.Locations.Length > 0 && symbol.Locations.First().SourceTree == node.SyntaxTree) 79 | { 80 | relevantSymbols.Add(symbol.OriginalDefinition); 81 | } 82 | } 83 | base.Visit(node); 84 | } 85 | } 86 | 87 | private class MethodSymbolFinder : CSharpSyntaxWalker 88 | { 89 | readonly HashSet relevantSymbols; 90 | readonly SemanticModel semanticModel; 91 | 92 | public MethodSymbolFinder(HashSet relevantSymbols, SemanticModel semanticModel) 93 | { 94 | this.relevantSymbols = relevantSymbols; 95 | this.semanticModel = semanticModel; 96 | } 97 | 98 | public override void Visit(SyntaxNode node) 99 | { 100 | ISymbol symbol = GetReferenceSymbol(node, semanticModel); 101 | if (symbol != null && (symbol is IMethodSymbol)) 102 | { 103 | var methodSymbol = symbol as IMethodSymbol; 104 | relevantSymbols.Add(methodSymbol.OriginalDefinition); 105 | } 106 | base.Visit(node); 107 | } 108 | } 109 | 110 | private class MethodDeclarationFinder: CSharpSyntaxWalker 111 | { 112 | public readonly List Methods = new List(); 113 | public override void VisitMethodDeclaration(MethodDeclarationSyntax node) 114 | { 115 | Methods.Add(node); 116 | } 117 | } 118 | 119 | public static IEnumerable GetMethodDeclarationsInNode(SyntaxNode node) 120 | { 121 | var df = new MethodDeclarationFinder(); 122 | df.Visit(node); 123 | return df.Methods; 124 | } 125 | 126 | public static InvocationExpressionSyntax GetInvocation(SyntaxToken token) 127 | { 128 | var node = token.Parent; 129 | while (node != null && !node.IsKind(SyntaxKind.InvocationExpression)) 130 | { 131 | node = node.Parent; 132 | } 133 | if (node == null) 134 | { 135 | return null; 136 | } 137 | return node as InvocationExpressionSyntax; 138 | } 139 | 140 | /// 141 | /// Get all the symbols in a tree. 142 | /// 143 | /// 144 | /// 145 | /// 146 | public static ISet GetUsedVariableSymbols(SemanticModel semanticModel, SyntaxNode currentNode) 147 | { 148 | HashSet usedSymbols = new HashSet(); 149 | VariableSymbolFinder sf = new VariableSymbolFinder(usedSymbols, semanticModel); 150 | sf.Visit(currentNode); 151 | return usedSymbols; 152 | } 153 | 154 | public static ISet GetUsedMethodSymbols(SemanticModel semanticModel, SyntaxNode currentNode) 155 | { 156 | var usedSymbols = new HashSet(); 157 | MethodSymbolFinder sf = new MethodSymbolFinder(usedSymbols, semanticModel); 158 | sf.Visit(currentNode); 159 | return usedSymbols; 160 | } 161 | 162 | public static ISymbol GetTokenSymbolReference(SyntaxToken token, SemanticModel semanticModel) 163 | { 164 | if (!token.IsKind(SyntaxKind.IdentifierToken) && !token.IsKind(SyntaxKind.ThisKeyword)) return null; 165 | 166 | SyntaxNode node = token.Parent; 167 | while (node.Parent != null) 168 | { 169 | ISymbol nodeSymbol = RoslynUtils.GetReferenceSymbol(node, semanticModel); 170 | if (nodeSymbol == null) 171 | { 172 | node = node.Parent; 173 | continue; 174 | } 175 | if (token.Text.StartsWith("@")) 176 | { 177 | if (token.Text.Substring(1) != nodeSymbol.Name) break; 178 | } 179 | else if (nodeSymbol.ToDisplayString() != token.Text && nodeSymbol.Name != token.Text) 180 | { 181 | break; 182 | } 183 | 184 | return nodeSymbol; 185 | } 186 | return null; 187 | } 188 | 189 | public static IEnumerable GetAvailableValueSymbols(SemanticModel semanticModel, SyntaxToken token) 190 | { 191 | var allSymbols = semanticModel.LookupSymbols(token.SpanStart); 192 | foreach (var symbol in allSymbols) 193 | { 194 | if (symbol.Kind == SymbolKind.Local) 195 | { 196 | ILocalSymbol localSymbol = (ILocalSymbol)symbol; 197 | var declarationSyntax = localSymbol.DeclaringSyntaxReferences[0].GetSyntax(); 198 | int declarationEnd; 199 | switch (declarationSyntax) 200 | { 201 | case ForEachStatementSyntax foreachSyntax: 202 | declarationEnd = foreachSyntax.CloseParenToken.SpanStart; 203 | break; 204 | default: 205 | declarationEnd = declarationSyntax.Span.End; 206 | break; 207 | } 208 | if (declarationEnd < token.SpanStart) 209 | { 210 | yield return localSymbol; 211 | } 212 | } 213 | else 214 | { 215 | if (symbol.Kind == SymbolKind.Field || symbol.Kind == SymbolKind.Property || symbol.Kind == SymbolKind.Parameter) 216 | { 217 | yield return symbol; 218 | } 219 | } 220 | } 221 | } 222 | 223 | public static List ComputeVariablesToConsider(ISet variableCandidates, 224 | ICollection knownTokens) 225 | { 226 | var variableSymbolsToUse = new List(); 227 | foreach (var symbol in variableCandidates) 228 | { 229 | if (symbol.IsImplicitlyDeclared) continue; 230 | if (symbol is IAliasSymbol || symbol is IRangeVariableSymbol) 231 | { 232 | continue; 233 | } 234 | if (!knownTokens.Contains(symbol)) continue; 235 | 236 | variableSymbolsToUse.Add(symbol); 237 | } 238 | 239 | return variableSymbolsToUse; 240 | } 241 | 242 | public static bool GetTypeSymbol(ISymbol symbol, out ITypeSymbol result) 243 | { 244 | if (symbol != null 245 | && !(symbol is IMethodSymbol) && !(symbol is INamespaceOrTypeSymbol) 246 | && !(symbol is IPreprocessingSymbol) && !(symbol is ITypeSymbol) 247 | && !(symbol is ILabelSymbol) 248 | && TypeHierarchy.ComputeTypeForSymbol(symbol, out var typeSym)) 249 | { 250 | result = typeSym; 251 | return true; 252 | } 253 | else if (symbol is IMethodSymbol methodSymbol) 254 | { 255 | result = methodSymbol.ReturnType; // methods' symbols are their return types, for now. 256 | return true; 257 | } 258 | result = null; 259 | return false; 260 | } 261 | } 262 | } 263 | -------------------------------------------------------------------------------- /dotnet/CSharpSourceGraphExtraction/Utils/TypeHierarchy.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.CodeAnalysis; 2 | using Newtonsoft.Json; 3 | using System; 4 | using System.Collections.Generic; 5 | using System.IO; 6 | using System.IO.Compression; 7 | using System.Runtime.CompilerServices; 8 | using MSRC.DPU.Utils; 9 | 10 | namespace MSRC.DPU.CSharpSourceGraphExtraction.Utils 11 | { 12 | /// 13 | /// A class that encapsulates the type hierarchy. 14 | /// 15 | public class TypeHierarchy 16 | { 17 | private readonly IntVocabulary _typeDictionary = new IntVocabulary(); 18 | private readonly Multimap _typeParents = new Multimap(); 19 | 20 | [MethodImpl(MethodImplOptions.Synchronized)] 21 | public void Add(ITypeSymbol type) 22 | { 23 | if (type.IsReferenceType) 24 | { 25 | // This is definitely an object 26 | AddAsObject(type); 27 | } 28 | 29 | if (type is IArrayTypeSymbol) { 30 | var arrayTypeSymbol = type as IArrayTypeSymbol; 31 | Add(arrayTypeSymbol.ElementType); 32 | } 33 | 34 | if (type is ITypeParameterSymbol) 35 | { 36 | var typeParam = type as ITypeParameterSymbol; 37 | if (typeParam.Variance == VarianceKind.In) 38 | { 39 | // this generic will accept all supertypes, not sure how to use. 40 | } 41 | else if (typeParam.Variance == VarianceKind.Out) 42 | { 43 | // Probably nothing should happen here. 44 | } 45 | 46 | if (typeParam.HasReferenceTypeConstraint) 47 | { 48 | AddAsObject(type); 49 | } 50 | else 51 | { 52 | // Add as a special "TypeParam" top type 53 | AddAsTypeParam(type); 54 | } 55 | } 56 | 57 | if (type.BaseType != null) 58 | { 59 | var addBaseType = Add(type, type.BaseType); 60 | if (addBaseType) Add(type.BaseType); 61 | } 62 | 63 | foreach (var implementedIface in type.Interfaces) 64 | { 65 | var addIface = Add(type, implementedIface); 66 | if (addIface) Add(implementedIface); 67 | } 68 | 69 | if (type is INamedTypeSymbol) 70 | { 71 | var namedType = type as INamedTypeSymbol; 72 | if (namedType.IsGenericType) 73 | { 74 | if (namedType.ConstructedFrom != type) 75 | { 76 | var addErasure = Add(type, namedType.ConstructedFrom); 77 | if (addErasure) Add(namedType.ConstructedFrom); 78 | } 79 | } 80 | } 81 | 82 | } 83 | 84 | private bool Add(ITypeSymbol subtype, ITypeSymbol type) 85 | { 86 | var baseTypeExisted = _typeDictionary.Contains(type.ToString()); 87 | int subtypeId = _typeDictionary.Get(subtype.ToString(), addIfNotPresent: true); 88 | int typeId = _typeDictionary.Get(type.ToString(), addIfNotPresent: true); 89 | 90 | if (subtype.ToString() == type.ToString()) return baseTypeExisted; 91 | 92 | _typeParents.Add(subtypeId, typeId); 93 | return !baseTypeExisted; 94 | } 95 | 96 | private void AddAsObject(ITypeSymbol subtype) 97 | { 98 | int subtypeId = _typeDictionary.Get(subtype.ToString(), addIfNotPresent: true); 99 | int typeId = _typeDictionary.Get("object", addIfNotPresent: true); 100 | _typeParents.Add(subtypeId, typeId); 101 | } 102 | 103 | private void AddAsTypeParam(ITypeSymbol subtype) 104 | { 105 | int subtypeId = _typeDictionary.Get(subtype.ToString(), addIfNotPresent: true); 106 | int typeId = _typeDictionary.Get("", addIfNotPresent: true); 107 | _typeParents.Add(subtypeId, typeId); 108 | } 109 | 110 | 111 | public void SaveTypeHierarchy(string outputFilename) 112 | { 113 | using (var fileStream = File.Create(outputFilename)) 114 | using (var gzipStream = new GZipStream(fileStream, CompressionMode.Compress, false)) 115 | using (var textStream = new StreamWriter(gzipStream)) 116 | { 117 | var serializer = new JsonSerializer { NullValueHandling = NullValueHandling.Ignore }; 118 | serializer.Serialize(textStream, new SerializableHierarchy(this)); 119 | } 120 | } 121 | 122 | public static bool ComputeTypeForSymbol(ISymbol symbol, out ITypeSymbol res) 123 | { 124 | switch (symbol) 125 | { 126 | case IParameterSymbol paramSym: 127 | res = paramSym.Type; 128 | return true; 129 | case ILocalSymbol localSym: 130 | res = localSym.Type; 131 | return true; 132 | case IFieldSymbol fieldSym: 133 | res = fieldSym.Type; 134 | return true; 135 | case IEventSymbol eventSym: 136 | res = eventSym.Type; 137 | return true; 138 | case IPropertySymbol propSym: 139 | res = propSym.Type; 140 | return true; 141 | default: 142 | res = null; 143 | return false; 144 | } 145 | } 146 | 147 | public List ComputeAndAddSymbolTypes(IEnumerable variableSymbols) 148 | { 149 | var typeNames = new List(); 150 | foreach (var symbol in variableSymbols) 151 | { 152 | if (ComputeTypeForSymbol(symbol, out var typeSymbol)) 153 | { 154 | Add(typeSymbol); 155 | typeNames.Add(typeSymbol.ToString()); 156 | } 157 | else 158 | { 159 | throw new Exception("Symbol not of recognized type: " + symbol); 160 | } 161 | } 162 | return typeNames; 163 | } 164 | 165 | private class SerializableHierarchy 166 | { 167 | public List types = new List(); 168 | public List> outgoingEdges = new List>(); 169 | 170 | public SerializableHierarchy(TypeHierarchy typeHierarchy) 171 | { 172 | for (int i = 0; i < typeHierarchy._typeDictionary.Count; i++) { 173 | types.Add(typeHierarchy._typeDictionary.Get(i)); 174 | outgoingEdges.Add(new HashSet(typeHierarchy._typeParents.Values(i))); 175 | } 176 | } 177 | } 178 | 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /dotnet/DPU.Utils.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 16 4 | VisualStudioVersion = 16.0.28516.95 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Utils", "Utils\MSRC.DPU.Utils.csproj", "{BC10F80B-22D4-4C0B-A038-AD6A75FC42A5}" 7 | EndProject 8 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CSharpSourceGraphExtraction", "CSharpSourceGraphExtraction\MSRC.DPU.CSharpSourceGraphExtraction.csproj", "{9929F15D-87C8-4D69-91C1-294B807C05C1}" 9 | EndProject 10 | Global 11 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 12 | Debug|Any CPU = Debug|Any CPU 13 | Debug|x64 = Debug|x64 14 | Release|Any CPU = Release|Any CPU 15 | Release|x64 = Release|x64 16 | EndGlobalSection 17 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 18 | {BC10F80B-22D4-4C0B-A038-AD6A75FC42A5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 19 | {BC10F80B-22D4-4C0B-A038-AD6A75FC42A5}.Debug|Any CPU.Build.0 = Debug|Any CPU 20 | {BC10F80B-22D4-4C0B-A038-AD6A75FC42A5}.Debug|x64.ActiveCfg = Debug|Any CPU 21 | {BC10F80B-22D4-4C0B-A038-AD6A75FC42A5}.Debug|x64.Build.0 = Debug|Any CPU 22 | {BC10F80B-22D4-4C0B-A038-AD6A75FC42A5}.Release|Any CPU.ActiveCfg = Release|Any CPU 23 | {BC10F80B-22D4-4C0B-A038-AD6A75FC42A5}.Release|Any CPU.Build.0 = Release|Any CPU 24 | {BC10F80B-22D4-4C0B-A038-AD6A75FC42A5}.Release|x64.ActiveCfg = Release|Any CPU 25 | {BC10F80B-22D4-4C0B-A038-AD6A75FC42A5}.Release|x64.Build.0 = Release|Any CPU 26 | {9929F15D-87C8-4D69-91C1-294B807C05C1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 27 | {9929F15D-87C8-4D69-91C1-294B807C05C1}.Debug|Any CPU.Build.0 = Debug|Any CPU 28 | {9929F15D-87C8-4D69-91C1-294B807C05C1}.Debug|x64.ActiveCfg = Debug|x64 29 | {9929F15D-87C8-4D69-91C1-294B807C05C1}.Debug|x64.Build.0 = Debug|x64 30 | {9929F15D-87C8-4D69-91C1-294B807C05C1}.Release|Any CPU.ActiveCfg = Release|Any CPU 31 | {9929F15D-87C8-4D69-91C1-294B807C05C1}.Release|Any CPU.Build.0 = Release|Any CPU 32 | {9929F15D-87C8-4D69-91C1-294B807C05C1}.Release|x64.ActiveCfg = Release|x64 33 | {9929F15D-87C8-4D69-91C1-294B807C05C1}.Release|x64.Build.0 = Release|x64 34 | EndGlobalSection 35 | GlobalSection(SolutionProperties) = preSolution 36 | HideSolutionNode = FALSE 37 | EndGlobalSection 38 | GlobalSection(ExtensibilityGlobals) = postSolution 39 | SolutionGuid = {9582362D-E81A-4511-B17A-0A05956DC8A1} 40 | EndGlobalSection 41 | EndGlobal 42 | -------------------------------------------------------------------------------- /dotnet/Utils/BidirectionalMap.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | 4 | namespace MSRC.DPU.Utils 5 | { 6 | [Serializable()] 7 | public class BidirectionalMap 8 | where V : class 9 | { 10 | private readonly Dictionary forwardMap; 11 | private readonly Dictionary backwardsMap; 12 | 13 | public int Count => forwardMap.Count; 14 | 15 | public BidirectionalMap() 16 | { 17 | forwardMap = new Dictionary(); 18 | backwardsMap = new Dictionary(); 19 | } 20 | 21 | public BidirectionalMap(IEqualityComparer valueComparer = null) 22 | { 23 | forwardMap = new Dictionary(); 24 | 25 | if (valueComparer != null) backwardsMap = new Dictionary(valueComparer); 26 | else backwardsMap = new Dictionary(); 27 | } 28 | 29 | public void Add(K key, V value) 30 | { 31 | forwardMap.Add(key, value); 32 | backwardsMap.Add(value, key); 33 | } 34 | 35 | public V GetValue(K key) 36 | { 37 | return forwardMap[key]; 38 | } 39 | 40 | public K GetKey(V value) 41 | { 42 | return backwardsMap[value]; 43 | } 44 | 45 | public bool Contains(K key) 46 | { 47 | return forwardMap.ContainsKey(key); 48 | } 49 | 50 | public bool TryGetValue(K key, out V value) 51 | { 52 | return forwardMap.TryGetValue(key, out value); 53 | } 54 | 55 | public bool TryGetKey(V value, out K key) 56 | { 57 | return backwardsMap.TryGetValue(value, out key); 58 | } 59 | 60 | public bool Contains(V value) 61 | { 62 | return backwardsMap.ContainsKey(value); 63 | } 64 | 65 | public void Delete(K key) 66 | { 67 | V value = forwardMap[key]; 68 | forwardMap.Remove(key); 69 | backwardsMap.Remove(value); 70 | } 71 | 72 | public void Delete(V value) 73 | { 74 | K key = backwardsMap[value]; 75 | forwardMap.Remove(key); 76 | backwardsMap.Remove(value); 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /dotnet/Utils/ChunkedJsonWriter.cs: -------------------------------------------------------------------------------- 1 | using Newtonsoft.Json; 2 | using System; 3 | using System.IO; 4 | using System.IO.Compression; 5 | using System.Text; 6 | 7 | namespace MSRC.DPU.Utils 8 | { 9 | /// 10 | /// Thread-safe .json[l].gz writer. The output is automatically split in chunks. 11 | /// 12 | public class ChunkedJsonGzWriter : IDisposable 13 | { 14 | private readonly object _lock = new object(); 15 | private TextWriter _textStream = null; 16 | private int _numElementsWrittenInCurrentChunk = 0; 17 | private readonly string _outputFilenameTemplate; 18 | 19 | private readonly int _max_elements_per_chunk; 20 | private readonly bool _useJsonlFormat; 21 | 22 | public ChunkedJsonGzWriter(string outputFilenameTemplate, 23 | int max_elements_per_chunk = 500, 24 | bool useJsonlFormat = true, 25 | bool resumeIfFilesExist = false) 26 | { 27 | _outputFilenameTemplate = outputFilenameTemplate; 28 | _max_elements_per_chunk = max_elements_per_chunk; 29 | _useJsonlFormat = useJsonlFormat; 30 | if (resumeIfFilesExist) 31 | { 32 | // Loop Until there is an unwritten file 33 | for (int i = 0; ; i++) 34 | { 35 | if (File.Exists(GetChunkedOutputFilename(_outputFilenameTemplate, NumChunksWrittenSoFar))) 36 | { 37 | NumChunksWrittenSoFar++; 38 | } 39 | else 40 | { 41 | break; 42 | } 43 | } 44 | } 45 | 46 | } 47 | 48 | public int NumChunksWrittenSoFar { get; private set; } = 0; 49 | 50 | /// 51 | /// Write JSON representation of a single datapoint to the output. The method handles details of chunking. 52 | /// 53 | /// String containing a JSON-encoded data point. 54 | public void WriteElement(string jsonElement) 55 | { 56 | lock (_lock) 57 | { 58 | if (_textStream == null) 59 | { 60 | var filename = GetChunkedOutputFilename(_outputFilenameTemplate, NumChunksWrittenSoFar); 61 | Console.WriteLine($"Opening output file {filename}."); 62 | var fileStream = File.Create(filename); 63 | var gzipStream = new GZipStream(fileStream, CompressionMode.Compress, false); 64 | _textStream = new StreamWriter(gzipStream); 65 | _numElementsWrittenInCurrentChunk = 0; 66 | if (!_useJsonlFormat) _textStream.Write('['); 67 | } 68 | 69 | if (_numElementsWrittenInCurrentChunk > 0) 70 | { 71 | if (_useJsonlFormat) 72 | { 73 | _textStream.Write('\n'); 74 | } 75 | else 76 | { 77 | _textStream.Write(','); 78 | } 79 | } 80 | _textStream.Write(jsonElement); 81 | 82 | ++_numElementsWrittenInCurrentChunk; 83 | if (_numElementsWrittenInCurrentChunk >= _max_elements_per_chunk) 84 | { 85 | CloseOutputFile(); 86 | } 87 | } 88 | } 89 | 90 | /// 91 | /// Write JSON representation of a single datapoint to the output. The method handles details of chunking. 92 | /// 93 | /// A callback that writes some data to the provided JsonWriter, for example your hand-rolled serialization code. 94 | public void WriteElement(Action writer) 95 | { 96 | string jsonElement; 97 | using (MemoryStream ms = new MemoryStream()) 98 | { 99 | TextWriter tw = new StreamWriter(ms); 100 | JsonWriter js = new JsonTextWriter(tw); 101 | 102 | writer(js); 103 | js.Flush(); 104 | ms.Seek(0, SeekOrigin.Begin); 105 | 106 | using (TextReader sr = new StreamReader(ms)) 107 | { 108 | jsonElement = sr.ReadToEnd(); 109 | } 110 | 111 | WriteElement(jsonElement); 112 | } 113 | } 114 | 115 | private string GetChunkedOutputFilename(string fileName, int chunkNum) 116 | { 117 | var outputFormat = (_useJsonlFormat ? ".jsonl" : ".json") + ".gz"; 118 | if (fileName.EndsWith(outputFormat)) 119 | { 120 | return fileName.Replace(outputFormat, "." + chunkNum + outputFormat); 121 | } 122 | else 123 | { 124 | return fileName + "." + chunkNum + outputFormat; 125 | } 126 | } 127 | 128 | private void CloseOutputFile() 129 | { 130 | lock (_lock) 131 | { 132 | if (!_useJsonlFormat) _textStream.Write(']'); 133 | _textStream.Close(); 134 | _textStream = null; 135 | ++NumChunksWrittenSoFar; 136 | } 137 | } 138 | 139 | public void Dispose() 140 | { 141 | if (_textStream != null) 142 | { 143 | CloseOutputFile(); 144 | } 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /dotnet/Utils/ExtensionUtils.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | 4 | namespace MSRC.DPU.Utils 5 | { 6 | public static class ExtensionUtils 7 | { 8 | public static void Deconstruct(this KeyValuePair kvp, out K key, out V value) 9 | { 10 | key = kvp.Key; 11 | value = kvp.Value; 12 | } 13 | 14 | public static V TryGetOrAddValue(this Dictionary dict, K key, out V value, Func computeDefault) 15 | { 16 | if (!dict.TryGetValue(key, out value)) 17 | { 18 | value = computeDefault(); 19 | dict.Add(key, value); 20 | } 21 | return value; 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /dotnet/Utils/MSRC.DPU.Utils.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | netstandard2.0 4 | 7.2 5 | AnyCPU;x64 6 | 7 | 8 | MSRC.DPU.Utils 9 | Utilities used by the Deep Program Understanding team 10 | MIT 11 | local 12 | 1.0.3-$(BUILD_BUILDNUMBER) 13 | Miltos Allamanis, Marc Brockschmidt 14 | Microsoft 15 | https://github.com/Microsoft/dpu-utils/ 16 | Copyright (c) 2018- Microsoft Corporation 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | bin\Debug 27 | 28 | 29 | 30 | bin\Debug 31 | 32 | 33 | 34 | bin\Release 35 | 36 | 37 | 38 | bin\Release 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /dotnet/Utils/Multimap.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | 5 | namespace MSRC.DPU.Utils 6 | { 7 | public class Multimap 8 | { 9 | private readonly Dictionary> _elements = new Dictionary>(); 10 | 11 | public Multimap() {} 12 | 13 | public Multimap(IEnumerable> input) 14 | { 15 | foreach (var grouping in input) 16 | { 17 | var key = grouping.Key; 18 | AddMany(key, grouping); 19 | } 20 | } 21 | 22 | public Multimap(Dictionary> input) 23 | { 24 | foreach (var grouping in input) 25 | { 26 | var key = grouping.Key; 27 | AddMany(key, grouping.Value); 28 | } 29 | } 30 | 31 | public void Add(K key, V value) 32 | { 33 | if (!_elements.TryGetValue(key, out HashSet keyElements)) 34 | { 35 | keyElements = new HashSet(); 36 | _elements.Add(key, keyElements); 37 | } 38 | keyElements.Add(value); 39 | } 40 | 41 | public Multimap Invese() 42 | { 43 | var inverse = new Multimap(); 44 | foreach (var pair in this.KeyValuePairs()) 45 | { 46 | inverse.Add(pair.Item2, pair.Item1); 47 | } 48 | return inverse; 49 | } 50 | 51 | public void AddMany(K key, IEnumerable values) 52 | { 53 | if (!_elements.TryGetValue(key, out HashSet keyElements)) 54 | { 55 | keyElements = new HashSet(); 56 | _elements.Add(key, keyElements); 57 | } 58 | keyElements.UnionWith(values); 59 | } 60 | 61 | public IEnumerable Values(K key) 62 | { 63 | if (!_elements.TryGetValue(key, out HashSet keyElements)) 64 | { 65 | return Enumerable.Empty(); 66 | } 67 | return keyElements.AsEnumerable(); 68 | } 69 | 70 | public bool ContainsEntry(K key, V value) 71 | { 72 | if (!_elements.TryGetValue(key, out HashSet keyElements)) 73 | { 74 | return false; 75 | } 76 | return keyElements.Contains(value); 77 | } 78 | 79 | public IEnumerable Keys() 80 | { 81 | return _elements.Keys; 82 | } 83 | 84 | public IEnumerable> KeyValuePairs() 85 | { 86 | foreach(var keyset in _elements) 87 | { 88 | foreach(var value in keyset.Value) 89 | { 90 | yield return Tuple.Create(keyset.Key, value); 91 | } 92 | } 93 | } 94 | 95 | public int CountFor(K key) 96 | { 97 | if (!_elements.TryGetValue(key, out HashSet keyElements)) 98 | { 99 | return 0; 100 | } 101 | return keyElements.Count; 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /python/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt 2 | include LICENSE 3 | recursive-include dpu_utils *.txt 4 | recursive-include dpu_utils py.typed 5 | -------------------------------------------------------------------------------- /python/dpu_utils/__init__.py: -------------------------------------------------------------------------------- 1 | name = "dpu_utils" 2 | 3 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/__init__.py: -------------------------------------------------------------------------------- 1 | from .identifiersplitting import split_identifier_into_parts 2 | from .lattice import CSharpLattice, Lattice, LatticeVocabulary 3 | from .keywords.keywordlist import get_language_keywords 4 | from .filesuffix import language_candidates_from_suffix 5 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/deduplication/__init__.py: -------------------------------------------------------------------------------- 1 | from .deduplication import DuplicateDetector -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/deduplication/deduplication.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, defaultdict 2 | import typing 3 | from typing import List, Dict, Set, TypeVar, Generic, Iterable, Tuple, Hashable, Optional, Callable 4 | import re 5 | import numpy as np 6 | 7 | from dpu_utils.codeutils import get_language_keywords 8 | 9 | from SetSimilaritySearch import all_pairs 10 | 11 | DocumentId = TypeVar('DocumentId', bound=Hashable) 12 | 13 | 14 | class DuplicateDetector(Generic[DocumentId]): 15 | """Detect near-duplicate code. 16 | 17 | This class accepts a list of tokens within the some code snippet. It then approximately finds all identifier 18 | tokens and creates a set T_1 with those tokens and a multiset T_2 with the same tokens. 19 | 20 | A file `i` is considered to be a duplicate with another one `j` if the Jaccard similarity of T_1^i and T_1^j 21 | is more than `set_similarity_threshold` and the Jaccard similarity of T_2^i and T_2^j is more than 22 | `multiset_similarity_threshold`. Documents with less than `min_num_tokens_per_document` are not considered. 23 | 24 | This follows the general principles in 25 | 26 | Sajnani H, Saini V, Svajlenko J, Roy CK, Lopes CV. 27 | SourcererCC: scaling code clone detection to big-code. 28 | In Software Engineering (ICSE), 2016 29 | IEEE/ACM 38th International Conference on 2016 May 14 (pp. 1157-1168) 30 | 31 | Sample usage: 32 | * Add all files (and their tokens) via `add_files()` 33 | * Call `compute_duplicates()` 34 | * If the goal is to retrieve a list of files to be excluded, instead use `compute_ids_to_exclude()` 35 | 36 | 37 | See also: 38 | Allamanis, Miltiadis. "The adverse effects of code duplication in machine learning models of code." 39 | Proceedings of the 2019 ACM SIGPLAN International Symposium on New Ideas, New Paradigms, 40 | and Reflections on Programming and Software. ACM, 2019. 41 | 42 | """ 43 | 44 | IDENTIFIER_REGEX = re.compile('[_a-zA-Z][_a-zA-Z0-9]*') 45 | 46 | def __init__(self, set_similarity_threshold: float=0.8, multiset_similarity_threshold: float=0.7, 47 | min_num_tokens_per_document: int=20)-> None: 48 | self.__vocabulary = {} # type: Dict[str, int] 49 | self.__set_similarity_threshold = set_similarity_threshold 50 | self.__multiset_similarity_threshold = multiset_similarity_threshold 51 | self.__min_num_tokens_per_document = min_num_tokens_per_document 52 | self.__document_keys = [] # type: List[DocumentId] 53 | self.__document_key_set = set() # type: Set[DocumentId] 54 | self.__document_elements = [] # type: List[typing.Counter[int]] 55 | 56 | def __get_token_id(self, token: str) -> int: 57 | token_id = self.__vocabulary.get(token) 58 | if token_id is None: 59 | token_id = len(self.__vocabulary) 60 | self.__vocabulary[token] = token_id 61 | return token_id 62 | 63 | def add_file(self, id: DocumentId, tokens: List[str], language: Optional[str]=None) -> bool: 64 | """Add a file to be indexed by the duplicate detector.""" 65 | id_tokens = Counter(self.__get_token_id(t) for t in tokens if self.IDENTIFIER_REGEX.match(t) 66 | and (language is None or t not in get_language_keywords(language))) 67 | assert id not in self.__document_key_set 68 | if sum(id_tokens.values()) < self.__min_num_tokens_per_document: 69 | return False 70 | self.__document_keys.append(id) 71 | self.__document_key_set.add(id) 72 | self.__document_elements.append(id_tokens) 73 | return True 74 | 75 | def __duplicate_pairs(self)-> Iterable[Tuple[int, int]]: 76 | similar_pairs = all_pairs(self.__document_elements, 77 | similarity_func_name='jaccard', 78 | similarity_threshold=self.__set_similarity_threshold) 79 | for idx1, idx2, _ in similar_pairs: 80 | if self.__multiset_jaccard_similarity(idx1, idx2) >= self.__multiset_similarity_threshold: 81 | yield idx1, idx2 82 | 83 | def __multiset_jaccard_similarity(self, idx1: int, idx2: int)-> float: 84 | intersection_size = sum((self.__document_elements[idx1] & self.__document_elements[idx2]).values()) 85 | union_size = sum((self.__document_elements[idx1] | self.__document_elements[idx2]).values()) 86 | return float(intersection_size) / union_size 87 | 88 | def compute_duplicates(self) -> List[Set[DocumentId]]: 89 | """Compute the duplicates in the currently indexed documents. 90 | 91 | Make the incorrect but reasonable assumption that similarity is transitive. 92 | Compute the pairwise similar elements and add them into clusters.""" 93 | 94 | clone_sets = [] # type: List[Set[DocumentId]] 95 | 96 | pairwise_relationships = defaultdict(list) # type: Dict[int, List[int]] 97 | for idx1, idx2 in self.__duplicate_pairs(): 98 | assert idx1 != idx2 99 | pairwise_relationships[idx1].append(idx2) 100 | pairwise_relationships[idx2].append(idx1) 101 | 102 | # Compute the transitive closure of this relationship 103 | documents_to_visit = set(pairwise_relationships.keys()) # type: Set[int] 104 | while len(documents_to_visit) > 0: 105 | current_idx = documents_to_visit.pop() 106 | 107 | current_idx_closure = {current_idx} 108 | visit_queue = list(pairwise_relationships[current_idx]) 109 | while len(visit_queue) > 0: 110 | other_idx = visit_queue.pop() 111 | current_idx_closure.add(other_idx) 112 | documents_to_visit.discard(other_idx) 113 | 114 | # Add to queue 115 | visit_queue.extend(next_idx for next_idx in pairwise_relationships[other_idx] 116 | if next_idx in documents_to_visit) 117 | 118 | clone_sets.append(set(self.__document_keys[i] for i in current_idx_closure)) 119 | return clone_sets 120 | 121 | def print_clone_set_stats(self, clone_sets: List[Set[DocumentId]]) -> None: 122 | total_num_files = len(self.__document_keys) 123 | num_cloned_files = sum(len(c) for c in clone_sets) 124 | print('Duplicated files: %.2f%%' % (num_cloned_files / total_num_files * 100.)) 125 | print('Avg num of files per clones %.2f' % np.mean([len(c) for c in clone_sets])) 126 | print('Median num of files per clones %s' % np.median([len(c) for c in clone_sets])) 127 | 128 | print('Duplication Ratio %.2f%%' % ((num_cloned_files - len(clone_sets)) / total_num_files * 100)) 129 | 130 | def compute_ids_to_exclude(self, keep_selector: Optional[Callable[[Set[DocumentId]], DocumentId]]=None) -> Set[DocumentId]: 131 | """Compute a set of document ids to discard in the currently indexed documents. 132 | 133 | :param keep_selector: a lambda that accepts a set of DocumentId's and returns the DocumentId to keep. 134 | If the DocumentId is not contained in input set, the whole cluster of duplicates is excluded. 135 | If keep_selector is None then it arbitrarily excludes one document id from each cluster of duplicates, and returns 136 | a set of the remaining document ids to exclude in order to de-duplicate your data. 137 | """ 138 | duplicate_clusters = self.compute_duplicates() 139 | # remove one document from each duplicate set to keep 140 | for cluster in duplicate_clusters: 141 | if keep_selector is None: 142 | cluster.pop() # Remove arbitrary element 143 | else: 144 | document_to_keep = keep_selector(cluster) 145 | cluster.discard(document_to_keep) 146 | 147 | # flatten out the lists of sets into one superset, each element being a document_id that you will discard 148 | return set.union(*duplicate_clusters) 149 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/deduplication/deduplicationcli: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Detect approximate duplicates across a set of files. 4 | 5 | Usage: 6 | deduplicationcli [options] DATA_PATH OUT_JSON 7 | 8 | Options: 9 | -h --help Show this message on screen. 10 | --azure-info= Azure authentication information file (JSON). Used to load data from Azure storage. 11 | --debug Enable debug routines. [default: False] 12 | --language= The programming language of the input data. [default: python] 13 | --entry-id-field= The name of the field in each JSON entry that uniquely identifies it. [default: filename] 14 | --tokens-field-name= The name of the field in each JSON entry that contains the code tokens. [default: tokens] 15 | """ 16 | from docopt import docopt 17 | from dpu_utils.utils import run_and_debug, RichPath 18 | from tqdm import tqdm 19 | 20 | from dpu_utils.codeutils.deduplication import DuplicateDetector 21 | 22 | 23 | def run(arguments): 24 | azure_info_path = arguments.get('--azure-info', None) 25 | data_dir = RichPath.create(arguments['DATA_PATH'], azure_info_path) 26 | assert data_dir.is_dir(), "%s is not a folder" % data_dir 27 | 28 | detector = DuplicateDetector() # type: DuplicateDetector[str] 29 | 30 | for file in tqdm(data_dir.get_filtered_files_in_dir('*.jsonl.gz'), desc='Loading files'): 31 | for idx, element in enumerate(file.read_as_jsonl()): 32 | detector.add_file(id=element[arguments['--entry-id-field']], 33 | tokens=element[arguments['--tokens-field-name']], 34 | language=arguments['--language']) 35 | 36 | print('Added files. Computing duplicates...') 37 | duplicates = detector.compute_duplicates() 38 | detector.print_clone_set_stats(duplicates) 39 | out_path = RichPath.create(arguments['OUT_JSON'], azure_info_path) 40 | out_path.save_as_compressed_file([list(l) for l in duplicates]) 41 | print('Done.') 42 | 43 | 44 | if __name__ == '__main__': 45 | args = docopt(__doc__) 46 | run_and_debug(lambda: run(args), args.get('--debug', False)) 47 | 48 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/filesuffix.py: -------------------------------------------------------------------------------- 1 | from typing import FrozenSet 2 | 3 | __all__ = ["language_candidates_from_suffix"] 4 | 5 | _SUFFIXES = { 6 | "c": {"c"}, 7 | "cc": {"cpp"}, 8 | "cpp": {"cpp"}, 9 | "cs": {"c_sharp"}, 10 | "go": {"go"}, 11 | "h": {"c", "cpp"}, 12 | "java": {"java"}, 13 | "js": {"javascript"}, 14 | "php": {"php"}, 15 | "py": {"python"}, 16 | "r": {"r"}, 17 | "rb": {"ruby"}, 18 | "rs": {"rust"}, 19 | "sh": {"bash"}, 20 | "ts": {"typescript"}, 21 | } 22 | _SUFFIXES = {s: frozenset(ls) for s, ls in _SUFFIXES.items()} 23 | 24 | 25 | def language_candidates_from_suffix(suffix: str) -> FrozenSet[str]: 26 | """ 27 | Get the set of potential programming languages for a given file suffix, 28 | based on common conventions. 29 | """ 30 | if suffix.startswith("."): 31 | suffix = suffix[1:] 32 | suffix = suffix.lower() 33 | return _SUFFIXES.get(suffix, frozenset()) 34 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/identifiersplitting.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from typing import List 3 | import sys 4 | 5 | REGEX_TEXT = ("(?<=[a-z0-9])(?=[A-Z])|" 6 | "(?<=[A-Z0-9])(?=[A-Z][a-z])|" 7 | "(?<=[0-9])(?=[a-zA-Z])|" 8 | "(?<=[A-Za-z])(?=[0-9])|" 9 | "(?<=[@$.'\"])(?=[a-zA-Z0-9])|" 10 | "(?<=[a-zA-Z0-9])(?=[@$.'\"])|" 11 | "_|\\s+") 12 | 13 | if sys.version_info >= (3, 7): 14 | import re 15 | SPLIT_REGEX = re.compile(REGEX_TEXT) 16 | else: 17 | import regex 18 | SPLIT_REGEX = regex.compile("(?V1)"+REGEX_TEXT) 19 | 20 | 21 | @lru_cache(maxsize=5000) 22 | def split_identifier_into_parts(identifier: str) -> List[str]: 23 | """ 24 | Split a single identifier into parts on snake_case and camelCase 25 | """ 26 | identifier_parts = list(s.lower() for s in SPLIT_REGEX.split(identifier) if len(s)>0) 27 | 28 | if len(identifier_parts) == 0: 29 | return [identifier] 30 | return identifier_parts 31 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/dpu-utils/bc4f5fa6c92c4b073fff4d93269512b0027f7f02/python/dpu_utils/codeutils/keywords/__init__.py -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/c.txt: -------------------------------------------------------------------------------- 1 | auto 2 | break 3 | case 4 | char 5 | const 6 | continue 7 | default 8 | do 9 | double 10 | else 11 | enum 12 | extern 13 | float 14 | for 15 | goto 16 | if 17 | int 18 | long 19 | register 20 | return 21 | short 22 | signed 23 | sizeof 24 | static 25 | struct 26 | switch 27 | typedef 28 | union 29 | unsigned 30 | void 31 | volatile 32 | while 33 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/cpp.txt: -------------------------------------------------------------------------------- 1 | alignas 2 | alignof 3 | and 4 | and_eq 5 | asm 6 | atomic_cancel 7 | atomic_commit 8 | atomic_noexcept 9 | auto 10 | bitand 11 | bitor 12 | bool 13 | break 14 | case 15 | catch 16 | char 17 | char8_t 18 | char16_t 19 | char32_t 20 | class 21 | compl 22 | concept 23 | const 24 | consteval 25 | constexpr 26 | const_cast 27 | continue 28 | co_await 29 | co_return 30 | co_yield 31 | decltype 32 | default 33 | delete 34 | do 35 | double 36 | dynamic_cast 37 | else 38 | enum 39 | explicit 40 | export 41 | extern 42 | false 43 | float 44 | for 45 | friend 46 | goto 47 | if 48 | inline 49 | int 50 | long 51 | mutable 52 | namespace 53 | new 54 | noexcept 55 | not 56 | not_eq 57 | nullptr 58 | operator 59 | or 60 | or_eq 61 | private 62 | protected 63 | public 64 | reflexpr 65 | register 66 | reinterpret_cast 67 | requires 68 | return 69 | short 70 | signed 71 | sizeof 72 | static 73 | static_assert 74 | static_cast 75 | struct 76 | switch 77 | synchronized 78 | template 79 | this 80 | thread_local 81 | throw 82 | true 83 | try 84 | typedef 85 | typeid 86 | typename 87 | union 88 | unsigned 89 | using 90 | virtual 91 | void 92 | volatile 93 | wchar_t 94 | while 95 | xor 96 | xor_eq -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/csharp.txt: -------------------------------------------------------------------------------- 1 | abstract 2 | as 3 | base 4 | bool 5 | break 6 | byte 7 | case 8 | catch 9 | char 10 | checked 11 | class 12 | const 13 | continue 14 | decimal 15 | default 16 | delegate 17 | do 18 | double 19 | else 20 | enum 21 | event 22 | explicit 23 | extern 24 | finally 25 | fixed 26 | float 27 | for 28 | foreach 29 | goto 30 | if 31 | implicit 32 | in 33 | int 34 | interface 35 | internal 36 | is 37 | lock 38 | long 39 | namespace 40 | new 41 | null 42 | object 43 | operator 44 | out 45 | override 46 | params 47 | private 48 | protected 49 | public 50 | readonly 51 | ref 52 | return 53 | sbyte 54 | sealed 55 | short 56 | sizeof 57 | stackalloc 58 | static 59 | string 60 | struct 61 | switch 62 | this 63 | throw 64 | try 65 | typeof 66 | uint 67 | ulong 68 | unchecked 69 | unsafe 70 | ushort 71 | using 72 | using 73 | static 74 | virtual 75 | void 76 | volatile 77 | while -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/go.txt: -------------------------------------------------------------------------------- 1 | break 2 | default 3 | func 4 | interface 5 | select 6 | case 7 | defer 8 | go 9 | map 10 | struct 11 | chan 12 | else 13 | goto 14 | package 15 | switch 16 | const 17 | fallthrough 18 | if 19 | range 20 | type 21 | continue 22 | for 23 | import 24 | return 25 | var 26 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/java.txt: -------------------------------------------------------------------------------- 1 | abstract 2 | assert 3 | boolean 4 | break 5 | byte 6 | case 7 | catch 8 | char 9 | class 10 | continue 11 | default 12 | do 13 | double 14 | else 15 | enum 16 | extends 17 | final 18 | finally 19 | float 20 | for 21 | if 22 | implements 23 | import 24 | instanceof 25 | int 26 | interface 27 | long 28 | native 29 | new 30 | package 31 | private 32 | protected 33 | public 34 | return 35 | short 36 | static 37 | strictfp 38 | super 39 | switch 40 | synchronized 41 | this 42 | throw 43 | throws 44 | transient 45 | try 46 | void 47 | volatile 48 | while 49 | var 50 | const 51 | goto 52 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/javascript.txt: -------------------------------------------------------------------------------- 1 | break 2 | case 3 | catch 4 | class 5 | const 6 | continue 7 | debugger 8 | default 9 | delete 10 | do 11 | else 12 | export 13 | extends 14 | finally 15 | for 16 | function 17 | if 18 | import 19 | in 20 | instanceof 21 | new 22 | return 23 | super 24 | switch 25 | this 26 | throw 27 | try 28 | typeof 29 | var 30 | void 31 | while 32 | with 33 | yield 34 | enum 35 | implements 36 | interface 37 | let 38 | package 39 | private 40 | protected 41 | public 42 | static 43 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/keywordlist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import keyword 3 | from functools import lru_cache 4 | from typing import FrozenSet 5 | 6 | __all__ = ['get_language_keywords'] 7 | 8 | _LANGUAGE_TO_FILENAME = { 9 | 'c': 'c.txt', 10 | 'cpp': 'cpp.txt', 11 | 'c++': 'cpp.txt', 12 | 'csharp': 'csharp.txt', 13 | 'c_sharp': 'csharp.txt', 14 | 'c#': 'csharp.txt', 15 | 'go': 'go.txt', 16 | 'java': 'java.txt', 17 | 'javascript': 'javascript.txt', 18 | 'js': 'javascript.txt', 19 | 'php': 'php.txt', 20 | 'ruby': 'ruby.txt', 21 | 'typescript': 'typescript.txt', 22 | 'ts': 'typescript.txt', 23 | } 24 | 25 | @lru_cache() 26 | def get_language_keywords(language: str) -> FrozenSet[str]: 27 | """ 28 | Returns the keywords of a programming language. 29 | 30 | There are some inconsistencies across languages wrt to 31 | what is considered a keyword. For example, the true/false 32 | literals are considered keywords in many languages. However, 33 | we exclude them here for consistency. We also exclude special 34 | functions-like keywords, such as `die()` in PHP. 35 | """ 36 | language = language.lower() 37 | if language == 'python': 38 | return frozenset(k for k in keyword.kwlist if k != 'True' and k != 'False') 39 | elif language in _LANGUAGE_TO_FILENAME: 40 | name = _LANGUAGE_TO_FILENAME[language] 41 | with open(os.path.join(os.path.dirname(__file__), name)) as f: 42 | return frozenset(l.strip() for l in f if len(l.strip()) > 0) 43 | else: 44 | raise Exception('Language keywords `%s` not supported yet. Consider contributing it to dpu-utils.' % language) 45 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/php.txt: -------------------------------------------------------------------------------- 1 | abstract 2 | and 3 | array 4 | break 5 | callable 6 | case 7 | catch 8 | class 9 | clone 10 | const 11 | continue 12 | declare 13 | default 14 | do 15 | echo 16 | else 17 | elseif 18 | enddeclare 19 | endfor 20 | endforeach 21 | endif 22 | endswitch 23 | endwhile 24 | extends 25 | final 26 | finally 27 | for 28 | foreach 29 | function 30 | global 31 | goto 32 | if 33 | implements 34 | include 35 | include_once 36 | instanceof 37 | insteadof 38 | interface 39 | namespace 40 | new 41 | or 42 | print 43 | private 44 | protected 45 | public 46 | require 47 | require_once 48 | return 49 | static 50 | switch 51 | throw 52 | trait 53 | try 54 | unset 55 | use 56 | var 57 | while 58 | xor 59 | yield -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/ruby.txt: -------------------------------------------------------------------------------- 1 | alias 2 | and 3 | begin 4 | break 5 | case 6 | class 7 | def 8 | defined? 9 | do 10 | else 11 | elsif 12 | end 13 | ensure 14 | for 15 | if 16 | in 17 | module 18 | next 19 | nil 20 | not 21 | or 22 | redo 23 | rescue 24 | retry 25 | return 26 | self 27 | super 28 | then 29 | undef 30 | unless 31 | until 32 | when 33 | while 34 | yield -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/keywords/typescript.txt: -------------------------------------------------------------------------------- 1 | abstract 2 | any 3 | as 4 | async 5 | await 6 | boolean 7 | break 8 | case 9 | catch 10 | class 11 | const 12 | constructor 13 | continue 14 | debugger 15 | declare 16 | default 17 | delete 18 | do 19 | else 20 | enum 21 | export 22 | extends 23 | finally 24 | for 25 | from 26 | function 27 | get 28 | if 29 | implements 30 | import 31 | in 32 | instanceof 33 | interface 34 | is 35 | let 36 | module 37 | namespace 38 | new 39 | null 40 | number 41 | of 42 | package 43 | private 44 | protected 45 | public 46 | require 47 | return 48 | set 49 | static 50 | string 51 | super 52 | switch 53 | symbol 54 | this 55 | throw 56 | try 57 | type 58 | typeof 59 | var 60 | void 61 | while 62 | with 63 | yield -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/lattice/__init__.py: -------------------------------------------------------------------------------- 1 | from .csharplattice import CSharpLattice 2 | from .lattice import LatticeVocabulary, Lattice 3 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/lattice/csharplattice.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from typing import List, Set 3 | 4 | from .lattice import Lattice 5 | from dpu_utils.utils.dataloading import load_json_gz 6 | 7 | 8 | class CSharpLattice(Lattice): 9 | """Represents a lattice structure of C# types.""" 10 | def __init__(self, elements: List[str], parent_relations: List[Set[int]]) -> None: 11 | super().__init__(elements, parent_relations) 12 | 13 | @lru_cache(maxsize=1024) 14 | def parents(self, element: str) -> List[str]: 15 | """Get the parents of a given element""" 16 | if element not in self._element_to_id: 17 | if element.endswith("[]"): 18 | inner_type = element[:-2] 19 | inner_type_parents = self.parents(inner_type) 20 | return list(sorted(set(inner_type_parent + "[]" for inner_type_parent in inner_type_parents))) 21 | 22 | return super().parents(element) 23 | 24 | @staticmethod 25 | def load(filename: str) -> 'CSharpLattice': 26 | types = load_json_gz(filename) 27 | return CSharpLattice(types['types'], types['outgoingEdges']) 28 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/lattice/lattice.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from functools import lru_cache 3 | from itertools import chain 4 | from typing import Iterable, List, Set, Optional, Union 5 | 6 | from dpu_utils.utils.dataloading import save_json_gz, load_json_gz 7 | from dpu_utils.mlutils import Vocabulary 8 | 9 | __all__ = ['Lattice', 'LatticeVocabulary'] 10 | 11 | 12 | class Lattice: 13 | """Represents a lattice structure.""" 14 | 15 | def __init__(self, elements: List[str], parent_relations: List[Set[int]]) -> None: 16 | self._elements = list(elements) 17 | self._element_to_id = {v: k for k, v in enumerate(self._elements)} 18 | self._parent_relations = [frozenset(parents) for parents in parent_relations] 19 | 20 | def __contains__(self, element: str) -> bool: 21 | return element in self._element_to_id 22 | 23 | @lru_cache(maxsize=1024) 24 | def parents(self, element: str) -> List[str]: 25 | """Get the parents of a given element""" 26 | if element not in self._element_to_id: 27 | return [] 28 | 29 | element_id = self._element_to_id[element] 30 | all_parents = set() 31 | to_visit = list(self._parent_relations[element_id]) 32 | while len(to_visit) > 0: 33 | next_element_id = to_visit.pop() 34 | all_parents.add(next_element_id) 35 | to_visit.extend(i for i in self._parent_relations[next_element_id] if i not in all_parents) 36 | 37 | return list(sorted(set(self._elements[i] for i in all_parents))) 38 | 39 | def to_dot(self, filename: str) -> None: 40 | with open(filename, 'w') as f: 41 | print('digraph G {', file=f) 42 | 43 | for i, element in enumerate(self._elements): 44 | print('n%s [label="%s"];' % (i, element), file=f) 45 | 46 | for i, parents in enumerate(self._parent_relations): 47 | for parent_id in parents: 48 | print('n%s->n%s;' % (i, parent_id), file=f) 49 | 50 | print('}', file=f) # digraph 51 | 52 | def save_as_json(self, filename: str) -> None: 53 | data = dict(types=self._elements, outgoingEdges=[list(p) for p in self._parent_relations]) 54 | save_json_gz(data, filename) 55 | 56 | def merge(self, other_lattice: 'Lattice') -> None: 57 | self._parent_relations = [set(parents) for parents in self._parent_relations] # Temporarily convert to sets 58 | for element in other_lattice._elements: 59 | if element not in self._element_to_id: 60 | self._element_to_id[element] = len(self._elements) 61 | self._elements.append(element) 62 | self._parent_relations.append(set()) 63 | 64 | # Add parent relations 65 | for other_lattice_idx, element in enumerate(other_lattice._elements): 66 | for other_lattice_parent_idx in other_lattice._parent_relations[other_lattice_idx]: 67 | parent_name = other_lattice._elements[other_lattice_parent_idx] 68 | 69 | this_lattice_element_idx = self._element_to_id[element] 70 | this_lattice_parent_idx = self._element_to_id[parent_name] 71 | self._parent_relations[this_lattice_element_idx].add(this_lattice_parent_idx) 72 | self._parent_relations = [frozenset(parents) for parents in self._parent_relations] # Reconvert to frozenset 73 | 74 | @staticmethod 75 | def load(filename: str) -> 'Lattice': 76 | types = load_json_gz(filename) 77 | return Lattice(types['types'], types['outgoingEdges']) 78 | 79 | 80 | class LatticeVocabulary(Vocabulary): 81 | """A feature dictionary that instead of returning UNKs, closest parent element in a 82 | lattice""" 83 | 84 | def __init__(self, lattice: Lattice) -> None: 85 | super().__init__(True) 86 | self.__lattice = lattice 87 | 88 | def is_unk(self, token: str) -> bool: 89 | return token not in self.token_to_id 90 | 91 | @lru_cache(maxsize=512) 92 | def __get_list_of_implemented_types(self, token, alternative_lattice: Optional[Lattice] = None) -> List[int]: 93 | if token.startswith('type:'): 94 | if alternative_lattice is None: 95 | type_parents = self.__lattice.parents(token[len('type:'):]) 96 | else: 97 | type_parents = alternative_lattice.parents(token[len('type:'):]) 98 | return [self.token_to_id[t] for t in chain([token], ['type:' + p for p in type_parents]) if 99 | t in self.token_to_id] 100 | if token in self.token_to_id: 101 | return [self.token_to_id[token]] 102 | return [] 103 | 104 | def get_id_or_unk(self, token: str, alternative_lattice: Optional[Lattice] = None) -> List[int]: 105 | type_list = self.__get_list_of_implemented_types(token, alternative_lattice) 106 | if len(type_list) == 0: 107 | return [self.token_to_id[self.get_unk()]] 108 | return type_list 109 | 110 | def get_id_or_none(self, token: str, alternative_lattice: Optional[Lattice] = None) -> List[Optional[int]]: 111 | type_list = self.__get_list_of_implemented_types(token, alternative_lattice) 112 | if len(type_list) == 0: 113 | return [None] 114 | return type_list 115 | 116 | def add_batch_tokens(self, tokens: Iterable[str], lattice: Lattice, count_threshold: int = 5) -> None: 117 | token_counter = Counter(tokens) 118 | for token, count in list(token_counter.items()): 119 | if token.startswith('type:'): 120 | type_name = token[len('type:'):] 121 | for t in lattice.parents(type_name): 122 | token_counter['type:' + t] += count 123 | for token, count in token_counter.items(): 124 | if count >= count_threshold: 125 | self.add_or_get_id(token) 126 | 127 | @staticmethod 128 | def get_feature_dictionary_for(tokens: Iterable[str], lattice: Lattice, 129 | count_threshold: int = 5) -> 'LatticeVocabulary': 130 | """Deprecated: Use `get_vocabulary_for` instead.""" 131 | return LatticeVocabulary.get_vocabulary_for(tokens, lattice, count_threshold) 132 | 133 | @staticmethod 134 | def get_vocabulary_for(tokens: Union[Iterable[str], Counter], lattice: Lattice, 135 | count_threshold: int = 5, max_size: int = 100000) -> 'LatticeVocabulary': 136 | if type(tokens) is Counter: 137 | token_counter = tokens 138 | else: 139 | token_counter = Counter(tokens) 140 | for token, count in list(token_counter.items()): 141 | if token.startswith('type:'): 142 | type_name = token[len('type:'):] 143 | for t in lattice.parents(type_name): 144 | token_counter['type:' + t] += count 145 | 146 | feature_dict = LatticeVocabulary(lattice) 147 | for token, count in token_counter.most_common(max_size): 148 | if count >= count_threshold: 149 | feature_dict.add_or_get_id(token) 150 | return feature_dict 151 | -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/text.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | import io 3 | from typing import Tuple 4 | 5 | 6 | def get_code_in_range(code_text: str, start_pos: Tuple[int, int], end_pos: Tuple[int, int]) -> str: 7 | """ 8 | Get the code text given a range in the form (start_line, start_column), (end_line, end_column). 9 | 10 | Notes: 11 | * This follow common IDE convention where lines and columns are 1-based (i.e., the first line 12 | is line 1 not line 0 as it would be in array indexing) 13 | * The range is inclusive (i.e., both start and end are included) 14 | 15 | :param code_text: the string representation of the code. 16 | :param start_pos: The starting position of the range in `code_text` as a tuple the form (start_line, start_column). 17 | :param end_pos: The end position of the range in `code_text` as a tuple the form (end_line, end_column). 18 | 19 | :return: the string representation of the target range. 20 | """ 21 | (start_line, start_column), (end_line, end_column) = start_pos, end_pos 22 | assert start_line < end_line or (start_line == end_line and start_column <= end_column), "Invalid range" 23 | 24 | # Adjust `target_range` columns to be 0-based and the end_column to be exclusive 25 | start_column -= 1 26 | 27 | with io.StringIO(code_text) as input_sb, io.StringIO() as output_sb: 28 | for line_no in count(start=1): 29 | next_input_line = input_sb.readline() 30 | if len(next_input_line) == 0: # No bytes read 31 | raise ValueError("EOF reached before target_range.") 32 | 33 | if start_line <= line_no <= end_line: 34 | if start_line == end_line: 35 | output_sb.write(next_input_line[start_column: end_column]) 36 | break 37 | elif line_no == start_line: 38 | output_sb.write(next_input_line[start_column:]) 39 | elif line_no == end_line: 40 | output_sb.write(next_input_line[:end_column]) 41 | break 42 | else: 43 | output_sb.write(next_input_line) 44 | elif line_no > end_line: 45 | raise Exception("Unreachable state.") 46 | return output_sb.getvalue() -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/treesitter/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import parser_for -------------------------------------------------------------------------------- /python/dpu_utils/codeutils/treesitter/parser.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from tempfile import TemporaryDirectory 4 | 5 | from tree_sitter import Language, Parser 6 | 7 | __all__ = ["parser_for"] 8 | 9 | _LANGUAGE_REPOS = { 10 | "bash": ("https://github.com/tree-sitter/tree-sitter-bash", ""), 11 | "c": ("https://github.com/tree-sitter/tree-sitter-c", ""), 12 | "c_sharp": ("https://github.com/tree-sitter/tree-sitter-c-sharp", ""), 13 | "css": ("https://github.com/tree-sitter/tree-sitter-css", ""), 14 | "cpp": ("https://github.com/tree-sitter/tree-sitter-cpp", ""), 15 | "html": ("https://github.com/tree-sitter/tree-sitter-html", ""), 16 | "go": ("https://github.com/tree-sitter/tree-sitter-go", ""), 17 | "java": ("https://github.com/tree-sitter/tree-sitter-java", ""), 18 | "javascript": ("https://github.com/tree-sitter/tree-sitter-javascript", ""), 19 | "julia": ("https://github.com/tree-sitter/tree-sitter-julia", ""), 20 | "php": ("https://github.com/tree-sitter/tree-sitter-php", ""), 21 | "python": ("https://github.com/tree-sitter/tree-sitter-python", ""), 22 | "ruby": ("https://github.com/tree-sitter/tree-sitter-ruby", ""), 23 | "rust": ("https://github.com/tree-sitter/tree-sitter-rust", ""), 24 | "scala": ("https://github.com/tree-sitter/tree-sitter-scala", ""), 25 | "typescript": ("https://github.com/tree-sitter/tree-sitter-typescript", "typescript"), 26 | } 27 | 28 | LIBRARY_DIR = os.path.join(os.path.dirname(__file__), "build", "treesitter-lib.so") 29 | TREE_SITTER_LANG_VER = "v0.19.0" 30 | 31 | if not os.path.exists(LIBRARY_DIR): 32 | logging.warning("TreeSitter has not been compiled. Cloning languages and building...") 33 | from git import Repo 34 | with TemporaryDirectory() as dir: 35 | # Clone all repos above at the given tag 36 | repo_dirs = [] 37 | for lang, (url, suffix) in _LANGUAGE_REPOS.items(): 38 | logging.warning(f"Cloning `{lang}`...") 39 | repo_dir = os.path.join(dir, lang) 40 | repo = Repo.clone_from(url, repo_dir) 41 | repo.git.checkout(TREE_SITTER_LANG_VER) 42 | repo_dirs.append(os.path.join(repo_dir, suffix)) 43 | 44 | # Build library by pointing to each repo 45 | logging.warning(f"Building Tree-sitter Library...") 46 | Language.build_library(LIBRARY_DIR, repo_dirs) 47 | 48 | _LANGUAGES = {} 49 | for language in _LANGUAGE_REPOS: 50 | _LANGUAGES[language] = Language(LIBRARY_DIR, language) 51 | 52 | # Add aliases 53 | _ALIASES = { 54 | "c++": "cpp", 55 | "c#": "c_sharp", 56 | "csharp": "c_sharp", 57 | "js": "javascript", 58 | "ts": "typescript" 59 | } 60 | for alias, target in _ALIASES.items(): 61 | _LANGUAGES[alias] = _LANGUAGES[target] 62 | 63 | def parser_for(language: str) -> Parser: 64 | parser = Parser() 65 | parser.set_language(_LANGUAGES[language]) 66 | return parser 67 | -------------------------------------------------------------------------------- /python/dpu_utils/mlutils/__init__.py: -------------------------------------------------------------------------------- 1 | from .chartensorizer import CharTensorizer 2 | from .vocabulary import Vocabulary 3 | from .bpevocabulary import BpeVocabulary 4 | 5 | __all__ = ['CharTensorizer', 'Vocabulary', 'BpeVocabulary'] 6 | -------------------------------------------------------------------------------- /python/dpu_utils/mlutils/bpevocabulary.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import Counter 4 | from tempfile import TemporaryDirectory 5 | from typing import List, Optional, Sized, Union, Iterable 6 | import sentencepiece as spm 7 | import typing 8 | 9 | __all__ = ['BpeVocabulary'] 10 | 11 | SPIECE_UNDERLINE = u'▁' 12 | 13 | 14 | class BpeVocabulary(Sized): 15 | """ 16 | A vocabulary that maps strings to unique ids (and back), and a tokenizer based on 17 | Byte-Pair Encoding using sentencepiece from https://github.com/google/sentencepiece. 18 | 19 | Sennrich, Rico, Barry Haddow, and Alexandra Birch. 20 | "Neural machine translation of rare words with subword units." 21 | arXiv preprint arXiv:1508.07909 (2015). 22 | 23 | To create a vocabulary use `BpeVocabulary.create_vocabulary()`. 24 | The control flow symbols have to be introduced 25 | manually, during preprocessing step. 26 | 27 | BpeVocabulary object usage: Assuming an initialized vocabulary `v`: 28 | 29 | * To get the the tokenized version of a string `v.tokenize("a single string here")`. 30 | * To get the ids of a string, use `v.get_id_or_unk_for_text("a single string here")`. 31 | * To get a string from a list of the ids of pieces, use `v.convert_ids_to_string([10, 2, 5, 3])`. 32 | * To get the size of the vocabulary use `len(v)` 33 | """ 34 | LOGGER = logging.getLogger('BpeVocabulary') 35 | DEFAULT_CONTROL_SYMBOLS = ["", ""] 36 | 37 | def __init__(self, max_size: int, sentencepiece_model_filepath: Optional[str]=None, 38 | bos_token: str="", eos_token: str="", unk_token: str="", pad_token: str="", 39 | user_defined_symbols: Optional[List[str]] = None, 40 | control_symbols: Optional[List[str]]=None) -> None: 41 | 42 | self.__max_size=max_size 43 | self.__bos_token=bos_token 44 | self.__eos_token=eos_token 45 | self.__unk_token=unk_token 46 | self.__pad_token=pad_token 47 | 48 | self.vocab_file = sentencepiece_model_filepath 49 | if user_defined_symbols is None: 50 | user_defined_symbols = [] 51 | self.user_defined_symbols=",".join(user_defined_symbols) 52 | 53 | if control_symbols is None: 54 | control_symbols = self.DEFAULT_CONTROL_SYMBOLS 55 | self.control_symbols=",".join(control_symbols) 56 | 57 | self.__sp_model = spm.SentencePieceProcessor() 58 | if sentencepiece_model_filepath is not None: 59 | self.__load_model_from_filepath(sentencepiece_model_filepath) 60 | 61 | #region Custom Pickling 62 | def __load_model_from_filepath(self, sentencepiece_model_filepath) -> bool: 63 | loaded = self.__sp_model.Load(sentencepiece_model_filepath) 64 | # We want to encapsulate all vocabulary-related elements in a single location (this object!) and avoid 65 | # dangling files. We store all model data in this object as a set of bytes. 66 | with open(sentencepiece_model_filepath, 'rb') as f: 67 | self.__sp_model_data = f.read() 68 | 69 | return loaded 70 | 71 | def __getstate__(self): 72 | """The __sp_model cannot be serialized. Remove it when pickling.""" 73 | state = self.__dict__.copy() 74 | del state['_BpeVocabulary__sp_model'] 75 | return state 76 | 77 | def __setstate__(self, state): 78 | """Restore __sp_model that could not be serialized.""" 79 | self.__dict__.update(state) 80 | if self.__sp_model_data is None: 81 | return 82 | with TemporaryDirectory() as tmp_dir: 83 | model_file = os.path.join(tmp_dir, 'tmp.model') 84 | with open(model_file, 'wb') as f: 85 | f.write(self.__sp_model_data) 86 | self.__sp_model = spm.SentencePieceProcessor() 87 | self.__sp_model.Load(model_file) 88 | #endregion 89 | 90 | def get_pad(self) -> str: 91 | """ Get padding token. """ 92 | if self.__pad_token is None: 93 | self.LOGGER.error("Using pad_token, but it is not set yet.") 94 | return self.__pad_token 95 | 96 | def get_unk(self) -> str: 97 | """ Get unknown token. """ 98 | if self.__unk_token is None: 99 | self.LOGGER.error("Using unk_token, but it is not set yet.") 100 | return self.__unk_token 101 | 102 | def __len__(self) -> int: 103 | return len(self.__sp_model) 104 | 105 | def tokenize(self, text: str) -> List[str]: 106 | """ Tokenize a string. """ 107 | pieces = self.__sp_model.EncodeAsPieces(text) 108 | 109 | new_pieces = [] # type: List[str] 110 | for piece in pieces: 111 | # Split subtokens composed of a digit and comma 112 | # 113 | # E.g. given in an input sentence: 114 | # text = 'for i in range(100, 2):' 115 | # Default output of tokenizer may be: 116 | # ['▁for', '▁i', '▁in', '▁range', '(1', '00,', '▁2', '):'] 117 | # Following will change this to: 118 | # ['▁for', '▁i', '▁in', '▁range', '(1', '0', '0', ',', '▁2', '):'] 119 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): 120 | cur_pieces = self.__sp_model.EncodeAsPieces( 121 | piece[:-1].replace(SPIECE_UNDERLINE, '')) 122 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 123 | if len(cur_pieces[0]) == 1: 124 | cur_pieces = cur_pieces[1:] 125 | else: 126 | cur_pieces[0] = cur_pieces[0][1:] 127 | cur_pieces.append(piece[-1]) 128 | new_pieces.extend(cur_pieces) 129 | else: 130 | new_pieces.append(piece) 131 | 132 | return new_pieces 133 | 134 | def get_id_or_unk_for_text(self, text: str, pad_to_size: Optional[int] = None, 135 | padding_element: int = 0) -> List[int]: 136 | """ 137 | Tokenize (using BPE) a given string and return a list of the int ids of the wordpieces of the string. 138 | """ 139 | tokens = self.tokenize(text) 140 | 141 | if pad_to_size is not None: 142 | tokens = tokens[:pad_to_size] 143 | 144 | ids = [self.__sp_model.PieceToId(t) for t in tokens] 145 | if pad_to_size is not None and len(ids) != pad_to_size: 146 | ids += [padding_element] * (pad_to_size - len(ids)) 147 | 148 | return ids 149 | 150 | def convert_ids_to_string(self, piece_ids: List[int]) -> str: 151 | """Converts a sequence of piece ids (strings for sub-words) in a single string.""" 152 | out_string = ''.join((self.__sp_model.IdToPiece(i) for i in piece_ids)).replace(SPIECE_UNDERLINE, ' ').strip() 153 | return out_string 154 | 155 | def create_vocabulary_from_file(self, sp_text_file: str, num_threads: Optional[int]=os.cpu_count(), 156 | max_sentence_length: int=16384, character_coverage: float=0.99995) -> None: 157 | """ 158 | Train sentencepiece tokenizer using BPE model and build a vocabulary. 159 | 160 | sp_text_file: path to a plain text file containing the training dataset. 161 | """ 162 | if num_threads is None: 163 | num_threads = 1 164 | with TemporaryDirectory() as tmpdir: 165 | model_filename = os.path.join(tmpdir, f'bpe_{self.__max_size}') 166 | command = [ 167 | f"--input={sp_text_file}", 168 | f"--num_threads={num_threads}", 169 | f"--model_prefix={model_filename}", 170 | f"--vocab_size={self.__max_size}", 171 | f"--model_type=bpe", 172 | f"--max_sentence_length={max_sentence_length}", 173 | f"--bos_piece={self.__bos_token}", 174 | f"--eos_piece={self.__eos_token}", 175 | f"--pad_piece={self.__pad_token}", 176 | f"--pad_id=3", 177 | f"--unk_piece={self.__unk_token}", 178 | f"--user_defined_symbols={self.user_defined_symbols}", 179 | f"--control_symbols={self.control_symbols}", 180 | f"--character_coverage={character_coverage}", 181 | "--minloglevel=1", 182 | "--hard_vocab_limit=false", 183 | ] 184 | 185 | spm.SentencePieceTrainer.train( 186 | " ".join(command) 187 | ) 188 | 189 | loaded = self.__load_model_from_filepath(model_filename+'.model') 190 | assert loaded, 'Sentencepiece failed to load model.' 191 | 192 | def create_vocabulary(self, tokens: Union[Iterable[str], Iterable[List[str]], typing.Counter[str]]) -> None: 193 | with TemporaryDirectory() as dir: 194 | data_path = os.path.join(dir, 'tmpvocab.model') 195 | with open(data_path, 'w') as f: 196 | if isinstance(tokens, Counter): 197 | for token, count in tokens.items(): 198 | for _ in range(count): 199 | f.write(token + '\n') 200 | else: 201 | for element in tokens: 202 | if isinstance(element, str): 203 | f.write(element + '\n') 204 | else: 205 | f.write(' '.join(element)) 206 | f.write('\n') 207 | return self.create_vocabulary_from_file(data_path) 208 | -------------------------------------------------------------------------------- /python/dpu_utils/mlutils/chartensorizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import Optional 4 | 5 | __all__ = ['CharTensorizer'] 6 | 7 | 8 | class CharTensorizer: 9 | """Tensorize strings into characters""" 10 | 11 | def __init__(self, max_num_chars: Optional[int], lower_case_all: bool, include_space: bool): 12 | self.__max_num_chars = max_num_chars 13 | self.__lower_case_all = lower_case_all 14 | 15 | self.__ALPHABET = "abcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}" 16 | if not self.__lower_case_all: 17 | self.__ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + self.__ALPHABET 18 | if include_space: 19 | self.__ALPHABET += ' ' 20 | 21 | self.__ALPHABET_DICT = {char: idx + 2 for (idx, char) in enumerate(self.__ALPHABET)} # "0" is PAD, "1" is UNK 22 | self.__ALPHABET_DICT['PAD'] = 0 23 | self.__ALPHABET_DICT['UNK'] = 1 24 | 25 | @property 26 | def max_char_length(self) -> Optional[int]: 27 | return self.__max_num_chars 28 | 29 | def num_chars_in_vocabulary(self)-> int: 30 | return len(self.__ALPHABET_DICT) 31 | 32 | def __get_char_idx(self, character: str) -> int: 33 | idx = self.__ALPHABET_DICT.get(character) 34 | if idx is not None: 35 | return idx 36 | return self.__ALPHABET_DICT['UNK'] 37 | 38 | def get_word_from_ids(self, ids: np.ndarray)-> str: 39 | return ''.join(self.__ALPHABET[i] if i!=1 else '' for i in ids) 40 | 41 | def tensorize_str(self, input: str) -> np.ndarray: 42 | if self.__lower_case_all: 43 | input = input.lower() 44 | 45 | def char_iterator(): 46 | for i, c in enumerate(input): 47 | if self.__max_num_chars is not None and i >= self.__max_num_chars: 48 | break 49 | yield self.__get_char_idx(c) 50 | if self.__max_num_chars is not None and len(input) < self.__max_num_chars: 51 | pad_id = self.__get_char_idx('PAD') 52 | yield from (pad_id for _ in range(self.__max_num_chars - len(input))) 53 | return np.fromiter(char_iterator(), dtype=np.uint8) 54 | -------------------------------------------------------------------------------- /python/dpu_utils/mlutils/vocabulary.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import typing 3 | from typing import Iterable, Dict, Sized, List, FrozenSet, Union, Optional 4 | 5 | import numpy as np 6 | 7 | __all__ = ['Vocabulary'] 8 | 9 | 10 | class Vocabulary(Sized): 11 | """ 12 | A simple vocabulary that maps strings to unique ids (and back). 13 | 14 | To create a vocabulary use `Vocabulary.create_vocabulary()` and pass 15 | a counter or an iterator of the elements that the vocabulary will contain. 16 | 17 | Vocabulary object usage: Assuming an initialized vocabulary `v`: 18 | 19 | * To get the id of an element use `v.get_id_or_unk("element")`. 20 | * To get the ids of a sequence, use `v.get_id_or_unk_multiple(..)`. 21 | * To get the size of the vocabulary use `len(v)` 22 | * To get the string representation for a given id use `v.get_name_for_id(the_id)`. 23 | """ 24 | 25 | def __init__(self, add_unk: bool=True, add_pad: bool=False) -> None: 26 | self.token_to_id = {} # type: Dict[str, int] 27 | self.id_to_token = [] # type: List[str] 28 | if add_pad: 29 | self.add_or_get_id(self.get_pad()) 30 | if add_unk: 31 | self.add_or_get_id(self.get_unk()) 32 | 33 | def is_equal_to(self, other) -> bool: 34 | """ 35 | This would be __eq__, except that Python 3 insists that __hash__ is defined if __eq__ is 36 | defined, and we can't define __hash__ because the object is mutable. 37 | """ 38 | if not isinstance(other, Vocabulary): 39 | return False 40 | return self.id_to_token == other.id_to_token 41 | 42 | def add_or_get_id(self, token: str) -> int: 43 | """ 44 | Get the token's id. If the token does not exist in the 45 | dictionary, add it. 46 | """ 47 | this_id = self.token_to_id.get(token) 48 | if this_id is not None: 49 | return this_id 50 | 51 | this_id = len(self.id_to_token) 52 | self.token_to_id[token] = this_id 53 | self.id_to_token.append(token) 54 | return this_id 55 | 56 | def is_unk(self, token: str) -> bool: 57 | return token not in self.token_to_id 58 | 59 | def get_id_or_unk(self, token: str) -> int: 60 | this_id = self.token_to_id.get(token) 61 | if this_id is not None: 62 | return this_id 63 | else: 64 | return self.token_to_id[self.get_unk()] 65 | 66 | def get_id_or_unk_multiple(self, tokens: List[str], pad_to_size: Optional[int] = None, 67 | padding_element: int = 0) -> List[int]: 68 | if pad_to_size is not None: 69 | tokens = tokens[:pad_to_size] 70 | 71 | ids = [self.get_id_or_unk(t) for t in tokens] 72 | 73 | if pad_to_size is not None and len(ids) != pad_to_size: 74 | ids += [padding_element] * (pad_to_size - len(ids)) 75 | 76 | return ids 77 | 78 | def get_name_for_id(self, token_id: int) -> str: 79 | return self.id_to_token[token_id] 80 | 81 | def __len__(self) -> int: 82 | return len(self.token_to_id) 83 | 84 | def __str__(self): 85 | return str(self.token_to_id) 86 | 87 | def get_all_names(self) -> FrozenSet[str]: 88 | return frozenset(self.token_to_id.keys()) 89 | 90 | def __batch_add_from_counter(self, token_counter: typing.Counter[str], count_threshold: int, max_size: int) -> None: 91 | """Update dictionary with elements of the token_counter""" 92 | for token, count in token_counter.most_common(max_size): 93 | if count >= count_threshold: 94 | self.add_or_get_id(token) 95 | else: 96 | break 97 | 98 | @staticmethod 99 | def get_unk() -> str: 100 | return '%UNK%' 101 | 102 | @staticmethod 103 | def get_pad() -> str: 104 | return '%PAD%' 105 | 106 | @staticmethod 107 | def create_vocabulary(tokens: Union[Iterable[str], typing.Counter[str]], max_size: int, 108 | count_threshold: int = 5, add_unk: bool = True, add_pad: bool = False) -> 'Vocabulary': 109 | if isinstance(tokens, Counter): 110 | token_counter = tokens 111 | else: 112 | token_counter = Counter(tokens) 113 | vocab = Vocabulary(add_unk=add_unk, add_pad=add_pad) 114 | num_base_tokens = (1 if add_unk else 0) + (1 if add_pad else 0) 115 | vocab.__batch_add_from_counter(token_counter, count_threshold, max_size - num_base_tokens) 116 | return vocab 117 | 118 | def update(self, token_counter: typing.Counter[str], max_size: int, count_threshold: int=5): 119 | assert len(self) < max_size, 'Dictionary is already larger than max_size.' 120 | self.__batch_add_from_counter(token_counter, count_threshold=count_threshold, max_size=max_size) 121 | 122 | def get_empirical_distribution(self, elements: Iterable[str], dirichlet_alpha: float = 10.) -> np.ndarray: 123 | """Retrieve empirical distribution of elements.""" 124 | targets = np.fromiter((self.get_id_or_unk(t) for t in elements), dtype=np.int) 125 | empirical_distribution = np.bincount(targets, minlength=len(self)).astype(float) 126 | empirical_distribution += dirichlet_alpha / len(empirical_distribution) 127 | return empirical_distribution / (np.sum(empirical_distribution) + dirichlet_alpha) 128 | -------------------------------------------------------------------------------- /python/dpu_utils/ptutils/__init__.py: -------------------------------------------------------------------------------- 1 | from .basecomponent import BaseComponent 2 | from .componenttrainer import ComponentTrainer, AbstractScheduler -------------------------------------------------------------------------------- /python/dpu_utils/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/dpu-utils/bc4f5fa6c92c4b073fff4d93269512b0027f7f02/python/dpu_utils/py.typed -------------------------------------------------------------------------------- /python/dpu_utils/tf2utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import gelu, get_activation_function_by_name 2 | from .mlp import MLP 3 | from .unsorted_segment_ops import unsorted_segment_log_softmax, unsorted_segment_logsumexp, unsorted_segment_softmax -------------------------------------------------------------------------------- /python/dpu_utils/tf2utils/activation.py: -------------------------------------------------------------------------------- 1 | """Custom activation functions.""" 2 | from typing import Optional, Callable 3 | import math as m 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def gelu(input_tensor: tf.Tensor) -> tf.Tensor: 9 | """An approximation to the GELU activation function as used in the paper 10 | https://arxiv.org/pdf/1810.04805.pdf 11 | """ 12 | cdf = 0.5 * ( 13 | 1.0 14 | + tf.tanh( 15 | (tf.sqrt(2 / m.pi) * (input_tensor + 0.044715 * tf.pow(input_tensor, 3))) 16 | ) 17 | ) 18 | return input_tensor * cdf 19 | 20 | 21 | def get_activation_function_by_name( 22 | activation_fn_name: Optional[str], 23 | ) -> Optional[Callable[[tf.Tensor], tf.Tensor]]: 24 | """Convert from an activation function name to the function itself.""" 25 | if activation_fn_name is None: 26 | return None 27 | activation_fn_name = activation_fn_name.lower() 28 | 29 | string_to_activation_fn = { 30 | "linear": None, 31 | "tanh": tf.nn.tanh, 32 | "relu": tf.nn.relu, 33 | "leaky_relu": tf.nn.leaky_relu, 34 | "elu": tf.nn.elu, 35 | "selu": tf.nn.selu, 36 | "gelu": gelu, 37 | } 38 | activation_fn = string_to_activation_fn.get(activation_fn_name) 39 | if activation_fn is None: 40 | raise ValueError(f"Unknown activation function: {activation_fn_name}") 41 | return activation_fn 42 | -------------------------------------------------------------------------------- /python/dpu_utils/tf2utils/constants.py: -------------------------------------------------------------------------------- 1 | SMALL_NUMBER = 1e-7 -------------------------------------------------------------------------------- /python/dpu_utils/tf2utils/mlp.py: -------------------------------------------------------------------------------- 1 | """MLP layer.""" 2 | import sys 3 | from typing import Callable, List, Optional, Union 4 | 5 | import tensorflow as tf 6 | 7 | 8 | class MLP(tf.keras.layers.Layer): 9 | def __init__( 10 | self, 11 | out_size: int, 12 | hidden_layers: Union[List[int], int] = 1, 13 | use_biases: bool = False, 14 | activation_fun: Optional[Callable[[tf.Tensor], tf.Tensor]] = tf.nn.relu, 15 | dropout_rate: float = 0.0, 16 | name: str = "MLP", 17 | ): 18 | """ 19 | Create new MLP with given number of hidden layers. 20 | 21 | Arguments: 22 | out_size: Dimensionality of output. 23 | hidden_layers: Either an integer determining number of hidden layers, which will have 24 | out_size units each; or list of integers whose lengths determines the number of 25 | hidden layers and whose contents the number of units in each layer. 26 | use_biases: Flag indicating use of bias in fully connected layers. 27 | activation_fun: Activation function applied between hidden layers (NB: the output of the 28 | MLP is always the direct result of a linear transformation) 29 | dropout_rate: Dropout applied to inputs of each MLP layer. 30 | name: Name of the MLP, used in names of created variables. 31 | """ 32 | super().__init__() 33 | if isinstance(hidden_layers, int): 34 | if out_size == 1: 35 | print( 36 | f"W: In {name}, was asked to use {hidden_layers} layers of size 1, which is most likely wrong." 37 | f" Switching to {hidden_layers} layers of size 32; to get hidden layers of size 1," 38 | f" use hidden_layers=[1,...,1] explicitly.", 39 | file=sys.stderr, 40 | ) 41 | self._hidden_layer_sizes = [32] * hidden_layers 42 | else: 43 | self._hidden_layer_sizes = [out_size] * hidden_layers 44 | else: 45 | self._hidden_layer_sizes = hidden_layers 46 | 47 | if len(self._hidden_layer_sizes) > 1: 48 | assert ( 49 | activation_fun is not None 50 | ), "Multiple linear layers without an activation" 51 | 52 | self._out_size = out_size 53 | self._use_biases = use_biases 54 | self._activation_fun = activation_fun 55 | self._dropout_rate = dropout_rate 56 | self._layers = [] # type: List[tf.keras.layers.Dense] 57 | self._name = name 58 | 59 | def build(self, input_shape): 60 | last_shape_dim = input_shape[-1] 61 | for hidden_layer_idx, hidden_layer_size in enumerate(self._hidden_layer_sizes): 62 | with tf.name_scope(f"{self._name}_dense_layer_{hidden_layer_idx}"): 63 | self._layers.append( 64 | tf.keras.layers.Dense( 65 | units=hidden_layer_size, 66 | use_bias=self._use_biases, 67 | activation=self._activation_fun, 68 | name=f"{self._name}_dense_layer_{hidden_layer_idx}", 69 | ) 70 | ) 71 | self._layers[-1].build(tf.TensorShape(input_shape[:-1] + [last_shape_dim])) 72 | last_shape_dim = hidden_layer_size 73 | 74 | # Output layer: 75 | with tf.name_scope(f"{self._name}_final_layer"): 76 | self._layers.append( 77 | tf.keras.layers.Dense( 78 | units=self._out_size, 79 | use_bias=self._use_biases, 80 | name=f"{self._name}_final_layer", 81 | ) 82 | ) 83 | self._layers[-1].build(tf.TensorShape(input_shape[:-1] + [last_shape_dim])) 84 | 85 | super().build(input_shape) 86 | 87 | @tf.function(experimental_relax_shapes=True) 88 | def call(self, input: tf.Tensor, training: bool) -> tf.Tensor: 89 | activations = input 90 | for layer in self._layers[:-1]: 91 | if training: 92 | activations = tf.nn.dropout(activations, rate=self._dropout_rate) 93 | activations = layer(activations) 94 | return self._layers[-1](activations) 95 | -------------------------------------------------------------------------------- /python/dpu_utils/tf2utils/unsorted_segment_ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .constants import SMALL_NUMBER 4 | 5 | 6 | @tf.function 7 | def unsorted_segment_logsumexp(scores, segment_ids, num_segments): 8 | """Perform an unsorted segment safe logsumexp.""" 9 | # Note: if a segment is empty, the smallest value for the score will be returned, 10 | # which yields the correct behavior 11 | max_per_segment = tf.math.unsorted_segment_max( 12 | data=scores, segment_ids=segment_ids, num_segments=num_segments 13 | ) 14 | scattered_log_maxes = tf.gather(params=max_per_segment, indices=segment_ids) 15 | recentered_scores = scores - scattered_log_maxes 16 | exped_recentered_scores = tf.math.exp(recentered_scores) 17 | 18 | per_segment_sums = tf.math.unsorted_segment_sum( 19 | exped_recentered_scores, segment_ids, num_segments 20 | ) 21 | per_segment_logs = tf.math.log(per_segment_sums) 22 | return per_segment_logs + max_per_segment 23 | 24 | 25 | @tf.function 26 | def unsorted_segment_log_softmax(logits, segment_ids, num_segments): 27 | """Perform an unsorted segment safe log_softmax.""" 28 | # Note: if a segment is empty, the smallest value for the score will be returned, 29 | # which yields the correct behavior 30 | max_per_segment = tf.math.unsorted_segment_max( 31 | data=logits, segment_ids=segment_ids, num_segments=num_segments 32 | ) 33 | scattered_maxes = tf.gather(params=max_per_segment, indices=segment_ids) 34 | recentered_scores = logits - scattered_maxes 35 | exped_recentered_scores = tf.math.exp(recentered_scores) 36 | 37 | per_segment_sums = tf.math.unsorted_segment_sum( 38 | exped_recentered_scores, segment_ids, num_segments 39 | ) 40 | per_segment_normalization_consts = tf.math.log(per_segment_sums) 41 | 42 | log_probs = recentered_scores - tf.gather( 43 | params=per_segment_normalization_consts, indices=segment_ids 44 | ) 45 | return log_probs 46 | 47 | 48 | @tf.function 49 | def unsorted_segment_softmax(logits, segment_ids, num_segments): 50 | """Perform a safe unsorted segment softmax.""" 51 | max_per_segment = tf.math.unsorted_segment_max( 52 | data=logits, segment_ids=segment_ids, num_segments=num_segments 53 | ) 54 | scattered_maxes = tf.gather(params=max_per_segment, indices=segment_ids) 55 | recentered_scores = logits - scattered_maxes 56 | exped_recentered_scores = tf.math.exp(recentered_scores) 57 | 58 | per_segment_sums = tf.math.unsorted_segment_sum( 59 | exped_recentered_scores, segment_ids, num_segments 60 | ) 61 | 62 | probs = exped_recentered_scores / ( 63 | tf.gather(params=per_segment_sums, indices=segment_ids) + SMALL_NUMBER 64 | ) 65 | return probs 66 | -------------------------------------------------------------------------------- /python/dpu_utils/tfmodels/__init__.py: -------------------------------------------------------------------------------- 1 | from .sparsegnn import SparseGGNN 2 | from .asyncgnn import AsyncGGNN 3 | -------------------------------------------------------------------------------- /python/dpu_utils/tfutils/__init__.py: -------------------------------------------------------------------------------- 1 | from .gradratiologgingoptimizer import GradRatioLoggingOptimizer 2 | from .unsortedsegmentops import unsorted_segment_log_softmax, unsorted_segment_logsumexp, unsorted_segment_softmax 3 | from .tfvariablesaver import TFVariableSaver 4 | from .pick_indices import pick_indices_from_probs 5 | from .activation import get_activation 6 | -------------------------------------------------------------------------------- /python/dpu_utils/tfutils/activation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | 3 | import tensorflow as tf 4 | 5 | __all__ = [ 'get_activation' ] 6 | 7 | def get_activation(activation_fun: Optional[str]) -> Optional[Callable]: 8 | if activation_fun is None: 9 | return None 10 | activation_fun = activation_fun.lower() 11 | if activation_fun == 'linear': 12 | return None 13 | if activation_fun == 'tanh': 14 | return tf.tanh 15 | if activation_fun == 'relu': 16 | return tf.nn.relu 17 | if activation_fun == 'leaky_relu': 18 | return tf.nn.leaky_relu 19 | if activation_fun == 'elu': 20 | return tf.nn.elu 21 | if activation_fun == 'selu': 22 | return tf.nn.selu 23 | if activation_fun == 'gelu': 24 | def gelu(input_tensor): 25 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 26 | return input_tensor * cdf 27 | return gelu 28 | else: 29 | raise ValueError("Unknown activation function '%s'!" % activation_fun) -------------------------------------------------------------------------------- /python/dpu_utils/tfutils/gradratiologgingoptimizer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import tensorflow as tf 4 | from tensorflow.python.ops import control_flow_ops 5 | 6 | 7 | class GradRatioLoggingOptimizer: 8 | """Wraps optimizers that compute the ratio of the update to the parameter values.""" 9 | def __init__(self, optimizer, name='training-optimizer'): 10 | self.__optimizer = optimizer 11 | self.__name = name 12 | self.__acc_count = tf.Variable(0, dtype=tf.int32, trainable=False) 13 | self.__grad_ratio_acc_vars = OrderedDict() # type: OrderedDict[str, tf.Variable] 14 | 15 | @property 16 | def optimizer(self): 17 | return self.__optimizer 18 | 19 | def print_ratios(self, session: tf.Session): 20 | count = self.__acc_count.eval(session) + 1e-10 21 | print('======================') 22 | print('Gradient Ratios') 23 | print('======================') 24 | for name, acc in self.__grad_ratio_acc_vars.items(): 25 | print('%s: %.2e' % (name, acc.eval(session) / count)) 26 | 27 | reset_ops = [tf.assign(self.__acc_count, 0)] + [tf.assign(v, 0) for v in self.__grad_ratio_acc_vars.values()] 28 | session.run(reset_ops) 29 | 30 | def minimize(self, loss): 31 | update_ops = [tf.assign_add(self.__acc_count, 1)] 32 | gradients_and_vars = self.__optimizer.compute_gradients(loss) 33 | for grad, var in gradients_and_vars: 34 | if grad is None: 35 | continue 36 | grad_ratio = tf.sqrt(tf.reduce_sum(tf.pow(grad, 2)) / tf.reduce_sum(tf.pow(var, 2))) 37 | ratio_acc_var = tf.Variable(0, trainable=False, dtype=tf.float32) 38 | self.__grad_ratio_acc_vars[var.name] = ratio_acc_var 39 | update_ops.append(tf.assign_add(ratio_acc_var, grad_ratio)) 40 | grad_apply_op = self.__optimizer.apply_gradients(gradients_and_vars) 41 | update_ops.append(grad_apply_op) 42 | return control_flow_ops.group(*update_ops) -------------------------------------------------------------------------------- /python/dpu_utils/tfutils/pick_indices.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Iterable 3 | 4 | import numpy as np 5 | 6 | BIG_NUMBER = 1e7 7 | SMALL_NUMBER = 1e-7 8 | 9 | 10 | def pick_indices_from_probs(probs: np.ndarray, num_picks: int, use_sampling: bool=False, 11 | temperature: float=0.5) -> Iterable[int]: 12 | """Given an array of probabilities, pick up to num_samples unique indices from it.""" 13 | if use_sampling: 14 | # First, consider the temperature for sampling: 15 | probs = probs ** (1.0 / temperature) 16 | normaliser = np.sum(probs) 17 | probs = probs / normaliser 18 | 19 | probs_cum = np.cumsum(probs) 20 | probs_cum[-1] = 1.0 # To protect against floating point oddness 21 | picked_indices = set() 22 | remaining_picks = num_picks * 10 23 | while len(picked_indices) < num_picks and remaining_picks > 0: 24 | remaining_picks -= 1 25 | picked_val = random.random() 26 | picked_index = np.argmax(probs_cum >= picked_val) # type: int 27 | if picked_index not in picked_indices and probs[picked_index] > SMALL_NUMBER: 28 | picked_indices.add(picked_index) 29 | return picked_indices 30 | else: 31 | num_samples = min(num_picks, len(probs)) 32 | top_k_indices = np.argpartition(probs, -num_samples)[-num_samples:] 33 | top_k_indices = [index for index in top_k_indices if probs[index] > SMALL_NUMBER] 34 | return top_k_indices 35 | -------------------------------------------------------------------------------- /python/dpu_utils/tfutils/tfvariablesaver.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from typing import Callable, Dict, Optional 3 | import numpy as np 4 | 5 | 6 | class TFVariableSaver: 7 | """ 8 | Save all variables in the graph and restore them, in a way that the values are serializable by pickle. 9 | """ 10 | def __init__(self): 11 | self.__saved_variables = {} # type: Dict[str, np.ndarray] 12 | 13 | def save_all(self, session: tf.Session, exclude_variable: Optional[Callable[[str], bool]]=None) -> None: 14 | self.__saved_variables = {} 15 | for variable in session.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): 16 | assert variable.name not in self.__saved_variables 17 | if exclude_variable is not None and exclude_variable(variable.name): 18 | continue 19 | self.__saved_variables[variable.name] = variable.value().eval() 20 | 21 | def has_saved_variables(self) -> bool: 22 | return len(self.__saved_variables) > 0 23 | 24 | def restore_saved_values(self, session: tf.Session) -> None: 25 | assert len(self.__saved_variables) > 0 26 | save_ops = [] 27 | with tf.name_scope("restore"): 28 | for variable in session.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): 29 | if variable.name in self.__saved_variables: 30 | saved_value = self.__saved_variables[variable.name] 31 | if len(variable.shape) == 0 or variable.shape[0]._value == saved_value.shape[0]: # Scalars or the size hasn't changed. 32 | save_ops.append(variable.assign(saved_value)) 33 | else: 34 | # Allow expanding saved variables 35 | print('Store value for %s has shape %s but the variable has shape %s. Padding with zeros.' 36 | % (variable.name, saved_value.shape, variable.shape)) 37 | 38 | initial_value = np.zeros([variable.shape[i]._value for i in range(len(variable.shape))], 39 | dtype=variable.dtype.as_numpy_dtype) 40 | initial_value[:saved_value.shape[0]] = saved_value 41 | save_ops.append(variable.assign(initial_value)) 42 | else: 43 | print('Initializing %s from random since no saved value was found.' % variable.name) 44 | save_ops.append(tf.variables_initializer([variable])) 45 | session.run(save_ops) 46 | -------------------------------------------------------------------------------- /python/dpu_utils/tfutils/unsortedsegmentops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | SMALL_NUMBER = 1e-7 4 | 5 | def unsorted_segment_logsumexp(scores, segment_ids, num_segments): 6 | """Perform an unsorted segment safe logsumexp.""" 7 | # Note: if a segment is empty, the smallest value for the score will be returned, 8 | # which yields the correct behavior 9 | max_per_segment = tf.unsorted_segment_max(data=scores, 10 | segment_ids=segment_ids, 11 | num_segments=num_segments) 12 | scattered_log_maxes = tf.gather(params=max_per_segment, 13 | indices=segment_ids) 14 | recentered_scores = scores - scattered_log_maxes 15 | exped_recentered_scores = tf.exp(recentered_scores) 16 | 17 | per_segment_sums = tf.unsorted_segment_sum(exped_recentered_scores, segment_ids, num_segments) 18 | per_segment_logs = tf.log(per_segment_sums) 19 | return per_segment_logs + max_per_segment 20 | 21 | 22 | def unsorted_segment_log_softmax(logits, segment_ids, num_segments): 23 | """Perform an unsorted segment safe log_softmax.""" 24 | # Note: if a segment is empty, the smallest value for the score will be returned, 25 | # which yields the correct behavior 26 | max_per_segment = tf.unsorted_segment_max(data=logits, 27 | segment_ids=segment_ids, 28 | num_segments=num_segments) 29 | scattered_maxes = tf.gather(params=max_per_segment, 30 | indices=segment_ids) 31 | recentered_scores = logits - scattered_maxes 32 | exped_recentered_scores = tf.exp(recentered_scores) 33 | 34 | per_segment_sums = tf.unsorted_segment_sum(exped_recentered_scores, segment_ids, num_segments) 35 | per_segment_normalization_consts = tf.log(per_segment_sums) 36 | 37 | log_probs = recentered_scores - tf.gather(params=per_segment_normalization_consts, indices=segment_ids) 38 | return log_probs 39 | 40 | 41 | def unsorted_segment_softmax(logits, segment_ids, num_segments): 42 | """Perform a safe unsorted segment softmax.""" 43 | max_per_segment = tf.unsorted_segment_max(data=logits, 44 | segment_ids=segment_ids, 45 | num_segments=num_segments) 46 | scattered_maxes = tf.gather(params=max_per_segment, 47 | indices=segment_ids) 48 | recentered_scores = logits - scattered_maxes 49 | exped_recentered_scores = tf.exp(recentered_scores) 50 | 51 | per_segment_sums = tf.unsorted_segment_sum(exped_recentered_scores, segment_ids, num_segments) 52 | 53 | probs = exped_recentered_scores / (tf.gather(params=per_segment_sums, indices=segment_ids) + SMALL_NUMBER) 54 | return probs 55 | -------------------------------------------------------------------------------- /python/dpu_utils/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .richpath import RichPath, AzurePath, LocalPath 2 | from .iterators import ThreadedIterator, BufferedIterator, DoubleBufferedIterator, MultiWorkerCallableIterator, \ 3 | shuffled_iterator, uniform_sample_iterator, subsample_iterator 4 | from .dataloading import load_json_gz, save_json_gz, load_jsonl_gz, save_jsonl_gz 5 | from .gitlog import git_tag_run 6 | from .debughelper import run_and_debug 7 | from .chunkwriter import ChunkWriter 8 | 9 | __all__ = ['RichPath', 'AzurePath', 'LocalPath', 10 | 'ThreadedIterator', 'BufferedIterator', 'DoubleBufferedIterator', 'MultiWorkerCallableIterator', 'shuffled_iterator', 11 | 'subsample_iterator', 'uniform_sample_iterator', 12 | 'load_json_gz', 'save_json_gz', 'load_jsonl_gz', 'save_jsonl_gz', 13 | 'git_tag_run', 14 | 'run_and_debug', 15 | 'ChunkWriter'] 16 | -------------------------------------------------------------------------------- /python/dpu_utils/utils/chunkwriter.py: -------------------------------------------------------------------------------- 1 | import re 2 | from concurrent.futures import ThreadPoolExecutor 3 | from typing import List, Iterable, TypeVar, Generic, Union 4 | 5 | from dpu_utils.utils import RichPath 6 | 7 | T = TypeVar('T') 8 | 9 | __all__ = ['ChunkWriter'] 10 | 11 | 12 | class ChunkWriter(Generic[T]): 13 | """Encapsulates writing output into chunks (multiple consecutive files). 14 | 15 | By setting the file_suffix to either .pkl.gz, .json.gz, .msgpack.l.gz, .msgpack.gz, or .jsonl.gz 16 | the appropriate format will be used for the chunks. 17 | 18 | ChunkWriter can be used either in a context manager, ie 19 | ``` 20 | with ChunkWriter(...) as writer: 21 | writer.add(...) 22 | ``` 23 | 24 | or by explicitly invoking `close()`, ie 25 | ``` 26 | writer = ChunkWriter(...) 27 | # Code that uses add() or add_many() 28 | writer.close() 29 | ``` 30 | """ 31 | def __init__(self, out_folder: Union[RichPath, str], file_prefix: str, max_chunk_size: int, file_suffix: str, 32 | parallel_writers: int = 0, mode: str = 'w'): 33 | self.__current_chunk = [] # type: List[T] 34 | if isinstance(out_folder, str): 35 | self.__out_folder = RichPath.create(out_folder) # type: RichPath 36 | else: 37 | self.__out_folder = out_folder 38 | self.__out_folder.make_as_dir() 39 | self.__file_prefix = file_prefix 40 | self.__max_chunk_size = max_chunk_size 41 | self.__file_suffix = file_suffix 42 | 43 | self.__mode = mode.lower() 44 | assert self.__mode in ('a', 'w'), 'Mode must be either append (a) or write (w). Given: {0}'.format(mode) 45 | 46 | if self.__mode == 'w': 47 | self.__num_files_written = 0 # 'w' mode will begin writing from scratch 48 | else: 49 | self.__num_files_written = self.__get_max_existing_index() + 1 # 'a' mode starts after the last-written file 50 | 51 | self.__parallel_writers = parallel_writers 52 | if self.__parallel_writers > 0: 53 | self.__writer_executors = ThreadPoolExecutor(max_workers=self.__parallel_writers) 54 | 55 | def __write_if_needed(self) -> None: 56 | if len(self.__current_chunk) < self.__max_chunk_size: 57 | return 58 | self.__flush() 59 | 60 | def add(self, element: T) -> None: 61 | self.__current_chunk.append(element) 62 | self.__write_if_needed() 63 | 64 | def add_many(self, elements: Iterable[T]) -> None: 65 | for element in elements: 66 | self.add(element) 67 | 68 | def __flush(self) -> None: 69 | if len(self.__current_chunk) == 0: 70 | return 71 | outfile = self.__out_folder.join( 72 | '%s%03d%s' % (self.__file_prefix, self.__num_files_written, self.__file_suffix) 73 | ) 74 | if self.__parallel_writers > 0: 75 | to_write = self.__current_chunk 76 | self.__writer_executors.submit(lambda: outfile.save_as_compressed_file(to_write)) 77 | else: 78 | outfile.save_as_compressed_file(self.__current_chunk) 79 | self.__current_chunk = [] 80 | self.__num_files_written += 1 81 | 82 | def __enter__(self) -> 'ChunkWriter': 83 | return self 84 | 85 | def __exit__(self, exc_type, exc_val, exc_tb) -> None: 86 | self.close() 87 | 88 | def close(self) -> None: 89 | self.__flush() 90 | if self.__parallel_writers > 0: 91 | self.__writer_executors.shutdown(wait=True) 92 | 93 | def __get_max_existing_index(self) -> int: 94 | """ 95 | Returns the largest file index within the current output folder. 96 | """ 97 | file_pattern = '{0}*{1}'.format(self.__file_prefix, self.__file_suffix) 98 | file_regex = re.compile('.*{0}([0-9]+){1}'.format(self.__file_prefix, self.__file_suffix)) 99 | 100 | max_index = 0 101 | for path in self.__out_folder.iterate_filtered_files_in_dir(file_pattern): 102 | match = file_regex.match(path.path) 103 | if match is None: 104 | continue 105 | file_index = int(match.group(1)) 106 | max_index = max(file_index, max_index) 107 | return max_index 108 | -------------------------------------------------------------------------------- /python/dpu_utils/utils/dataloading.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import codecs 4 | from collections import OrderedDict 5 | from typing import Any, Iterator, Iterable 6 | 7 | __all__ = ['load_json_gz', 'save_json_gz', 'load_jsonl_gz', 'save_jsonl_gz'] 8 | 9 | 10 | def load_json_gz(filename: str) -> Any: 11 | reader = codecs.getreader('utf-8') 12 | with gzip.open(filename) as f: 13 | return json.load(reader(f), object_pairs_hook=OrderedDict) 14 | 15 | 16 | def save_json_gz(data: Any, filename: str) -> None: 17 | writer = codecs.getwriter('utf-8') 18 | with gzip.GzipFile(filename, 'wb') as outfile: 19 | json.dump(data, writer(outfile)) 20 | 21 | 22 | def load_jsonl_gz(filename: str) -> Iterator[Any]: 23 | """ 24 | Iterate through gzipped JSONL files. See http://jsonlines.org/ for more. 25 | """ 26 | reader = codecs.getreader('utf-8') 27 | with gzip.open(filename) as f: 28 | for line in reader(f): 29 | yield json.loads(line, object_pairs_hook=OrderedDict) 30 | 31 | 32 | def save_jsonl_gz(data: Iterable[Any], filename: str) -> None: 33 | with gzip.GzipFile(filename, 'wb') as out_file: 34 | writer = codecs.getwriter('utf-8') 35 | for element in data: 36 | writer(out_file).write(json.dumps(element)) 37 | writer(out_file).write('\n') 38 | -------------------------------------------------------------------------------- /python/dpu_utils/utils/debughelper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | import pdb 4 | from typing import Callable 5 | 6 | __all__= ['run_and_debug'] 7 | 8 | 9 | def run_and_debug(func: Callable[[], None], enable_debugging: bool) -> None: 10 | """ 11 | A wrapper around a running script that triggers the debugger in case of an uncaught exception. 12 | 13 | For example, this can be used as: 14 | ``` 15 | if __name__ == '__main__': 16 | args = docopt(__doc__) 17 | run_and_debug(lambda: run(args), args['--debug']) 18 | ``` 19 | """ 20 | try: 21 | func() 22 | except: 23 | if enable_debugging: 24 | _, value, tb = sys.exc_info() 25 | traceback.print_exc() 26 | pdb.post_mortem(tb) 27 | else: 28 | raise 29 | -------------------------------------------------------------------------------- /python/dpu_utils/utils/gitlog.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | __all__ = ['git_tag_run'] 4 | 5 | 6 | def git_tag_run(train_run_id: str)-> str: 7 | """Tag current version of code in git""" 8 | cur_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']) 9 | 10 | # This call may fail if there are no changes, and we're fine with it: 11 | subprocess.call(['git', 'commit', '-a', '-m', 'Automatic commit of state for run %s' % train_run_id]) 12 | new_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']) 13 | subprocess.check_call(['git', 'tag', 'runs/%s' % train_run_id]) 14 | 15 | # If the hash changed, we created a new commit (otherwise, no changes existed), so back out of that again 16 | if cur_hash != new_hash: 17 | subprocess.check_call(['git', 'reset', '--mixed', 'HEAD^']) 18 | 19 | return new_hash.strip().decode("utf-8") 20 | -------------------------------------------------------------------------------- /python/dpu_utils/utils/msgpackloading.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import random 3 | from collections import OrderedDict 4 | from os import PathLike 5 | from typing import Any, Iterable, Iterator, Optional, Union 6 | 7 | import msgpack 8 | 9 | 10 | from dpu_utils.utils import RichPath 11 | 12 | 13 | def load_msgpack_l_gz(filename: Union[PathLike, str]) -> Iterator[Any]: 14 | with gzip.open(filename) as f: 15 | unpacker = msgpack.Unpacker(f, raw=False, object_pairs_hook=OrderedDict) 16 | yield from unpacker 17 | 18 | 19 | def save_msgpack_l_gz(data: Iterable[Any], filename: Union[PathLike, str]) -> None: 20 | with gzip.GzipFile(filename, "wb") as out_file: 21 | packer = msgpack.Packer(use_bin_type=True) 22 | for element in data: 23 | out_file.write(packer.pack(element)) 24 | 25 | 26 | def load_all_msgpack_l_gz( 27 | path: RichPath, 28 | shuffle: bool = False, 29 | take_only_first_n_files: Optional[int] = None, 30 | limit_num_yielded_elements: Optional[int] = None, 31 | rng: Optional[random.Random] = None, 32 | ) -> Iterator: 33 | """ 34 | Iterate through all the elements of all the `.msgpack.l.gz` in a given directory. 35 | 36 | :param path: 37 | :param shuffle: 38 | :param take_only_first_n_files: 39 | :param limit_num_yielded_elements: 40 | :param rng: 41 | :return: 42 | """ 43 | all_files = sorted(path.iterate_filtered_files_in_dir("*.msgpack.l.gz")) 44 | if take_only_first_n_files is not None: 45 | all_files = all_files[:take_only_first_n_files] 46 | if shuffle and rng is None: 47 | random.shuffle(all_files) 48 | elif shuffle: 49 | rng.shuffle(all_files) 50 | 51 | sample_idx = 0 52 | for msgpack_file in all_files: 53 | try: 54 | for element in load_msgpack_l_gz(msgpack_file.to_local_path().path): 55 | if element is not None: 56 | sample_idx += 1 57 | yield element 58 | if limit_num_yielded_elements is not None and sample_idx > limit_num_yielded_elements: 59 | return 60 | except Exception as e: 61 | print(f"Error loading {msgpack_file}: {e}.") 62 | 63 | 64 | if __name__ == "__main__": 65 | # A json.tool-like CLI to look into msgpack.l.gz files. 66 | import sys 67 | import json 68 | 69 | for datapoint in load_msgpack_l_gz(sys.argv[1]): 70 | print(json.dumps(datapoint, indent=2)) 71 | print("---------------------------------------") -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | SetSimilaritySearch 2 | tqdm 3 | docopt 4 | azure-storage-blob 5 | azure-identity 6 | msgpack 7 | numpy 8 | typing_extensions 9 | sentencepiece 10 | cffi 11 | regex 12 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools, os 2 | 3 | if os.path.exists('../README.md'): 4 | with open('../README.md') as f: 5 | long_description = f.read() 6 | else: 7 | long_description = "" 8 | 9 | setuptools.setup( 10 | name='dpu_utils', 11 | version='0.6.1', 12 | license='MIT', 13 | description='Python utilities used by Deep Procedural Intelligence', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | url='https://github.com/microsoft/dpu-utils', 17 | author='Deep Procedural Intelligence', 18 | packages=setuptools.find_packages(), 19 | python_requires=">=3.6.1", 20 | include_package_data=True, 21 | install_requires=[ 22 | 'azure-storage-blob', 'azure-identity', 'numpy', 'docopt', 'tqdm', 'SetSimilaritySearch', 'sentencepiece', 'cffi', 'regex' 23 | ], 24 | scripts=['dpu_utils/codeutils/deduplication/deduplicationcli'], 25 | test_suite="tests", 26 | zip_safe=False) 27 | -------------------------------------------------------------------------------- /python/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/dpu-utils/bc4f5fa6c92c4b073fff4d93269512b0027f7f02/python/tests/__init__.py -------------------------------------------------------------------------------- /python/tests/codeutils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/dpu-utils/bc4f5fa6c92c4b073fff4d93269512b0027f7f02/python/tests/codeutils/__init__.py -------------------------------------------------------------------------------- /python/tests/codeutils/test_code_range.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from dpu_utils.codeutils.text import get_code_in_range 4 | 5 | TEST_CODE = """1234 6 | 7 | 567 8 | 890 9 | 10 | abcdefghijklmnop 11 | qrs 12 | """ 13 | 14 | class TestCodeRange(unittest.TestCase): 15 | def test_get_range(self): 16 | self.assertEqual(get_code_in_range( 17 | TEST_CODE, 18 | (1, 1), (1, 4) 19 | ), "1234") 20 | 21 | self.assertEqual(get_code_in_range( 22 | TEST_CODE, 23 | (1, 2), (1, 2) 24 | ), "2") 25 | 26 | self.assertEqual(get_code_in_range( 27 | TEST_CODE, 28 | (1, 1), (1, 10) 29 | ), "1234\n") 30 | 31 | self.assertEqual(get_code_in_range( 32 | TEST_CODE, 33 | (1, 1), (3, 0) 34 | ), "1234\n\n") 35 | 36 | self.assertEqual(get_code_in_range( 37 | TEST_CODE, 38 | (3, 1), (4, 10) 39 | ), "567\n 890\n") 40 | 41 | self.assertEqual(get_code_in_range( 42 | TEST_CODE, 43 | (3, 2), (4, 2) 44 | ), "67\n 8") 45 | 46 | self.assertEqual(get_code_in_range( 47 | TEST_CODE, 48 | (3, 1), (6, 0) 49 | ), "567\n 890\n\n") 50 | 51 | with self.assertRaises(ValueError): 52 | get_code_in_range( 53 | TEST_CODE, 54 | (7, 0), (10, 0) 55 | ) 56 | 57 | with self.assertRaises(ValueError): 58 | get_code_in_range( 59 | TEST_CODE, 60 | (10, 0), (11, 0) 61 | ) -------------------------------------------------------------------------------- /python/tests/codeutils/test_identifiersplitting.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from typing import List 3 | 4 | import dpu_utils.codeutils.identifiersplitting as split 5 | 6 | class TestSplitCamelCase(unittest.TestCase): 7 | def run_test(self, identifier: str, expected: List[str]) -> None: 8 | actual = split.split_identifier_into_parts(identifier) 9 | self.assertEqual(expected, actual) 10 | 11 | def test_empty_string_returns_empty_list(self): 12 | self.run_test("", [""]) 13 | self.run_test("_", ["_"]) 14 | self.run_test("$", ["$"]) 15 | 16 | def test_single_word_is_not_split(self): 17 | self.run_test("variable", ["variable"]) 18 | self.run_test("i", ["i"]) 19 | 20 | def test_two_words_are_split(self): 21 | self.run_test("camelCase", ["camel", "case"]) 22 | self.run_test("camelCase2", ["camel", "case", "2"]) 23 | self.run_test("camelCase23", ["camel", "case", "23"]) 24 | 25 | def test_three_words_are_split(self): 26 | self.run_test("camelCaseIdentifier", ["camel", "case", "identifier"]) 27 | 28 | def test_upper_camelcase_is_split(self): 29 | self.run_test("CamelCase", ["camel", "case"]) 30 | self.run_test("CamelCaseId", ["camel", "case", "id"]) 31 | self.run_test("CamelCaseID", ["camel", "case", "id"]) 32 | 33 | def test_abbreviations_are_split_correctly(self): 34 | self.run_test("HTMLParser", ["html", "parser"]) 35 | self.run_test("HTML25", ["html", "25"]) 36 | 37 | def test_digits_are_split(self): 38 | self.run_test("var12var3", ["var", "12", "var", "3"]) 39 | 40 | def test_special_characters_are_split(self): 41 | self.run_test("@var$var", ["@", "var", "$", "var"]) 42 | self.run_test("@var", ["@", "var"]) # C# style 43 | self.run_test("$var", ["$", "var"]) # PHP style 44 | self.run_test("$var2", ["$", "var", "2"]) # PHP style 45 | -------------------------------------------------------------------------------- /python/tests/mlutils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/dpu-utils/bc4f5fa6c92c4b073fff4d93269512b0027f7f02/python/tests/mlutils/__init__.py -------------------------------------------------------------------------------- /python/tests/mlutils/test_bpevocabulary.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pickle 3 | 4 | from glob import iglob 5 | import os 6 | from tempfile import TemporaryDirectory 7 | 8 | from dpu_utils.mlutils import BpeVocabulary 9 | 10 | 11 | class TestBpeVocab(unittest.TestCase): 12 | def test(self): 13 | def pseudotoken_iter(): 14 | """Create a dummy corpus by using this project.""" 15 | for filename in iglob(os.path.join(os.path.dirname(__file__), '..', '..', "**/*.py"), recursive=True): 16 | with open(filename) as f: 17 | yield from f.read().split('\n') 18 | 19 | v = BpeVocabulary(5000, user_defined_symbols=['']) 20 | v.create_vocabulary(pseudotoken_iter()) 21 | self.__test_roundtrip('for i in range(100, 2):', v) 22 | self.__test_roundtrip('from glob import iglob', v) 23 | self.__test_roundtrip('with open(_blahblah_blah, "w") as _f:_f.write("Something with space.")', v) 24 | 25 | def __test_roundtrip(self, text, v): 26 | idxs = v.get_id_or_unk_for_text(text) 27 | self.assertEqual(v.convert_ids_to_string(idxs), text) 28 | with TemporaryDirectory() as tmp: 29 | # Test serialization 30 | tmp_filename = os.path.join(tmp, 'tmp.pkl') 31 | with open(tmp_filename, 'wb') as f: 32 | pickle.dump(v, f) 33 | 34 | with open(tmp_filename, 'rb') as f: 35 | v2 = pickle.load(f) 36 | self.assertEqual(v2.get_id_or_unk_for_text(text), idxs) 37 | self.assertEqual(v2.convert_ids_to_string(idxs), text) 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /python/tests/ptutils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/dpu-utils/bc4f5fa6c92c4b073fff4d93269512b0027f7f02/python/tests/ptutils/__init__.py -------------------------------------------------------------------------------- /python/tests/ptutils/test_component.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tempfile 3 | from typing import Tuple, Iterator, Iterable 4 | 5 | import torch 6 | 7 | from dpu_utils.utils import RichPath 8 | from dpu_utils.ptutils import ComponentTrainer 9 | 10 | from tests.ptutils.testdata import SyntheticData 11 | from tests.ptutils.testmodel import SimpleRegression, SampleDatapoint 12 | 13 | 14 | class TestPytorchComponent(unittest.TestCase): 15 | def test_train_model(self): 16 | num_features = 100 17 | training_data, validation_data = self.__get_data(num_features) 18 | 19 | with tempfile.TemporaryDirectory() as dir: 20 | model_file = RichPath.create(dir).join('tmp.pkl.gz') 21 | 22 | model = SimpleRegression('SimpleRegressionTest', num_features) 23 | trainer = ComponentTrainer(model, model_file, max_num_epochs=50) 24 | trainer.train(training_data, validation_data, parallel_minibatch_creation=True) 25 | model_acc_1 = self.__compute_accuracy(model, validation_data) 26 | 27 | trained_model = SimpleRegression.restore_model(model_file) # type: SimpleRegression 28 | trained_model_acc = self.__compute_accuracy(trained_model, validation_data) 29 | self.assertGreater(trained_model_acc, .95, f'Model achieves too low accuracy, {trained_model_acc:%}') 30 | 31 | self.assertAlmostEqual(trained_model_acc, model_acc_1, places=3, msg=f'Accuracy before and after loading does not match: {trained_model_acc} vs {model_acc_1}') 32 | 33 | def test_freeze_params(self): 34 | num_features = 100 35 | training_data, validation_data = self.__get_data(num_features) 36 | 37 | with tempfile.TemporaryDirectory() as dir: 38 | model_file = RichPath.create(dir).join('tmp.pkl.gz') 39 | 40 | model = SimpleRegression('SimpleRegressionTest', num_features) 41 | trainer = ComponentTrainer(model, model_file, max_num_epochs=50) 42 | 43 | def get_freeze_weights(): 44 | for p in model.parameters(): 45 | if len(p.shape) == 2: # Just the weights 46 | yield p 47 | 48 | trainer.train(training_data, validation_data, get_parameters_to_freeze=lambda: set(get_freeze_weights())) 49 | trained_model_acc = self.__compute_accuracy(model, validation_data) 50 | 51 | self.assertLess(trained_model_acc, .7, f'Model achieves too high accuracy but the weights were frozen, {trained_model_acc:%}') 52 | 53 | 54 | def __get_data(self, num_features): 55 | data = SyntheticData(num_features) 56 | all_data = list(data.generate(10000)) 57 | training_data, validation_data = all_data[:9000], all_data[9000:] 58 | return training_data, validation_data 59 | 60 | def __compute_accuracy(self, model: SimpleRegression, dataset: Iterable[SampleDatapoint]) -> float: 61 | num_samples = 0 62 | num_correct = 0 63 | for point, prediction in self.__get_model_prediction(model, dataset): 64 | num_samples += 1 65 | if point.target_class == prediction: 66 | num_correct += 1 67 | return num_correct / num_samples 68 | 69 | def __get_model_prediction(self, model: SimpleRegression, data: Iterable[SampleDatapoint]) -> Iterator[Tuple[SampleDatapoint, bool]]: 70 | for datapoint in data: 71 | tensorized = model.load_data_from_sample(datapoint) 72 | mb_data = model.initialize_minibatch() 73 | model.extend_minibatch_by_sample(tensorized, mb_data) 74 | mb_data = model.finalize_minibatch(mb_data) 75 | 76 | with torch.no_grad(): 77 | predictions = model.predict(mb_data['inputs']).cpu().numpy() 78 | yield datapoint, predictions[0] 79 | 80 | 81 | 82 | if __name__ == '__main__': 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /python/tests/ptutils/testdata.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | import numpy as np 4 | 5 | from tests.ptutils.testmodel import SampleDatapoint 6 | 7 | 8 | class SyntheticData: 9 | def __init__(self, num_features: int): 10 | self.__num_features = num_features 11 | self.__weights = np.random.randn(num_features) * 10 12 | 13 | def generate(self, num_points: int) -> Iterator[SampleDatapoint]: 14 | for _ in range(num_points): 15 | # Avoid bias so that there is no obvious class imbalance. 16 | input_features = np.random.randn(self.__num_features) * 5 17 | yield SampleDatapoint( 18 | input_features=list(input_features), 19 | target_class=sum(input_features * self.__weights) >= 0 20 | ) -------------------------------------------------------------------------------- /python/tests/ptutils/testmodel.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, List, Optional, Dict, Any 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn as nn 6 | 7 | from dpu_utils.ptutils import BaseComponent 8 | 9 | 10 | class SampleDatapoint(NamedTuple): 11 | input_features: List[float] 12 | target_class: bool 13 | 14 | 15 | class TensorizedDatapoint(NamedTuple): 16 | input_features: np.ndarray 17 | target_class: np.ndarray 18 | 19 | 20 | class SimpleRegression(BaseComponent[SampleDatapoint, TensorizedDatapoint]): 21 | """A simple linear regression model used for testing.""" 22 | def __init__(self, name, num_features: int, hyperparameters: Optional[Dict[str, Any]] = None): 23 | super(SimpleRegression, self).__init__(name, hyperparameters) 24 | self.__num_features = num_features 25 | 26 | @classmethod 27 | def default_hyperparameters(cls) -> Dict[str, Any]: 28 | return {} 29 | 30 | def _load_metadata_from_sample(self, data_to_load: SampleDatapoint) -> None: 31 | pass # No metadata in this simple model. 32 | 33 | def _finalize_component_metadata_and_model(self) -> None: 34 | self.__layer = nn.Linear(self.__num_features, 1, bias=False) 35 | self.__bias = nn.Parameter(torch.tensor(0, dtype=torch.float32)) # Use a separate bias to allow freezing the weights. 36 | self.__loss = nn.BCEWithLogitsLoss() 37 | 38 | def load_data_from_sample(self, data_to_load: SampleDatapoint) -> Optional[TensorizedDatapoint]: 39 | return TensorizedDatapoint( 40 | input_features=np.array(data_to_load.input_features, dtype=np.float32), 41 | target_class=np.array(1 if data_to_load.target_class else 0, dtype=np.float32) 42 | ) 43 | 44 | def initialize_minibatch(self) -> Dict[str, Any]: 45 | return { 46 | 'inputs': [], 47 | 'targets': [] 48 | } 49 | 50 | def extend_minibatch_by_sample(self, datapoint: TensorizedDatapoint, accumulated_minibatch_data: Dict[str, Any]) -> bool: 51 | accumulated_minibatch_data['inputs'].append(datapoint.input_features) 52 | accumulated_minibatch_data['targets'].append(datapoint.target_class) 53 | return True 54 | 55 | def finalize_minibatch(self, accumulated_minibatch_data: Dict[str, Any]) -> Dict[str, Any]: 56 | return { 57 | 'inputs': torch.tensor(np.stack(accumulated_minibatch_data['inputs'], axis=0), device=self.device), 58 | 'targets': torch.tensor(np.stack(accumulated_minibatch_data['targets'], axis=0), device=self.device) 59 | } 60 | 61 | def predict(self, inputs: torch.Tensor): 62 | predicted = self.__layer(inputs)[:, 0] + self.__bias # B 63 | return predicted >= 0 64 | 65 | def forward(self, inputs, targets): 66 | predicted = self.__layer(inputs)[:, 0] + self.__bias # B 67 | loss = self.__loss(input=predicted, target=targets) 68 | return loss -------------------------------------------------------------------------------- /python/tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/dpu-utils/bc4f5fa6c92c4b073fff4d93269512b0027f7f02/python/tests/utils/__init__.py -------------------------------------------------------------------------------- /python/tests/utils/test_chunkwriter.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import unittest 3 | from itertools import permutations 4 | from typing import Set, Callable 5 | 6 | from dpu_utils.utils import ChunkWriter, RichPath 7 | 8 | 9 | class TestChunkWriter(unittest.TestCase): 10 | 11 | def test_write_read_jsonl_sequential(self): 12 | self.__test_write_read(lambda p: ChunkWriter(p, file_prefix='test', max_chunk_size=123, file_suffix='-test.jsonl.gz')) 13 | 14 | def test_write_read_jsonl_parallel(self): 15 | self.__test_write_read(lambda p: ChunkWriter(p, file_prefix='test', max_chunk_size=123, file_suffix='-test.jsonl.gz', parallel_writers=5)) 16 | 17 | def test_write_read_msgpack_sequential(self): 18 | self.__test_write_read(lambda p: ChunkWriter(p, file_prefix='test', max_chunk_size=123, file_suffix='-test.msgpack.l.gz'), suffix='.msgpack.l.gz') 19 | 20 | def test_write_read_msgpack_parallel(self): 21 | self.__test_write_read( 22 | lambda p: ChunkWriter(p, file_prefix='test', max_chunk_size=123, file_suffix='-test.msgpack.l.gz', parallel_writers=5), 23 | suffix='.msgpack.l.gz') 24 | 25 | 26 | def __test_write_read(self, chunk_writer_creator: Callable[[RichPath], ChunkWriter], suffix='.jsonl.gz'): 27 | all_chars = [chr(65+i) for i in range(26)] 28 | ground_elements = set(''.join(t) for t in permutations(all_chars, 3)) # 26^3 elements 29 | 30 | with tempfile.TemporaryDirectory() as tmp: 31 | tmp_path = RichPath.create(tmp) 32 | with chunk_writer_creator(tmp_path) as w: 33 | w.add_many(ground_elements) 34 | 35 | # Assert that all have been seen 36 | stored_elements = set() # type: Set[str] 37 | for file in tmp_path.get_filtered_files_in_dir('test*'+suffix): 38 | stored_elements.update(file.read_by_file_suffix()) 39 | 40 | self.assertSetEqual(stored_elements, ground_elements, f'Stored elements differ len(stored)={len(stored_elements)},' \ 41 | f' len(ground)={len(ground_elements)}.' \ 42 | f' Diff {ground_elements-stored_elements}.') -------------------------------------------------------------------------------- /python/tests/utils/test_iterators.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from functools import partial 3 | from itertools import islice 4 | 5 | from dpu_utils.utils import shuffled_iterator, ThreadedIterator, BufferedIterator, DoubleBufferedIterator, MultiWorkerCallableIterator, uniform_sample_iterator 6 | 7 | 8 | class TestShuffleIterator(unittest.TestCase): 9 | 10 | def test_return_all_elements(self): 11 | for size in [100, 10000, 100000]: 12 | shuffled_dataset = list(shuffled_iterator(range(size))) 13 | self.assertNotEqual(shuffled_dataset[:10], list(range(10)), 14 | 'It is highly unlikely that the original order is preserved') 15 | self.assertSetEqual(set(shuffled_dataset), set(range(size)), f'Some returned elements are missing.') 16 | 17 | def identity(x): 18 | return x 19 | 20 | class TestMultiWorkerIterator(unittest.TestCase): 21 | def test_return_all_elements(self): 22 | for use_threads in (True, False): 23 | with self.subTest('useThread={%s}' % use_threads): 24 | for size in [100, 10000, 100000]: 25 | dataset = list(MultiWorkerCallableIterator(((i,) for i in range(size)), identity, use_threads=use_threads)) 26 | self.assertSetEqual(set(dataset), set(range(size)), f'Some returned elements are missing.') 27 | 28 | 29 | def generator(size): 30 | for i in range(size): 31 | yield i 32 | 33 | class IterWrapper: 34 | def __init__(self, iter_fn): 35 | self._iter_fn = iter_fn 36 | 37 | def __iter__(self): 38 | yield from self._iter_fn() 39 | 40 | class TestParellelIterators(unittest.TestCase): 41 | 42 | ALL_ITERATOR_TYPES = [ThreadedIterator, BufferedIterator, DoubleBufferedIterator] 43 | 44 | def test_return_all_elements_in_order(self): 45 | for iterator_type in self.ALL_ITERATOR_TYPES: 46 | for enabled in (True, False): 47 | for size in [100, 10000]: 48 | for iter_kind in (range(size), generator(size), IterWrapper(partial(generator, size))): 49 | with self.subTest("%s-%s-%s-enabled=%s" % (iterator_type, size, iter_kind, enabled)): 50 | returned = list(iterator_type(iter_kind, enabled=enabled)) 51 | self.assertListEqual(returned, list(range(size)), f'Iterator {iterator_type.__name__} did not return all elements.') 52 | 53 | 54 | def test_finish_on_partial_iteration(self): 55 | """Parallel iterators may leave resources (threads, processes) on partial iteration. Ensure that's not the case.""" 56 | for iterator_type in self.ALL_ITERATOR_TYPES: 57 | for iter_kind in (range(100), generator(100), IterWrapper(partial(generator, 100))): 58 | with self.subTest("%s=%s" % (iterator_type, iter_kind)): 59 | returned = list(islice(iterator_type(iter_kind), 10)) 60 | # The test always finishes normally, but the pytest process should _not_ hang due to unfinished threads/processes. 61 | 62 | 63 | class TestSampleIterator(unittest.TestCase): 64 | def test_sample_iterator(self): 65 | 66 | self.assertSetEqual(set(uniform_sample_iterator(range(10), sample_size=100)), set(range(10))) 67 | self.assertSetEqual(set(uniform_sample_iterator(range(100), sample_size=100)), set(range(100))) 68 | 69 | all_elements = set(range(1000)) 70 | sampled = set(uniform_sample_iterator(all_elements, sample_size=100)) 71 | self.assertEqual(len(sampled), 100) 72 | self.assertEqual(len(set(sampled)), 100, msg="Elements not unique.") 73 | self.assertTrue(all(s in all_elements for s in sampled)) 74 | -------------------------------------------------------------------------------- /python/tests/utils/test_richpath.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import unittest 4 | from contextlib import contextmanager 5 | from enum import Enum 6 | from tempfile import TemporaryDirectory 7 | 8 | from azure.core.exceptions import ResourceExistsError 9 | from azure.storage.blob import ContainerClient 10 | 11 | from dpu_utils.utils import RichPath 12 | from dpu_utils.utils import save_jsonl_gz 13 | 14 | from dpu_utils.utils.msgpackloading import save_msgpack_l_gz 15 | 16 | 17 | class AuthType(Enum): 18 | CONNECTION_STRING = 0 19 | ACCOUNT_KEY = 1 20 | SAS_TOKEN = 2 21 | 22 | 23 | class TestRichPath(unittest.TestCase): 24 | # Note! The following are the default secrets used in Azurite, not real secrets 25 | # See https://github.com/Azure/Azurite for more. 26 | AZURITE_DEVELOPMENT_CONNECTION_STRING = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;" 27 | 28 | # Valid until 2050 29 | AZURITE_DEVELOPMENT_QUERY_STRING = "?sv=2018-03-28&st=2021-02-05T15%3A03%3A54Z&se=2050-02-06T15%3A03%3A00Z&sr=c&sp=racwdl&sig=fQDYpycIa3D7XZFBMIp0%2BzrukJb3Lq80gGLs9CArSHg%3D" 30 | AZURITE_DEVELOPMENT_ACCOUNT_KEY = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" 31 | 32 | def _create_test_container(self): 33 | client: ContainerClient = ContainerClient.from_connection_string( 34 | self.AZURITE_DEVELOPMENT_CONNECTION_STRING, 35 | container_name="test1" 36 | ) 37 | try: 38 | client.create_container() 39 | except ResourceExistsError: 40 | pass 41 | 42 | @contextmanager 43 | def _setup_test(self, auth_type: AuthType=AuthType.CONNECTION_STRING): 44 | self._create_test_container() 45 | with TemporaryDirectory() as tmp_config, TemporaryDirectory() as cache_dir: 46 | 47 | test_config = { 48 | "devstoreaccount1": { 49 | "cache_location": cache_dir, 50 | } 51 | } 52 | if auth_type == AuthType.ACCOUNT_KEY: 53 | test_config["devstoreaccount1"].update({ 54 | "account_key": self.AZURITE_DEVELOPMENT_ACCOUNT_KEY, 55 | "endpoint": "http://localhost:10000/devstoreaccount1" 56 | }) 57 | elif auth_type == AuthType.CONNECTION_STRING: 58 | test_config["devstoreaccount1"].update({ 59 | "connection_string": self.AZURITE_DEVELOPMENT_CONNECTION_STRING, 60 | }) 61 | elif auth_type == AuthType.SAS_TOKEN: 62 | test_config["devstoreaccount1"].update({ 63 | "sas_token": self.AZURITE_DEVELOPMENT_QUERY_STRING, 64 | "endpoint": "http://localhost:10000/devstoreaccount1" 65 | }) 66 | else: 67 | raise Exception(f"Unknown `auth_type`: {auth_type}") 68 | 69 | config_path = os.path.join(tmp_config, 'config.json') 70 | with open(config_path, 'w') as f: 71 | json.dump(test_config, f) 72 | 73 | yield config_path 74 | 75 | def test_connection_types(self): 76 | for auth_type in AuthType: 77 | with self.subTest(f"Test {auth_type}"), self._setup_test(auth_type=auth_type) as az_info, TemporaryDirectory() as tmp_dir: 78 | data_f = os.path.join(tmp_dir, 'testtext.txt') 79 | with open(data_f, 'w') as f: 80 | f.write("hello!") 81 | local_path = RichPath.create(data_f) 82 | 83 | remote_path = RichPath.create("azure://devstoreaccount1/test1/test_text.txt", az_info) 84 | remote_path.copy_from(local_path) 85 | local_path.delete() 86 | 87 | self.assertEqual(remote_path.read_as_text(), "hello!") 88 | remote_path.delete() 89 | 90 | def test_simple_read_write(self): 91 | with self._setup_test() as az_info: 92 | remote_path = RichPath.create("azure://devstoreaccount1/test1/remote_path.txt", az_info) 93 | with TemporaryDirectory() as tmp_dir: 94 | data_f = os.path.join(tmp_dir, 'testdata.txt') 95 | with open(data_f, 'w') as f: 96 | f.write("hello!") 97 | local_path = RichPath.create(data_f) 98 | self.assertEqual(local_path.read_as_text(), "hello!") 99 | local_size = local_path.get_size() 100 | 101 | remote_path.copy_from(local_path) 102 | self.assertTrue(local_path.exists()) 103 | local_path.delete() 104 | self.assertFalse(local_path.exists()) 105 | local_path.delete() 106 | with self.assertRaises(Exception): 107 | local_path.delete(missing_ok=False) 108 | 109 | self.assertEqual(remote_path.read_as_text(), "hello!") 110 | 111 | # Read once again (should trigger cache) 112 | self.assertEqual(remote_path.read_as_text(), "hello!") 113 | 114 | self.assertTrue(remote_path.exists()) 115 | self.assertTrue(remote_path.is_file()) 116 | self.assertFalse(remote_path.is_dir()) 117 | self.assertEqual(local_size, remote_path.get_size()) 118 | 119 | local_path = remote_path.to_local_path() 120 | self.assertTrue(local_path.exists()) 121 | os.path.exists(local_path.path) 122 | with open(local_path.path, 'r') as f: 123 | self.assertEqual(f.read(), "hello!") 124 | 125 | # Delete file 126 | remote_path.delete() 127 | self.assertFalse(remote_path.exists()) 128 | remote_path.delete() # Should not raise Exception 129 | with self.assertRaises(FileNotFoundError): 130 | remote_path.delete(missing_ok=False) 131 | 132 | # Other random remote_path does not exist 133 | remote_path = RichPath.create("azure://devstoreaccount1/test1/remote_path2.txt", az_info) 134 | self.assertFalse(remote_path.exists()) 135 | self.assertFalse(remote_path.is_dir()) 136 | self.assertFalse(remote_path.is_file()) 137 | 138 | with self.assertRaises(Exception): 139 | remote_path.read_as_text() 140 | 141 | with self.assertRaises(Exception): 142 | remote_path.get_size() 143 | 144 | def test_read_write_compressed_files(self): 145 | with self._setup_test() as az_info: 146 | random_elements = list(range(100)) 147 | for suffix in ('.json.gz', '.jsonl.gz', '.pkl.gz', '.msgpack.gz', '.msgpack.l.gz'): 148 | with self.subTest(f'Read/write {suffix}'): 149 | remote_path = RichPath.create(f"azure://devstoreaccount1/test1/compressed/data{suffix}", az_info) 150 | remote_path.save_as_compressed_file(random_elements) 151 | 152 | # Read once 153 | read_nums = list(remote_path.read_by_file_suffix()) 154 | self.assertListEqual(read_nums, random_elements) 155 | 156 | # Hit Cache 157 | read_nums = list(remote_path.read_by_file_suffix()) 158 | self.assertListEqual(read_nums, random_elements) 159 | self.assertTrue(remote_path.exists()) 160 | self.assertTrue(remote_path.is_file()) 161 | 162 | remote_dir = RichPath.create(f"azure://devstoreaccount1/test1/compressed/", az_info) 163 | self.assertTrue(remote_dir.is_dir()) 164 | self.assertFalse(remote_dir.is_file()) 165 | self.assertTrue(remote_dir.exists()) 166 | remote_files = list(remote_dir.iterate_filtered_files_in_dir('*.gz')) 167 | self.assertEqual(len(remote_files), 5) 168 | 169 | for suffix in ('.json.gz', '.jsonl.gz', '.pkl.gz', '.msgpack.gz', '.msgpack.l.gz'): 170 | joined_remote = remote_dir.join(f"data{suffix}") 171 | self.assertTrue(joined_remote.exists()) 172 | read_nums = list(joined_remote.read_by_file_suffix()) 173 | self.assertListEqual(read_nums, random_elements) 174 | 175 | for file in remote_files: 176 | read_nums = list(file.read_by_file_suffix()) 177 | self.assertListEqual(read_nums, random_elements) 178 | file.delete() 179 | self.assertFalse(file.exists()) 180 | 181 | self.assertFalse(remote_dir.exists()) 182 | # The directory should now be empty 183 | remote_files = list(remote_dir.iterate_filtered_files_in_dir('*.gz')) 184 | self.assertEqual(len(remote_files), 0) 185 | 186 | def test_copy_from(self): 187 | with self._setup_test() as az_info, TemporaryDirectory() as tmp_dir: 188 | elements = [[i, i//2] for i in range(10000)] 189 | tmp_local_path = RichPath.create(tmp_dir).join("sample.json.gz") 190 | tmp_local_path.save_as_compressed_file(elements) 191 | 192 | remote_path1 = RichPath.create(f"azure://devstoreaccount1/test1/sample1.json.gz", az_info) 193 | self.assertFalse(remote_path1.exists()) 194 | 195 | remote_path1.copy_from(tmp_local_path) 196 | tmp_local_path.delete() 197 | 198 | self.assertFalse(tmp_local_path.exists()) 199 | self.assertTrue(remote_path1.exists()) 200 | 201 | read_elements = remote_path1.read_by_file_suffix() 202 | self.assertListEqual(elements, read_elements) 203 | 204 | remote_path2 = RichPath.create(f"azure://devstoreaccount1/test1/sample2.json.gz", az_info) 205 | remote_path2.copy_from(remote_path1) 206 | remote_path1.delete() 207 | 208 | read_elements = remote_path2.read_by_file_suffix() 209 | self.assertListEqual(elements, read_elements) 210 | 211 | read_elements = remote_path2.to_local_path().read_by_file_suffix() 212 | self.assertListEqual(elements, read_elements) 213 | remote_path2.delete() 214 | 215 | def test_cache_correctness(self): 216 | with self._setup_test() as az_info: 217 | for suffix in ('.jsonl.gz', '.msgpack.l.gz'): 218 | random_elements = list(range(100)) 219 | remote_path = RichPath.create("azure://devstoreaccount1/test1/compressed/data" + suffix, az_info) 220 | remote_path.save_as_compressed_file(random_elements) 221 | 222 | # Read once 223 | read_nums = list(remote_path.read_by_file_suffix()) 224 | self.assertListEqual(read_nums, random_elements) 225 | 226 | # Hit Cache 227 | read_nums = list(remote_path.read_by_file_suffix()) 228 | self.assertListEqual(read_nums, random_elements) 229 | self.assertTrue(remote_path.exists()) 230 | self.assertTrue(remote_path.is_file()) 231 | 232 | # Update file through other means, and ensure that cache is appropriately invalidated. 233 | new_elements = list(range(500)) 234 | with TemporaryDirectory() as tmp: 235 | path = os.path.join(tmp, 'tst'+suffix) 236 | if suffix == '.jsonl.gz': 237 | save_jsonl_gz(new_elements, path) 238 | else: 239 | save_msgpack_l_gz(new_elements, path) 240 | container_client = ContainerClient.from_connection_string(self.AZURITE_DEVELOPMENT_CONNECTION_STRING, 241 | "test1") 242 | blob_client = container_client.get_blob_client("compressed/data" + suffix) 243 | with open(path, 'rb') as f: 244 | blob_client.upload_blob(f, overwrite=True) 245 | 246 | read_nums = list(remote_path.read_by_file_suffix()) 247 | self.assertListEqual(read_nums, new_elements) 248 | self.assertTrue(remote_path.exists()) 249 | self.assertTrue(remote_path.is_file()) 250 | --------------------------------------------------------------------------------