├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── deploy.yaml │ └── test.yaml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── index_files └── figure-commonmark │ └── mermaid-figure-1.png ├── nbs ├── 00_core.ipynb ├── 01_filter.ipynb ├── 02_clean.ipynb ├── 03_helpers.ipynb ├── 04_tutorials.ipynb ├── _quarto.yml ├── index.ipynb ├── nbdev.yml ├── sidebar.yml └── styles.css ├── settings.ini ├── setup.py └── squeakily ├── __init__.py ├── _modidx.py ├── clean.py ├── core.py ├── filter.py └── helpers.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | on: 3 | push: 4 | branches: [ "main", "master" ] 5 | workflow_dispatch: 6 | jobs: 7 | deploy: 8 | runs-on: ubuntu-latest 9 | steps: [uses: fastai/workflows/quarto-ghp@master] 10 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [workflow_dispatch, pull_request, push] 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | steps: [uses: fastai/workflows/nbdev-ci@master] 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.bin 2 | pilev2/ 3 | _docs/ 4 | _proc/ 5 | 6 | *.bak 7 | .gitattributes 8 | .last_checked 9 | .gitconfig 10 | *.bak 11 | *.log 12 | *~ 13 | ~* 14 | _tmp* 15 | tmp* 16 | tags 17 | *.pkg 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | env/ 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # dotenv 101 | .env 102 | 103 | # virtualenv 104 | .venv 105 | venv/ 106 | ENV/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | 121 | .vscode 122 | *.swp 123 | 124 | # osx generated files 125 | .DS_Store 126 | .DS_Store? 127 | .Trashes 128 | ehthumbs.db 129 | Thumbs.db 130 | .idea 131 | 132 | # pytest 133 | .pytest_cache 134 | 135 | # tools/trust-doc-nbs 136 | docs_src/.last_checked 137 | 138 | # symlinks to fastai 139 | docs_src/fastai 140 | tools/fastai 141 | 142 | # link checker 143 | checklink/cookies.txt 144 | 145 | # .gitconfig is now autogenerated 146 | .gitconfig 147 | 148 | # Quarto installer 149 | .deb 150 | .pkg 151 | 152 | # Quarto 153 | .quarto 154 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | louis@stability.ai. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022, CarperAI, EleutherAI, Chenghao Mou, BigCode, BigScience, and Eduardo Gonzalez Ponferrada 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # squeakily 2 | 3 | 4 | 5 | This repository is heavily inspired by BigScience’s [ROOTs 6 | project](https://github.com/bigscience-workshop/data-preparation) and 7 | EleutherAI’s [The Pile](https://github.com/EleutherAI/the-pile). 8 | 9 | The overall pipeline is as follows: 10 | 11 | ``` mermaid 12 | flowchart LR 13 | A(Defining <br/>Datasources) --> B(Defining Filters <br/>per Datasource) 14 | B --> C(Defining Cleaners <br/>per Datasource) 15 | ``` 16 | 17 | In this library, we define filtering as data instances being removed 18 | from the dataset based on some criteria and cleaning as data instances 19 | being modified in some way. 20 | 21 | ## Install 22 | 23 | ``` sh 24 | pip install squeakily 25 | ``` 26 | 27 | ## How to use 28 | 29 | ### Using the API 30 | 31 | First, we need to define a datasource. `squeakily` accepts any `Dataset` 32 | object from the [HuggingFace 33 | Datasets](https://huggingface.co/docs/datasets/index) library. For 34 | example, we can use the 35 | [wikitext](https://huggingface.co/datasets/wikitext) dataset: 36 | 37 | ``` python 38 | from datasets import load_dataset 39 | 40 | ds = load_dataset("wikitext", "wikitext-103-v1", split="train[:1%]") 41 | ``` 42 | 43 | We simply need to wrap the `Dataset` object in a dictionary, with the 44 | key being the name of the datasource and the value being the `Dataset` 45 | object, the filter and cleaners. For example: 46 | 47 | ``` python 48 | from squeakily.filter import check_char_repetition, check_flagged_words 49 | from squeakily.clean import remove_empty_lines, normalize_whitespace 50 | 51 | datasources = [ 52 | { 53 | "dataset": ds, 54 | "name": "wikitext", 55 | "columns": ["text"], 56 | "filters": [check_char_repetition, check_flagged_words], 57 | "cleaners": [remove_empty_lines, normalize_whitespace], 58 | }, 59 | # ... 60 | ] 61 | ``` 62 | 63 |
64 | 65 | > **Warning** 66 | > 67 | > Note: The order of the filters and cleaning functions matter. Filters 68 | > and cleaners are applied in the order they are defined. 69 | 70 |
71 | 72 |
73 | 74 | > **Important** 75 | > 76 | > Note: As of now, we only use the first column of the given column 77 | > names. This is because the `squeakily` library is designed to work 78 | > with language datasets, which usually have a single column of text. 79 | > Future versions will support multiple columns. 80 | 81 |
82 | 83 | Finally, we can apply the filters and cleaners to the datasouces using a 84 | [`Pipeline`](https://CarperAI.github.io/squeakily/core.html#pipeline) 85 | object: 86 | 87 | ``` python 88 | from squeakily.core import Pipeline 89 | 90 | pipeline = Pipeline(datasources) 91 | pipeline.run() 92 | ``` 93 | 94 |
[11/16/22 04:32:57] INFO     Running datasource: wikitext                                                core.py:41
 95 | 
96 |
                    INFO     Running filter: check_char_repetition on text                               core.py:54
 97 | 
98 |
                    INFO     Running filter: check_flagged_words on text                                 core.py:54
 99 | 
100 |
                    INFO     Running cleaner: remove_empty_lines on text                                 core.py:57
101 | 
102 |
[11/16/22 04:32:59] INFO     Running cleaner: normalize_whitespace on text                               core.py:57
103 | 
104 | 105 |
106 | 107 | > **Note** 108 | > 109 | > Note: If you want to to export the processed data source to a desired 110 | > path, you can specify an export path and the output type (csv or json) 111 | > in the `export_to_path` function. 112 | > 113 | > ``` python 114 | > export_path = "/path/to/desired/path" 115 | > output_types = ['csv', 'json'] # Optional, default is "csv" 116 | > json_indication = "records" # Optional, default is "records" 117 | > pipeline.export_to_path(export_path, output_types[1], json_indication=indication) 118 | > ``` 119 | 120 |
121 | 122 |
123 | 124 | > **Note** 125 | > 126 | > Note: If you want to run cleaners first, you can pass 127 | > `cleaning_first=True` to the `run` function. 128 | > 129 | > ``` python 130 | > pipeline.run(cleaning_first=True) 131 | > ``` 132 | 133 |
134 | 135 | If you need to run a filter or cleaner at the dataset level rather than 136 | the example level, you can pass `global_filters` or `global_cleaners` to 137 | the 138 | [`Pipeline.run`](https://CarperAI.github.io/squeakily/core.html#pipeline.run) 139 | function. For example: 140 | 141 | ``` python 142 | from squeakily.filter import minhash_dedup 143 | 144 | pipeline.run(global_filters=[minhash_dedup]) 145 | ``` 146 | 147 |
148 | 149 | > **Note** 150 | > 151 | > Note: If you use global filters or cleaners, all datasets must have a 152 | > common column name in order to properly concatenate them. 153 | 154 |
155 | 156 |
157 | 158 | > **Note** 159 | > 160 | > Note: You can also specifiy if you want a specific dataset to be 161 | > skipped by setting the `skip_global` parameter to `True` when defining 162 | > the datasource. 163 | > 164 | > ``` python 165 | > datasources = [ 166 | > { 167 | > "dataset": ds, 168 | > "columns": ["text"], 169 | > "filters": [check_char_repetition, check_flagged_words], 170 | > "cleaners": [remove_empty_lines, normalize_whitespace], 171 | > "skip_global": True, 172 | > }, 173 | > # ... 174 | > ] 175 | > ``` 176 | 177 |
178 | 179 | Additionally, you can run the pipeline in a dry run mode by passing 180 | `dry_run=True` to the `run` function. This will make no modifications to 181 | the datasets’ documents, but will add additional columns to the datasets 182 | with the results of the filters and cleaners. For example, if you if you 183 | ran the pipeline with the 184 | [`check_char_repetition`](https://CarperAI.github.io/squeakily/filter.html#check_char_repetition) 185 | filter, you would get a new column called 186 | [`check_char_repetition`](https://CarperAI.github.io/squeakily/filter.html#check_char_repetition) 187 | with a float value between 0 and 1 indicating the percentage of 188 | characters that are repeated in the document. 189 | 190 | ``` python 191 | 192 | ::: {.cell} 193 | ``` {.python .cell-code} 194 | pipeline = Pipeline(datasources) 195 | pipeline.run(dry_run=True) 196 | pipeline.datasources[0]["dataset"].features 197 | ``` 198 | 199 | ::: 200 | -------------------------------------------------------------------------------- /index_files/figure-commonmark/mermaid-figure-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarperAI/squeakily/1ffd03d02e6385e6435f50674366c934733ddcbf/index_files/figure-commonmark/mermaid-figure-1.png -------------------------------------------------------------------------------- /nbs/01_filter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# filter\n", 9 | "\n", 10 | "> This module contains all the various filtering options supported." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "# | default_exp filter" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# | export\n", 29 | "import datasets\n", 30 | "import gc\n", 31 | "import logging\n", 32 | "import multiprocessing\n", 33 | "import os\n", 34 | "import random\n", 35 | "import re\n", 36 | "\n", 37 | "import networkit as nk\n", 38 | "import numpy as np\n", 39 | "\n", 40 | "from collections import Counter\n", 41 | "from datasets import Dataset, Features, Value, Sequence\n", 42 | "from datasketch import LeanMinHash, MinHash, MinHashLSH\n", 43 | "from rich.logging import RichHandler\n", 44 | "from squeakily.helpers import flagged_words, get_words\n", 45 | "from squeakily.helpers import stopwords, stopword_ratios\n", 46 | "from tqdm.auto import tqdm\n", 47 | "from typing import Set" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# | export\n", 57 | "logger = logging.getLogger(__name__)\n", 58 | "logger.setLevel(logging.INFO)\n", 59 | "logger.addHandler(RichHandler(rich_tracebacks=True))\n", 60 | "logger.propagate = False\n", 61 | "datasets.logging.set_verbosity_error()\n", 62 | "# Turn off logging for datasets\n", 63 | "logging.getLogger(\"datasets\").setLevel(logging.ERROR)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "# | hide\n", 73 | "from datasets import load_dataset\n", 74 | "from nbdev.showdoc import *" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# | export\n", 84 | "multiprocessing.set_start_method(\"fork\", force=True)\n", 85 | "\n", 86 | "zstd_cntxt = None" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# | export\n", 96 | "def _compress_ratio(\n", 97 | " doc: str, # document to be analyzed\n", 98 | " compression_level: int = 3, # compression level to use\n", 99 | ") -> float:\n", 100 | " \"\"\"\n", 101 | " Returns the ratio of the compressed document to the original document.\n", 102 | " \"\"\"\n", 103 | " global zstd_cntxt\n", 104 | " if zstd_cntxt is None:\n", 105 | " import zstandard as zstd\n", 106 | "\n", 107 | " zstd_cntxt = zstd.ZstdCompressor(level=compression_level)\n", 108 | " bts = doc.encode(\"utf-8\")\n", 109 | " compressed_bts = zstd_cntxt.compress(bts)\n", 110 | " try:\n", 111 | " ratio = len(compressed_bts) / len(bts)\n", 112 | " except ZeroDivisionError:\n", 113 | " ratio = 0\n", 114 | " return ratio" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# | export\n", 124 | "def check_compression_ratio(\n", 125 | " document, # document to be analyzed\n", 126 | " compression_threshold: float = 0.5, # threshold for compression ratio\n", 127 | " compression_level: int = 3, # compression level to use\n", 128 | " dry_run=False, # if True, returns the ratio of character repetition\n", 129 | ") -> bool: # returns True if document is below threshold\n", 130 | " \"\"\"\n", 131 | " Checks if the document is below the character repetition threshold.\n", 132 | " \"\"\"\n", 133 | " compress_ratio = _compress_ratio(document, compression_level=compression_level)\n", 134 | " if dry_run:\n", 135 | " return compress_ratio\n", 136 | " else:\n", 137 | " return compress_ratio > compression_threshold" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "# |eval: false\n", 147 | "test_str0 = \"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\"\n", 148 | "test_str1 = \"This is a test string.\"\n", 149 | "assert check_compression_ratio(test_str0, dry_run=True) < check_compression_ratio(\n", 150 | " test_str1, dry_run=True\n", 151 | ")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "# | export\n", 161 | "def _char_rep_ratio(\n", 162 | " doc: str, # document to be analyzed\n", 163 | " char_rep_len: int, # length of character repetition\n", 164 | ") -> float:\n", 165 | " \"\"\"\n", 166 | " Returns the ratio of character repetitions in a document.\n", 167 | " \"\"\"\n", 168 | "\n", 169 | " def calc_ngrams(doc, n):\n", 170 | " char_ngrams = [doc[i : i + n] for i in range(len(doc) - n + 1)]\n", 171 | " freq_char_ngrams = Counter(char_ngrams)\n", 172 | " return freq_char_ngrams\n", 173 | "\n", 174 | " freq_char_ngrams = calc_ngrams(doc, char_rep_len)\n", 175 | " if len(freq_char_ngrams) == 0:\n", 176 | " return 0\n", 177 | " freq_char_ngrams = list(freq_char_ngrams.values())\n", 178 | " freq_char_ngrams = sorted(freq_char_ngrams, reverse=True)\n", 179 | " val_one = len([el for el in freq_char_ngrams if el == 1])\n", 180 | " num_rep_char_ngrams = min(\n", 181 | " int(np.sqrt(len(freq_char_ngrams))),\n", 182 | " len(freq_char_ngrams) - val_one,\n", 183 | " )\n", 184 | " char_rep_ratio = sum(freq_char_ngrams[:num_rep_char_ngrams]) / sum(freq_char_ngrams)\n", 185 | " return char_rep_ratio" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "# | export\n", 195 | "def check_char_repetition(\n", 196 | " document, # document to be analyzed\n", 197 | " char_repetition_len=10, # length of character repetition\n", 198 | " char_repetition_threshold=0.2, # threshold for character repetition\n", 199 | " dry_run=False, # if True, returns the ratio of character repetition\n", 200 | ") -> bool: # returns True if document is below threshold\n", 201 | " \"\"\"\n", 202 | " Checks if the document is below the character repetition threshold.\n", 203 | " \"\"\"\n", 204 | " char_rep_ratio = _char_rep_ratio(document, char_repetition_len)\n", 205 | " if dry_run:\n", 206 | " return char_rep_ratio\n", 207 | " else:\n", 208 | " return char_rep_ratio <= char_repetition_threshold" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "test_str = \"aaabbbcccddd\"\n", 218 | "assert (\n", 219 | " check_char_repetition(\n", 220 | " test_str, char_repetition_len=3, char_repetition_threshold=0.2\n", 221 | " )\n", 222 | " == True\n", 223 | ")\n", 224 | "\n", 225 | "test_str = \"aaaaaaabbbcccddd\"\n", 226 | "assert (\n", 227 | " check_char_repetition(\n", 228 | " test_str, char_repetition_len=3, char_repetition_threshold=0.2\n", 229 | " )\n", 230 | " == False\n", 231 | ")" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "# | export\n", 241 | "def _flag_word_ratio(\n", 242 | " doc: str, # document to be analyzed\n", 243 | " flagged_words: list, # list of flagged words\n", 244 | " get_words_func: callable, # function to get words from document\n", 245 | ") -> float: # returns ratio of flagged words in document\n", 246 | " \"\"\"\n", 247 | " Returns the ratio of flagged words in a document.\n", 248 | " \"\"\"\n", 249 | " words = get_words_func(doc)\n", 250 | " if not words:\n", 251 | " return 0.0\n", 252 | " flagged_words_ratio = len([word for word in words if word in flagged_words]) / len(\n", 253 | " words\n", 254 | " )\n", 255 | " if flagged_words_ratio > 1.0:\n", 256 | " flagged_words_ratio = 1.0\n", 257 | " return flagged_words_ratio" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "# | export\n", 267 | "def check_flagged_words(\n", 268 | " document: str, # document to be analyzed\n", 269 | " flagged_words: list = flagged_words[\"en\"], # list of flagged words\n", 270 | " flagged_words_threshold: float = 0.1, # threshold for flagged words\n", 271 | " get_words_func: callable = get_words, # function to get words from document\n", 272 | " dry_run: bool = False, # if True, returns the ratio of flagged words\n", 273 | ") -> bool: # returns True if document is below threshold unless dry_run is True\n", 274 | " \"\"\"\n", 275 | " Checks if a document contains a high percentage of flagged words.\n", 276 | " \"\"\"\n", 277 | " cond = True\n", 278 | " if flagged_words:\n", 279 | " flagged_words_ratio = _flag_word_ratio(\n", 280 | " document,\n", 281 | " flagged_words,\n", 282 | " get_words_func,\n", 283 | " )\n", 284 | " if dry_run:\n", 285 | " return flagged_words_ratio\n", 286 | "\n", 287 | " cond = flagged_words_ratio <= flagged_words_threshold\n", 288 | " return cond" 289 | ] 290 | }, 291 | { 292 | "attachments": {}, 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "The `check_flagged_words` filter function is purposefully hidden in this documentation as it would show the flagged words directly in the documentation, which might shock some people." 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "assert check_flagged_words(\"test\") == True\n", 306 | "assert check_flagged_words(\"bdsm\") == False" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "# | export\n", 316 | "def check_perplexity(\n", 317 | " document, # document to be analyzed\n", 318 | " perplexity_threshold=10_000, # threshold for perplexity\n", 319 | " model=None, # model to calculate perplexity\n", 320 | " dry_run=False, # if True, returns the perplexity of the document\n", 321 | ") -> bool: # returns True if document is below threshold\n", 322 | " \"\"\"\n", 323 | " Checks if the document is below the perplexity threshold.\n", 324 | " \"\"\"\n", 325 | " perplexity = model.get_perplexity(document)\n", 326 | " if dry_run:\n", 327 | " return perplexity\n", 328 | " else:\n", 329 | " return perplexity <= perplexity_threshold" 330 | ] 331 | }, 332 | { 333 | "attachments": {}, 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "To run this test, you need to have kenlm and sentencepiece installed:\n", 338 | "`pip install https://github.com/kpu/kenlm/archive/master.zip sentencepiece`" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "name": "stderr", 348 | "output_type": "stream", 349 | "text": [ 350 | "/home/nathan/miniconda3/envs/squeakily/lib/python3.10/site-packages/huggingface_hub/file_download.py:592: FutureWarning: `cached_download` is the legacy way to download files from the HF hub, please consider upgrading to `hf_hub_download`\n", 351 | " warnings.warn(\n" 352 | ] 353 | } 354 | ], 355 | "source": [ 356 | "# |eval: false\n", 357 | "from squeakily.helpers import KenlmModel\n", 358 | "\n", 359 | "model = KenlmModel.from_pretrained(\n", 360 | " model_dataset=\"wikipedia\",\n", 361 | " language=\"en\",\n", 362 | " lower_case=True,\n", 363 | " remove_accents=True,\n", 364 | " normalize_numbers=True,\n", 365 | " punctuation=1,\n", 366 | ")\n", 367 | "\n", 368 | "low_test_str = \"I am very perplexed\"\n", 369 | "high_test_str = \"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed ...\"\n", 370 | "\n", 371 | "assert check_perplexity(low_test_str, perplexity_threshold=1_000, model=model) == True\n", 372 | "assert check_perplexity(high_test_str, perplexity_threshold=1_000, model=model) == False" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "# | export\n", 382 | "def check_language(\n", 383 | " document, # document to be analyzed\n", 384 | " language=\"en\", # language to check\n", 385 | " language_threshold=0.9, # threshold for language\n", 386 | " model=None, # model to check language\n", 387 | " dry_run=False, # if True, returns the language of the document\n", 388 | ") -> bool: # returns True if document is below threshold\n", 389 | " \"\"\"\n", 390 | " Checks if the document is below the language threshold.\n", 391 | " \"\"\"\n", 392 | " lang, prob = model.get_language(document)\n", 393 | " if dry_run:\n", 394 | " if lang == language:\n", 395 | " return prob\n", 396 | " else:\n", 397 | " return -1.0\n", 398 | " else:\n", 399 | " return language == lang and prob > language_threshold" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": {}, 406 | "outputs": [ 407 | { 408 | "name": "stderr", 409 | "output_type": "stream", 410 | "text": [ 411 | "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n" 412 | ] 413 | } 414 | ], 415 | "source": [ 416 | "# |eval: false\n", 417 | "from squeakily.helpers import FastTextLanguageDetector\n", 418 | "\n", 419 | "fasttext_model = FastTextLanguageDetector.from_pretrained()\n", 420 | "\n", 421 | "english_text = \"Hi, my name is John.\"\n", 422 | "spanish_text = \"Hola, me llamo Juan.\"\n", 423 | "chinese_text = \"你好,我叫张三。\"\n", 424 | "\n", 425 | "assert (\n", 426 | " check_language(\n", 427 | " english_text, language=\"en\", language_threshold=0.85, model=fasttext_model\n", 428 | " )\n", 429 | " == True\n", 430 | ")\n", 431 | "assert (\n", 432 | " check_language(\n", 433 | " spanish_text, language=\"en\", language_threshold=0.85, model=fasttext_model\n", 434 | " )\n", 435 | " == False\n", 436 | ")\n", 437 | "assert (\n", 438 | " check_language(\n", 439 | " chinese_text, language=\"en\", language_threshold=0.85, model=fasttext_model\n", 440 | " )\n", 441 | " == False\n", 442 | ")\n", 443 | "\n", 444 | "# test dry run\n", 445 | "assert (\n", 446 | " check_language(\n", 447 | " english_text,\n", 448 | " language=\"en\",\n", 449 | " language_threshold=0.85,\n", 450 | " model=fasttext_model,\n", 451 | " dry_run=True,\n", 452 | " )\n", 453 | " > 0.0\n", 454 | ")\n", 455 | "assert (\n", 456 | " check_language(\n", 457 | " spanish_text,\n", 458 | " language=\"en\",\n", 459 | " language_threshold=0.85,\n", 460 | " model=fasttext_model,\n", 461 | " dry_run=True,\n", 462 | " )\n", 463 | " == -1.0\n", 464 | ")\n", 465 | "assert (\n", 466 | " check_language(\n", 467 | " chinese_text,\n", 468 | " language=\"es\",\n", 469 | " language_threshold=0.85,\n", 470 | " model=fasttext_model,\n", 471 | " dry_run=True,\n", 472 | " )\n", 473 | " == -1.0\n", 474 | ")" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "# | export\n", 484 | "def check_word_number(\n", 485 | " document, # document to be analyzed\n", 486 | " min_word_threshold=5, # minimum number of words\n", 487 | " max_word_threshold=100, # maximum number of words\n", 488 | " get_words_func=get_words, # function to get words from document\n", 489 | " dry_run=False, # if True, returns the number of words in the document\n", 490 | ") -> bool: # returns True if document is between the minimum and maximum thresholds\n", 491 | " \"\"\"\n", 492 | " Checks if the document is between the minimum and maximum word thresholds.\n", 493 | " \"\"\"\n", 494 | " words = get_words_func(document)\n", 495 | " if dry_run:\n", 496 | " return len(words)\n", 497 | " else:\n", 498 | " return len(words) >= min_word_threshold and len(words) <= max_word_threshold" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "test_str = \"This is a test string.\"\n", 508 | "\n", 509 | "assert check_word_number(test_str, min_word_threshold=5, max_word_threshold=10) == True\n", 510 | "assert check_word_number(test_str, min_word_threshold=1, max_word_threshold=4) == False" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "metadata": {}, 517 | "outputs": [], 518 | "source": [ 519 | "# | export\n", 520 | "def check_stop_word_ratio(\n", 521 | " document, # document to be analyzed\n", 522 | " stop_word_threshold=stopword_ratios[\"en\"], # threshold for stop words\n", 523 | " stop_words=stopwords[\"en\"], # list of stop words\n", 524 | " get_words_func=get_words, # function to get words from document\n", 525 | " dry_run=False, # if True, returns the ratio of stop words in the document\n", 526 | ") -> bool: # returns True if document is below the threshold\n", 527 | " \"\"\"\n", 528 | " Checks if the document contains a high percentage of stop words.\n", 529 | " \"\"\"\n", 530 | " cond = True\n", 531 | " if stop_words:\n", 532 | " stop_word_ratio = _flag_word_ratio(\n", 533 | " document,\n", 534 | " stop_words,\n", 535 | " get_words_func,\n", 536 | " )\n", 537 | " if dry_run:\n", 538 | " return stop_word_ratio\n", 539 | " else:\n", 540 | " cond = stop_word_ratio <= stop_word_threshold\n", 541 | " return cond" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "metadata": {}, 548 | "outputs": [], 549 | "source": [ 550 | "assert check_stop_word_ratio(\"test\") == True\n", 551 | "assert check_stop_word_ratio(\"the\") == False\n", 552 | "assert check_stop_word_ratio(\"the funny llama\", stop_word_threshold=0.3) == False\n", 553 | "assert check_stop_word_ratio(\"the funny llama\", stop_word_threshold=0.5) == True\n", 554 | "# Test french stop words\n", 555 | "assert check_stop_word_ratio(\"le\", stop_words=stopwords[\"fr\"]) == False\n", 556 | "assert (\n", 557 | " check_stop_word_ratio(\"le chien est beau\", stop_words=stopwords[\"fr\"], dry_run=True)\n", 558 | " == 0.5\n", 559 | ")\n", 560 | "assert (\n", 561 | " check_stop_word_ratio(\n", 562 | " \"le chien est beau\", stop_words=stopwords[\"fr\"], stop_word_threshold=0.3\n", 563 | " )\n", 564 | " == False\n", 565 | ")" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": null, 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "# | export\n", 575 | "def check_code_parsability(\n", 576 | " document, # document to be analyzed\n", 577 | " program_language=\"python\", # programming language to check\n", 578 | ") -> bool: # returns True if the code is parsable\n", 579 | " \"\"\"\n", 580 | " Checks if the document contains parsable code.\n", 581 | " \"\"\"\n", 582 | " import code_tokenize as ctok\n", 583 | "\n", 584 | " try:\n", 585 | " ctok.tokenize(document, lang=program_language, syntax_error=\"raise\")\n", 586 | " return True\n", 587 | " except SyntaxError:\n", 588 | " return False" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": null, 594 | "metadata": {}, 595 | "outputs": [ 596 | { 597 | "name": "stderr", 598 | "output_type": "stream", 599 | "text": [ 600 | "WARNING:root:Autoloading AST parser for javascript: Start download from Github.\n", 601 | "WARNING:root:Start cloning the parser definition from Github.\n", 602 | "WARNING:root:Compiling language for javascript\n" 603 | ] 604 | } 605 | ], 606 | "source": [ 607 | "# Test python code\n", 608 | "assert check_code_parsability(\"print('hello world')\", program_language=\"python\") == True\n", 609 | "assert check_code_parsability(\"print('hello world'\", program_language=\"python\") == False\n", 610 | "# Test javascript code\n", 611 | "assert (\n", 612 | " check_code_parsability(\"console.log('hello world')\", program_language=\"javascript\")\n", 613 | " == True\n", 614 | ")\n", 615 | "assert (\n", 616 | " check_code_parsability(\"console.log('hello world'\", program_language=\"javascript\")\n", 617 | " == False\n", 618 | ")" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": null, 624 | "metadata": {}, 625 | "outputs": [], 626 | "source": [ 627 | "# | export\n", 628 | "def check_labels(\n", 629 | " document, # document to be analyzed\n", 630 | " labels: list, # list of labels to check the document against\n", 631 | " model=None, # model to check label\n", 632 | " dry_run=False, # if True, returns the tags of the document\n", 633 | ") -> bool: # returns True if document relates to any of the labels\n", 634 | " \"\"\"\n", 635 | " Checks if the document relates to any of the labels.\n", 636 | " \"\"\"\n", 637 | " pred_labels = model(document)\n", 638 | " if dry_run:\n", 639 | " return pred_labels\n", 640 | " else:\n", 641 | " return any([label in pred_labels for label in labels])" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": null, 647 | "metadata": {}, 648 | "outputs": [], 649 | "source": [ 650 | "# |eval: false\n", 651 | "from squeakily.helpers import LLMLabeler\n", 652 | "\n", 653 | "instruction = \"Please classify the following text into one of the following categories:\"\n", 654 | "labels = [\"positive\", \"negative\", \"neutral\"]\n", 655 | "api_key = \"\"\n", 656 | "llm_labeler = LLMLabeler(instruction, labels, api_key=api_key)\n", 657 | "\n", 658 | "pos_text = \"I love this product!\"\n", 659 | "neg_text = \"I hate this product!\"\n", 660 | "neutral_text = \"This product is okay.\"\n", 661 | "assert check_labels(pos_text, labels, model=llm_labeler) == True\n", 662 | "assert check_labels(neg_text, labels, model=llm_labeler) == True\n", 663 | "assert check_labels(neutral_text, labels, model=llm_labeler) == True\n", 664 | "assert check_labels(pos_text, [\"negative\", \"neutral\"], model=llm_labeler) == False\n", 665 | "\n", 666 | "mixed_text = \"I love this product! I hate this product! This product is okay.\"\n", 667 | "pred_labels = check_labels(mixed_text, labels, model=llm_labeler, dry_run=True)\n", 668 | "assert set(pred_labels) == set(labels)" 669 | ] 670 | }, 671 | { 672 | "attachments": {}, 673 | "cell_type": "markdown", 674 | "metadata": {}, 675 | "source": [ 676 | "## Whole Dataset Filtering" 677 | ] 678 | }, 679 | { 680 | "attachments": {}, 681 | "cell_type": "markdown", 682 | "metadata": {}, 683 | "source": [ 684 | "### MinHash Deduplication\n", 685 | "The following code has all been adapted from the awesome [Chenghao Mou](https://github.com/ChenghaoMou) and their work on the [BigCode repository](https://github.com/bigcode-project/bigcode-analysis/blob/main/data_analysis/near-deduplication)!" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": null, 691 | "metadata": {}, 692 | "outputs": [], 693 | "source": [ 694 | "# | export\n", 695 | "MINHASH_SEED = 115\n", 696 | "NON_ALPHA = re.compile(\"[^A-Za-z_0-9]\")\n", 697 | "\n", 698 | "random.seed(MINHASH_SEED)\n", 699 | "\n", 700 | "lsh: MinHashLSH = None\n", 701 | "dup_ids: Set = None" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": null, 707 | "metadata": {}, 708 | "outputs": [], 709 | "source": [ 710 | "# | export\n", 711 | "def _hash_func(\n", 712 | " idx: int, # The index of the record.\n", 713 | " content: str, # The content to be hashed.\n", 714 | " *,\n", 715 | " num_perm: int # The number of permutations to use in the MinHash object.\n", 716 | ") -> dict[str, any]: # The MinHash signature and the index of the record.\n", 717 | " \"\"\"\n", 718 | " Embed the content of a record into a MinHash object. This function should be\n", 719 | " used with multiprocessing and it scales well with the number of cores.\n", 720 | " >>> result = _hash_func(0, \"Hello world!\", num_perm=128)\n", 721 | " >>> result[\"__id__\"]\n", 722 | " 0\n", 723 | " >>> result[\"__signature__\"].shape\n", 724 | " (128,)\n", 725 | " >>> result[\"__signature__\"].dtype\n", 726 | " dtype('uint64')\n", 727 | " \"\"\"\n", 728 | " m = MinHash(num_perm=num_perm, seed=MINHASH_SEED)\n", 729 | " m.update_batch(\n", 730 | " [token.encode(\"utf-8\") for token in {t for t in NON_ALPHA.split(content) if t}]\n", 731 | " )\n", 732 | " return {\"__signature__\": m.hashvalues, \"__id__\": idx}" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": null, 738 | "metadata": {}, 739 | "outputs": [], 740 | "source": [ 741 | "result = _hash_func(0, \"Hello world!\", num_perm=128)\n", 742 | "assert result[\"__id__\"] == 0\n", 743 | "assert result[\"__signature__\"].shape == (128,)\n", 744 | "assert result[\"__signature__\"].dtype == np.uint64" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": null, 750 | "metadata": {}, 751 | "outputs": [], 752 | "source": [ 753 | "# | export\n", 754 | "def _query_content(\n", 755 | " idx: int, # The index of the record.\n", 756 | " signature: np.ndarray, # The MinHash signature of the record to be queried.\n", 757 | " *,\n", 758 | " index: MinHashLSH # The MinHashLSH index. It is shared across all processes when using multiprocessing with fork without copy.\n", 759 | ") -> dict[str, any]: # The query result.\n", 760 | " \"\"\"\n", 761 | " Query the MinHashLSH index for the record. This function can be used with multiprocessing\n", 762 | " as long as the index is shared across processes.\n", 763 | " \"\"\"\n", 764 | " return {\n", 765 | " \"__neighbors__\": [\n", 766 | " dup_idx\n", 767 | " for dup_idx in index.query(\n", 768 | " LeanMinHash(seed=MINHASH_SEED, hashvalues=signature),\n", 769 | " )\n", 770 | " if dup_idx != idx # exclude itself\n", 771 | " ],\n", 772 | " \"__id__\": idx,\n", 773 | " }" 774 | ] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "execution_count": null, 779 | "metadata": {}, 780 | "outputs": [], 781 | "source": [ 782 | "data = [\"Hello world!\", \"Hello world\"]\n", 783 | "signatures = [_hash_func(i, content, num_perm=128) for i, content in enumerate(data)]\n", 784 | "index = MinHashLSH(threshold=0.5, num_perm=128)\n", 785 | "for signature in signatures:\n", 786 | " index.insert(\n", 787 | " signature[\"__id__\"],\n", 788 | " MinHash(num_perm=128, hashvalues=signature[\"__signature__\"], seed=MINHASH_SEED),\n", 789 | " )\n", 790 | "assert _query_content(0, signatures[0][\"__signature__\"], index=index) == {\n", 791 | " \"__neighbors__\": [1],\n", 792 | " \"__id__\": 0,\n", 793 | "}\n", 794 | "assert _query_content(1, signatures[1][\"__signature__\"], index=index) == {\n", 795 | " \"__neighbors__\": [0],\n", 796 | " \"__id__\": 1,\n", 797 | "}" 798 | ] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "execution_count": null, 803 | "metadata": {}, 804 | "outputs": [], 805 | "source": [ 806 | "# | export\n", 807 | "def _jaccard_similarity(\n", 808 | " s1: str, s2: str # The first string to compare. # The second string to compare.\n", 809 | ") -> float: # The Jaccard similarity between the two strings.\n", 810 | " \"\"\"\n", 811 | " Calculate the jaccard similarity between two code snippets.\n", 812 | " \"\"\"\n", 813 | " tokens1 = set([t for t in NON_ALPHA.split(s1) if t.strip()])\n", 814 | " tokens2 = set([t for t in NON_ALPHA.split(s2) if t.strip()])\n", 815 | " return len(tokens1 & tokens2) / max(1, len(tokens1 | tokens2))" 816 | ] 817 | }, 818 | { 819 | "cell_type": "code", 820 | "execution_count": null, 821 | "metadata": {}, 822 | "outputs": [], 823 | "source": [ 824 | "assert _jaccard_similarity(\"a = 1\", \"a = 2\") == 0.3333333333333333\n", 825 | "assert _jaccard_similarity(\"a = 1\", \"a = 1\") == 1.0" 826 | ] 827 | }, 828 | { 829 | "cell_type": "code", 830 | "execution_count": null, 831 | "metadata": {}, 832 | "outputs": [], 833 | "source": [ 834 | "# | export\n", 835 | "def _calculate_average_false_positive_rate(\n", 836 | " clusters: list[list[int]], # The clusters of duplicate records.\n", 837 | " reference_records: Dataset, # The reference records.\n", 838 | " threshold: float, # The threshold to use for calculating the false positive rate.\n", 839 | " column: str, # The column to use for calculating the false positive rate.\n", 840 | ") -> None:\n", 841 | " \"\"\"\n", 842 | " Calculate the average false positive rate within each cluster. The false positives are defined as\n", 843 | " number of examples that have a maximum jaccard similarity with any example in the cluster that is\n", 844 | " less than the threshold. The false positive rate is defined as the number of false positives divided\n", 845 | " by the number of examples in the cluster. The average false positive rate is defined as the average\n", 846 | " of the false positive rate across all clusters given.\n", 847 | " \"\"\"\n", 848 | " cluster_false_positive_rates: list[float] = []\n", 849 | " deltas: list[float] = []\n", 850 | "\n", 851 | " for cluster in tqdm(clusters, desc=\"Calculating sampling false positive rate...\"):\n", 852 | " num_false_positives = 0\n", 853 | " ids = sorted(cluster)\n", 854 | " for i, x in enumerate(ids):\n", 855 | " is_false_positive = True\n", 856 | " max_similarity = -float(\"inf\")\n", 857 | " for j, y in enumerate(ids):\n", 858 | " if i == j:\n", 859 | " continue\n", 860 | " # TODO This can be redundant but we only calculate this for a small sample\n", 861 | " similarity = _jaccard_similarity(\n", 862 | " reference_records[x][column], reference_records[y][column]\n", 863 | " )\n", 864 | " max_similarity = max(max_similarity, similarity)\n", 865 | " if max_similarity >= threshold:\n", 866 | " is_false_positive = False\n", 867 | " break\n", 868 | " if is_false_positive:\n", 869 | " num_false_positives += 1\n", 870 | " deltas.append(threshold - max_similarity)\n", 871 | " cluster_false_positive_rates.append(num_false_positives / len(ids))\n", 872 | "\n", 873 | " logger.info(\n", 874 | " f\"Average false positive rate from {len(clusters)} clusters: {np.mean(cluster_false_positive_rates):.2f}\"\n", 875 | " )\n", 876 | " logger.info(f\"Similarity delta stats from threshold:\")\n", 877 | " logger.info(f\"- Max : {np.max(deltas):0.2f}\")\n", 878 | " logger.info(f\"- Min : {np.min(deltas):0.2f}\")\n", 879 | " logger.info(f\"- Mean: {np.mean(deltas):0.2f}\")\n", 880 | " logger.info(f\"- Std : {np.std(deltas):0.2f}\")" 881 | ] 882 | }, 883 | { 884 | "cell_type": "code", 885 | "execution_count": null, 886 | "metadata": {}, 887 | "outputs": [], 888 | "source": [ 889 | "# | export\n", 890 | "def _find_duplicate_communities(\n", 891 | " records: Dataset, # The dataset that contains both `__id__` and `__neighbors__`.\n", 892 | " community_detection: bool, # Whether to use community detection to find the duplicate communities, or to use the connected components.\n", 893 | " report_false_positive_rate: bool = False, # Whether to report the false positive rate.\n", 894 | " reference_records: Dataset = None, # The reference records. It can be an iterable or a Dataset. It is only used when `report_false_positive_rate` is True.\n", 895 | " threshold: float = 0.85, # The threshold to use for calculating the false positive rate.\n", 896 | " column: str = \"content\", # The column to use for calculating the false positive rate.\n", 897 | " verbose: bool = False,\n", 898 | ") -> (\n", 899 | " Set\n", 900 | "): # The set of duplicate ids that should be removed, leaving only one id in each community.\n", 901 | " \"\"\"\n", 902 | " Find the duplicate communities from the queried dataset.\n", 903 | " \"\"\"\n", 904 | " SAMPLE_MIN_SIZE = 10\n", 905 | " SAMPLE_MAX_SIZE = 100\n", 906 | " SAMPLE_SIZE = 10\n", 907 | " g = nk.graph.Graph()\n", 908 | " for record in tqdm(records, desc=\"Constructing graph...\"):\n", 909 | " for y in record[\"__neighbors__\"]:\n", 910 | " g.addEdge(record[\"__id__\"], y, addMissing=True)\n", 911 | "\n", 912 | " to_remove: Set = set()\n", 913 | " samples: list[list[int]] = []\n", 914 | " if not community_detection:\n", 915 | " cc = nk.components.ConnectedComponents(g)\n", 916 | " cc.run()\n", 917 | " partition = cc.getPartition()\n", 918 | " components = list(cc.getComponents())\n", 919 | " random.shuffle(components)\n", 920 | " for component in tqdm(components, desc=\"Iterating over components...\"):\n", 921 | " component = sorted(component)\n", 922 | " to_remove.update(component[1:])\n", 923 | " if (\n", 924 | " len(samples) < SAMPLE_SIZE\n", 925 | " and SAMPLE_MAX_SIZE > len(component) >= SAMPLE_MIN_SIZE\n", 926 | " ):\n", 927 | " samples.append(component[:])\n", 928 | " else:\n", 929 | " algo = nk.community.PLM(g, refine=False)\n", 930 | " algo.run()\n", 931 | " partition = algo.getPartition()\n", 932 | " communities = list(partition.getSubsetIds())\n", 933 | " random.shuffle(communities)\n", 934 | " # This can be slow if there are many communities\n", 935 | " for i in tqdm(communities, desc=\"Iterating over communities...\"):\n", 936 | " ids = partition.getMembers(i)\n", 937 | " to_remove.update(sorted(ids)[1:])\n", 938 | " if (\n", 939 | " len(samples) < SAMPLE_SIZE\n", 940 | " and SAMPLE_MAX_SIZE > len(ids) >= SAMPLE_MIN_SIZE\n", 941 | " ):\n", 942 | " samples.append(ids)\n", 943 | "\n", 944 | " if report_false_positive_rate and verbose:\n", 945 | " _calculate_average_false_positive_rate(\n", 946 | " samples,\n", 947 | " reference_records,\n", 948 | " threshold,\n", 949 | " column,\n", 950 | " )\n", 951 | "\n", 952 | " return to_remove" 953 | ] 954 | }, 955 | { 956 | "cell_type": "code", 957 | "execution_count": null, 958 | "metadata": {}, 959 | "outputs": [], 960 | "source": [ 961 | "# | export\n", 962 | "def minhash_dedup(\n", 963 | " ds, # The dataset to deduplicate.\n", 964 | " column, # The column to use for deduplication.\n", 965 | " community_detection: bool = False, # Whether to use community detection to find the duplicate communities, or to use the connected components.\n", 966 | " report_false_positive_rate: bool = False, # Whether to report the false positive rate.\n", 967 | " threshold: float = 0.85, # The threshold to use for deduplication.\n", 968 | " num_perm: int = 128, # The number of permutations to use for minhashing.\n", 969 | " dry_run: bool = False, # Whether to run the deduplication in dry run mode.\n", 970 | ") -> Dataset:\n", 971 | " \"\"\"\n", 972 | " Deduplicate the dataset using minhashing as described in the paper \"Deduplicating Training Data Makes Language Models Better\".\n", 973 | " \"\"\"\n", 974 | " global lsh\n", 975 | " global dup_ids\n", 976 | "\n", 977 | " lsh = MinHashLSH(\n", 978 | " threshold=threshold,\n", 979 | " num_perm=num_perm,\n", 980 | " )\n", 981 | " column_names = ds.column_names\n", 982 | " ds = ds.map(\n", 983 | " lambda _, idx: {\"__id__\": idx},\n", 984 | " with_indices=True,\n", 985 | " num_proc=os.cpu_count(),\n", 986 | " desc=\"Adding index...\",\n", 987 | " )\n", 988 | " hashed_ds = ds.map(\n", 989 | " function=_hash_func,\n", 990 | " fn_kwargs={\"num_perm\": num_perm},\n", 991 | " input_columns=[\"__id__\", column],\n", 992 | " remove_columns=column_names,\n", 993 | " num_proc=os.cpu_count(),\n", 994 | " desc=f\"Fingerprinting...\",\n", 995 | " )\n", 996 | " with lsh.insertion_session() as session:\n", 997 | " for data in tqdm(hashed_ds, desc=\"Indexing signatures...\"):\n", 998 | " if data[\"__id__\"] in lsh:\n", 999 | " continue\n", 1000 | " session.insert(\n", 1001 | " data[\"__id__\"],\n", 1002 | " LeanMinHash(seed=MINHASH_SEED, hashvalues=data[\"__signature__\"]),\n", 1003 | " check_duplication=False,\n", 1004 | " )\n", 1005 | "\n", 1006 | " gc.disable()\n", 1007 | " gc.freeze()\n", 1008 | "\n", 1009 | " conf = {\n", 1010 | " \"threshold\": threshold,\n", 1011 | " \"community_detection\": community_detection,\n", 1012 | " \"report_false_positive_rate\": report_false_positive_rate,\n", 1013 | " \"num_perm\": num_perm,\n", 1014 | " \"name\": ds.builder_name,\n", 1015 | " \"column\": column,\n", 1016 | " }\n", 1017 | " queried = hashed_ds.map(\n", 1018 | " lambda x, y: _query_content(x, y, index=lsh),\n", 1019 | " num_proc=os.cpu_count(),\n", 1020 | " features=Features(\n", 1021 | " {\n", 1022 | " \"__id__\": Value(dtype=\"int64\", id=None),\n", 1023 | " \"__neighbors__\": Sequence(\n", 1024 | " feature=Value(dtype=\"int64\", id=None), length=-1, id=None\n", 1025 | " ),\n", 1026 | " }\n", 1027 | " ),\n", 1028 | " input_columns=[\"__id__\", \"__signature__\"],\n", 1029 | " remove_columns=[\"__signature__\"],\n", 1030 | " desc=f\"Querying...\",\n", 1031 | " )\n", 1032 | "\n", 1033 | " del lsh\n", 1034 | " gc.collect()\n", 1035 | "\n", 1036 | " queried = queried.filter(\n", 1037 | " lambda x: len(x[\"__neighbors__\"]) > 0,\n", 1038 | " num_proc=os.cpu_count(),\n", 1039 | " desc=\"Finding duplicates...\",\n", 1040 | " )\n", 1041 | " dup_ids = _find_duplicate_communities(\n", 1042 | " records=queried,\n", 1043 | " community_detection=conf[\"community_detection\"],\n", 1044 | " report_false_positive_rate=conf[\"report_false_positive_rate\"],\n", 1045 | " reference_records=ds,\n", 1046 | " threshold=conf[\"threshold\"],\n", 1047 | " column=conf[\"column\"],\n", 1048 | " )\n", 1049 | "\n", 1050 | " del queried\n", 1051 | " gc.collect()\n", 1052 | "\n", 1053 | " if dry_run:\n", 1054 | " final_data = ds.map(\n", 1055 | " lambda idx: {\"duplicate\": idx in dup_ids},\n", 1056 | " input_columns=[\"__id__\"],\n", 1057 | " num_proc=os.cpu_count(),\n", 1058 | " desc=\"Labeling duplicates...\",\n", 1059 | " )\n", 1060 | " else:\n", 1061 | " final_data = ds.filter(\n", 1062 | " lambda idx: idx not in dup_ids,\n", 1063 | " input_columns=[\"__id__\"],\n", 1064 | " num_proc=os.cpu_count(),\n", 1065 | " desc=\"Filtering duplicates...\",\n", 1066 | " )\n", 1067 | " return final_data" 1068 | ] 1069 | }, 1070 | { 1071 | "cell_type": "code", 1072 | "execution_count": null, 1073 | "metadata": {}, 1074 | "outputs": [ 1075 | { 1076 | "name": "stdout", 1077 | "output_type": "stream", 1078 | "text": [ 1079 | " " 1080 | ] 1081 | }, 1082 | { 1083 | "data": { 1084 | "application/vnd.jupyter.widget-view+json": { 1085 | "model_id": "8405a6a9d16447d89efd387db70f5c91", 1086 | "version_major": 2, 1087 | "version_minor": 0 1088 | }, 1089 | "text/plain": [ 1090 | "Indexing signatures...: 0%| | 0/1000 [00:00 This module contains all the various cleaning options supported." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# | default_exp clean" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# | export\n", 28 | "import re\n", 29 | "from faker import Faker\n", 30 | "import ftfy\n", 31 | "\n", 32 | "fake = Faker()" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# | hide\n", 42 | "from nbdev.showdoc import *" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# | export\n", 52 | "# From: https://github.com/bigscience-workshop/data-preparation/blob/main/preprocessing/training/01b_oscar_cleaning_and_filtering/filtering.py#L95\n", 53 | "whitespace = {\n", 54 | " \" \",\n", 55 | " \" \",\n", 56 | " \" \",\n", 57 | " \" \",\n", 58 | " \" \",\n", 59 | " \" \",\n", 60 | " \" \",\n", 61 | " \" \",\n", 62 | " \" \",\n", 63 | " \" \",\n", 64 | " \"\",\n", 65 | " \"„\",\n", 66 | "}\n", 67 | "\n", 68 | "\n", 69 | "def normalize_whitespace(\n", 70 | " text: str, # The text to normalize\n", 71 | ") -> str: # The normalized text\n", 72 | " \"\"\"\n", 73 | " Replace the various whitespace characters with the standard one.\n", 74 | " \"\"\"\n", 75 | " text = \"\".join([char if char not in whitespace else \" \" for char in text])\n", 76 | " return text" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# test the normalize_whitespace function\n", 86 | "assert normalize_whitespace(\"a b c d e f g h ij„k\") == \"a b c d e f g h i j k\"" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# | export\n", 96 | "unicode_punctuation = {\n", 97 | " \",\": \",\",\n", 98 | " \"。\": \".\",\n", 99 | " \"、\": \",\",\n", 100 | " \"„\": '\"',\n", 101 | " \"”\": '\"',\n", 102 | " \"“\": '\"',\n", 103 | " \"«\": '\"',\n", 104 | " \"»\": '\"',\n", 105 | " \"1\": '\"',\n", 106 | " \"」\": '\"',\n", 107 | " \"「\": '\"',\n", 108 | " \"《\": '\"',\n", 109 | " \"》\": '\"',\n", 110 | " \"´\": \"'\",\n", 111 | " \"∶\": \":\",\n", 112 | " \":\": \":\",\n", 113 | " \"?\": \"?\",\n", 114 | " \"!\": \"!\",\n", 115 | " \"(\": \"(\",\n", 116 | " \")\": \")\",\n", 117 | " \";\": \";\",\n", 118 | " \"–\": \"-\",\n", 119 | " \"—\": \" - \",\n", 120 | " \".\": \". \",\n", 121 | " \"~\": \"~\",\n", 122 | " \"’\": \"'\",\n", 123 | " \"…\": \"...\",\n", 124 | " \"━\": \"-\",\n", 125 | " \"〈\": \"<\",\n", 126 | " \"〉\": \">\",\n", 127 | " \"【\": \"[\",\n", 128 | " \"】\": \"]\",\n", 129 | " \"%\": \"%\",\n", 130 | " \"►\": \"-\",\n", 131 | "}\n", 132 | "\n", 133 | "\n", 134 | "def normalize_punctuation(\n", 135 | " text: str, # The text to normalize\n", 136 | ") -> str: # The normalized text\n", 137 | " \"\"\"\n", 138 | " Replace the various unicode punctuation characters with the standard ones.\n", 139 | " \"\"\"\n", 140 | " text = \"\".join([unicode_punctuation.get(char, char) for char in text])\n", 141 | " return text" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "# test the normalize_punctuation function\n", 151 | "text = \",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►\"\n", 152 | "\n", 153 | "assert normalize_punctuation(text) == ',.,\"\"\"\"\"\"\"\"\"\"\\'::?!();- - . ~\\'...-<>[]%-'" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "# | export\n", 163 | "def remove_empty_lines(\n", 164 | " text: str, # The text to remove empty lines from\n", 165 | ") -> str: # The text with empty lines removed\n", 166 | " \"\"\"\n", 167 | " Remove empty lines from the text.\n", 168 | " Solution from https://stackoverflow.com/a/3711884/5768407\n", 169 | " \"\"\"\n", 170 | " lines = text.splitlines()\n", 171 | " filtered = filter(lambda x: not re.match(r\"^\\s*$\", x), lines)\n", 172 | " return \"\\n\".join(filtered)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "# test the remove_empty_lines function\n", 182 | "starts_with_newline = \"\\nfoo\\nbar\"\n", 183 | "multiple_newlines = \"foo\\n\\nbar\"\n", 184 | "ends_with_newline = \"foo\\nbar\\n\"\n", 185 | "\n", 186 | "assert remove_empty_lines(starts_with_newline) == \"foo\\nbar\"\n", 187 | "assert remove_empty_lines(multiple_newlines) == \"foo\\nbar\"\n", 188 | "assert remove_empty_lines(ends_with_newline) == \"foo\\nbar\"" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "# | export\n", 198 | "def replace_urls(\n", 199 | " text: str, # The text to replace URLs in\n", 200 | " dummy: str = \"https://example.com/\", # The dummy text to replace URLs with\n", 201 | ") -> str: # The text with URLs replaced\n", 202 | " \"\"\"Replace urls from text with a dummy.\"\"\"\n", 203 | " return re.sub(r\"http\\S+\", dummy, text)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "# test the replace_urls function\n", 213 | "url_after_space = \"foo http://bar.com\"\n", 214 | "url_before_space = \"http://foo.com bar\"\n", 215 | "assert replace_urls(url_after_space) == \"foo https://example.com/\"\n", 216 | "assert replace_urls(url_before_space) == \"https://example.com/ bar\"" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "# | export\n", 226 | "def replace_dates(\n", 227 | " text: str, # The text to remove dates from\n", 228 | " dummy: str = fake.date(), # The dummy text to replace dates with\n", 229 | ") -> str: # The text with dates replaced\n", 230 | " \"\"\"Replace dates from text with a dummy.\"\"\"\n", 231 | " return re.sub(r\"\\d{1,2}/\\d{1,2}/\\d{4}\", dummy, text)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "# test the replace_dates function\n", 241 | "date_after_space = \"foo 1/1/2020\"\n", 242 | "date_before_space = \"1/1/2020 bar\"\n", 243 | "assert replace_dates(date_after_space, \"1/1/1970\") == \"foo 1/1/1970\"\n", 244 | "assert replace_dates(date_before_space, \"1/1/1970\") == \"1/1/1970 bar\"" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "## PII Removal\n", 252 | "\n", 253 | "Currently, we support the following PII removal options:\n", 254 | "\n", 255 | " * `replace_email`\n", 256 | " * `replace_phone`\n", 257 | " * `replace_ip`\n", 258 | " * `replace_credit_card`\n", 259 | " * `replace_ssn`\n", 260 | "\n", 261 | "However, for emails, phone numbers, credit cards, and SSNs, we recommend you to use the [scrubadub](https://scrubadub.readthedocs.io/en/stable/index.html) library." 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "# | export\n", 271 | "def replace_email(\n", 272 | " text: str, # The text to replace email addresses in\n", 273 | " dummy: str = fake.email(), # The dummy text to replace email addresses with\n", 274 | ") -> str: # The text with email addresses replaced\n", 275 | " \"\"\"Replace email addresses from text with a dummy.\"\"\"\n", 276 | " return re.sub(r\"[\\w\\.-]+@[\\w\\.-]+\", dummy, text)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "# test the replace_email function\n", 286 | "email_after_space = \"foo fake@email.com\"\n", 287 | "email_before_space = \"fake@email.com bar\"\n", 288 | "email_with_forward_periods = \"foo.bar@email.com\"\n", 289 | "email_with_backward_periods = \"foo@bar.email.com\"\n", 290 | "\n", 291 | "assert replace_email(email_after_space, \"example@email.com\") == \"foo example@email.com\"\n", 292 | "assert replace_email(email_before_space, \"example@email.com\") == \"example@email.com bar\"\n", 293 | "assert (\n", 294 | " replace_email(email_with_forward_periods, \"example@email.com\")\n", 295 | " == \"example@email.com\"\n", 296 | ")\n", 297 | "assert (\n", 298 | " replace_email(email_with_backward_periods, \"example@email.com\")\n", 299 | " == \"example@email.com\"\n", 300 | ")" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "# | export\n", 310 | "def replace_phone(\n", 311 | " text: str, # The text to replace phone numbers in\n", 312 | " dummy: str = fake.phone_number(), # The dummy text to replace phone numbers with\n", 313 | ") -> str: # The text with phone numbers replaced\n", 314 | " \"\"\"Replace phone numbers from text with a dummy.\"\"\"\n", 315 | " return re.sub(r\"\\(?\\d{3}\\)?-? *\\d{3}-? *-?\\d{4}\", dummy, text)" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "# test the replace_phone function\n", 325 | "phone_after_space = \"foo 111-222-3333\"\n", 326 | "phone_before_space = \"111-222-3333 bar\"\n", 327 | "phone_with_parens = \"(111) 222-3333\"\n", 328 | "phone_with_spaces = \"111 222 3333\"\n", 329 | "phone_with_dashes = \"111-222-3333\"\n", 330 | "\n", 331 | "assert replace_phone(phone_after_space, \"123-456-7890\") == \"foo 123-456-7890\"\n", 332 | "assert replace_phone(phone_before_space, \"123-456-7890\") == \"123-456-7890 bar\"\n", 333 | "assert replace_phone(phone_with_parens, \"123-456-7890\") == \"123-456-7890\"\n", 334 | "assert replace_phone(phone_with_spaces, \"123-456-7890\") == \"123-456-7890\"\n", 335 | "assert replace_phone(phone_with_dashes, \"123-456-7890\") == \"123-456-7890\"" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "# | export\n", 345 | "def replace_ip(\n", 346 | " text, # The text to replace ip addresses in\n", 347 | " dummy1: str = fake.ipv4(), # The dummy text to replace ipv4 addresses with\n", 348 | " dummy2: str = fake.ipv6(), # The dummy text to replace ipv6 addresses with\n", 349 | ") -> str: # The text with ip addresses replaced\n", 350 | " \"\"\"\n", 351 | " Replace ip addresses from text with a dummy.\n", 352 | " Solution from https://github.com/bigcode-project/bigcode-analysis/blob/main/data_analysis/pii/utils/emails_ip_addresses_detection.py#L48\n", 353 | " \"\"\"\n", 354 | " ipv4_pattern = r\"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(?:\\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}\"\n", 355 | " text = re.sub(ipv4_pattern, dummy1, text)\n", 356 | " ipv6_pattern = r\"(?:[0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,7}:|(?:[0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,5}(?::[0-9a-fA-F]{1,4}){1,2}|(?:[0-9a-fA-F]{1,4}:){1,4}(?::[0-9a-fA-F]{1,4}){1,3}|(?:[0-9a-fA-F]{1,4}:){1,3}(?::[0-9a-fA-F]{1,4}){1,4}|(?:[0-9a-fA-F]{1,4}:){1,2}(?::[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:(?:(?::[0-9a-fA-F]{1,4}){1,6})|:(?:(?::[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(?::[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(?:ffff(?::0{1,4}){0,1}:){0,1}(?:(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])\\.){3,3}(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])|(?:[0-9a-fA-F]{1,4}:){1,4}:(?:(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])\\.){3,3}(25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])\"\n", 357 | " text = re.sub(ipv6_pattern, dummy2, text)\n", 358 | " return text" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "# test the replace_ip function\n", 368 | "ip4_after_space = \"foo 111.222.3.4\"\n", 369 | "ip4_before_space = \"111.222.3.4 bar\"\n", 370 | "ip6_with_colons = \"2001:0db8:0000:0000:0000:8a2e:0370:7334\"\n", 371 | "\n", 372 | "assert replace_ip(ip4_after_space, \"127.0.0.1\") == \"foo 127.0.0.1\"\n", 373 | "assert replace_ip(ip4_before_space, \"127.0.0.1\") == \"127.0.0.1 bar\"\n", 374 | "assert replace_ip(ip6_with_colons, \"127.0.0.1\", \"0:0:0:0:0:0:0:1\") == \"0:0:0:0:0:0:0:1\"" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "# | export\n", 384 | "def replace_credit_card(\n", 385 | " text: str, # The text to replace credit card numbers in\n", 386 | " dummy: str = fake.credit_card_number(), # The dummy text to replace credit card numbers with\n", 387 | ") -> str: # The text with credit card numbers replaced\n", 388 | " \"\"\"Replace credit card numbers from text with a dummy.\"\"\"\n", 389 | " return re.sub(r\"\\d{4}-\\d{4}-\\d{4}-\\d{4}\", dummy, text)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "# test the replace_credit_card function\n", 399 | "credit_card_after_space = \"foo 1111-2222-3333-4444\"\n", 400 | "credit_card_before_space = \"1111-2222-3333-4444 bar\"\n", 401 | "\n", 402 | "assert (\n", 403 | " replace_credit_card(credit_card_after_space, \"1234-5678-9012-3456\")\n", 404 | " == \"foo 1234-5678-9012-3456\"\n", 405 | ")\n", 406 | "assert (\n", 407 | " replace_credit_card(credit_card_before_space, \"1234-5678-9012-3456\")\n", 408 | " == \"1234-5678-9012-3456 bar\"\n", 409 | ")" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "# | export\n", 419 | "def replace_ssn(\n", 420 | " text: str, # The text to replace social security numbers in\n", 421 | " dummy: str = fake.ssn(), # The dummy text to replace social security numbers with\n", 422 | ") -> str: # The text with social security numbers replaced\n", 423 | " \"\"\"Replace social security numbers from text with a dummy.\"\"\"\n", 424 | " return re.sub(r\"\\d{3}-\\d{2}-\\d{4}\", dummy, text)" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": null, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [ 433 | "# test the replace_ssn function\n", 434 | "ssn_after_space = \"foo 111-22-3333\"\n", 435 | "ssn_before_space = \"111-22-3333 bar\"\n", 436 | "\n", 437 | "assert replace_ssn(ssn_after_space, \"123-45-6789\") == \"foo 123-45-6789\"\n", 438 | "assert replace_ssn(ssn_before_space, \"123-45-6789\") == \"123-45-6789 bar\"" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "# | export\n", 448 | "def fix_utf8_encoding(\n", 449 | " text: str, # The text to fix\n", 450 | ") -> str: # The fixed text\n", 451 | " \"\"\"Fix utf8 text using ftfy.\"\"\"\n", 452 | " return ftfy.fix_text(text)" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [ 461 | "# test the fix_utf8_encoding function\n", 462 | "bad_text = \"✔ No problems\"\n", 463 | "assert fix_utf8_encoding(bad_text) == \"✔ No problems\"\n", 464 | "bad_text = \"déjà vu\"\n", 465 | "assert fix_utf8_encoding(bad_text) == \"déjà vu\"\n", 466 | "bad_text = \"é\"\n", 467 | "assert fix_utf8_encoding(bad_text) == \"é\"\n", 468 | "bad_text = \"P&EACUTE;REZ\"\n", 469 | "assert fix_utf8_encoding(bad_text) == \"PÉREZ\"" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [ 478 | "# | export\n", 479 | "def clean_code_license(\n", 480 | " code: str, # The code to clean\n", 481 | " language: str = \"python\", # The language of the code\n", 482 | " min_lines: int = 3, # The minimum number of lines that need to be removed\n", 483 | "):\n", 484 | " import code_ast\n", 485 | " from code_ast import ASTVisitor\n", 486 | " from code_ast.ast import LEAVE_WHITELIST\n", 487 | "\n", 488 | " class FirstNonCommentVisitor(ASTVisitor):\n", 489 | " def __init__(self):\n", 490 | " self.passed_global_node = False\n", 491 | " self.first_node = None\n", 492 | "\n", 493 | " def visit(self, node):\n", 494 | " if not self.passed_global_node:\n", 495 | " self.passed_global_node = True\n", 496 | " return\n", 497 | " if self.first_node is None:\n", 498 | " if node.child_count > 0 or node.type in LEAVE_WHITELIST:\n", 499 | " self.first_node = node\n", 500 | "\n", 501 | " \"\"\"Remove the license or other boilerplate comments from the code.\"\"\"\n", 502 | " ast = code_ast.ast(code, lang=language)\n", 503 | " visitor = FirstNonCommentVisitor()\n", 504 | " ast.visit(visitor)\n", 505 | " start_line = visitor.first_node.start_point[0]\n", 506 | " if start_line < min_lines:\n", 507 | " return code\n", 508 | " else:\n", 509 | " return \"\\n\".join(code.splitlines()[start_line:])" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "# |eval: false\n", 519 | "# Test the cleaning of code licenses or similar boilerplate comments from code\n", 520 | "code_python = \"\"\"# -*- coding: utf-8 -*-\n", 521 | "\n", 522 | "# Copyright 2018 Spanish National Research Council (CSIC)\n", 523 | "#\n", 524 | "# Licensed under the Apache License, Version 2.0 (the \"License\"); you may\n", 525 | "# not use this file except in compliance with the License. You may obtain\n", 526 | "# a copy of the License at\n", 527 | "#\n", 528 | "# http://www.apache.org/licenses/LICENSE-2.0\n", 529 | "#\n", 530 | "# Unless required by applicable law or agreed to in writing, software\n", 531 | "# distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT\n", 532 | "# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\n", 533 | "# License for the specific language governing permissions and limitations\n", 534 | "# under the License.\n", 535 | "\n", 536 | "\\\"\\\"\\\"\n", 537 | "Given two dates and region, download N Sentinel Collections scenes from ESA\n", 538 | "Sentinel dataHUB.\n", 539 | "The downloaded Sentinel collection scenes are compatible with:\n", 540 | "S2MSI1C: Top-of-atmosphere reflectances in cartographic geometry\n", 541 | "or S2MSI2A: Bottom-of-atmosphere reflectance in cartographic geometry\n", 542 | "Parameters\n", 543 | "----------\n", 544 | "inidate: datetime.strptime(\"YYYY-MM-dd\", \"%Y-%m-%d\")\n", 545 | "enddate: datetime.strptime(\"YYYY-MM-dd\", \"%Y-%m-%d\")\n", 546 | "region: name of one reservoir saved in the \"coord_reservoirs.json\" file\n", 547 | "coordinates : dict. Coordinates of the region to search.\n", 548 | "Example: {\"W\": -2.830, \"S\": 41.820, \"E\": -2.690, \"N\": 41.910}}\n", 549 | "platform : str. Satellite to use from the Sentinel family\n", 550 | "producttype : str. Dataset type.\n", 551 | "cloud: int\n", 552 | "path : path\n", 553 | "Author: Daniel García Díaz\n", 554 | "Email: garciad@ifca.unican.es\n", 555 | "Institute of Physics of Cantabria (IFCA)\n", 556 | "Advanced Computing and e-Science\n", 557 | "Date: Sep 2018\n", 558 | "\\\"\\\"\\\"\n", 559 | "#imports apis\n", 560 | "import requests\n", 561 | "import os\n", 562 | "\n", 563 | "# Subfunctions\n", 564 | "from wq_sat.utils import config\n", 565 | "\"\"\"\n", 566 | "\n", 567 | "code_go = \"\"\"// +build go1.9\n", 568 | "\n", 569 | "// Copyright 2019 Microsoft Corporation\n", 570 | "//\n", 571 | "// Licensed under the Apache License, Version 2.0 (the \"License\");\n", 572 | "// you may not use this file except in compliance with the License.\n", 573 | "// You may obtain a copy of the License at\n", 574 | "//\n", 575 | "// http://www.apache.org/licenses/LICENSE-2.0\n", 576 | "//\n", 577 | "// Unless required by applicable law or agreed to in writing, software\n", 578 | "// distributed under the License is distributed on an \"AS IS\" BASIS,\n", 579 | "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 580 | "// See the License for the specific language governing permissions and\n", 581 | "// limitations under the License.\n", 582 | "\n", 583 | "// This code was auto-generated by:\n", 584 | "// github.com/Azure/azure-sdk-for-go/tools/profileBuilder\n", 585 | "\n", 586 | "package policyinsights\n", 587 | "\n", 588 | "import (\n", 589 | "\t\"context\"\n", 590 | "\n", 591 | "\toriginal \"github.com/Azure/azure-sdk-for-go/services/policyinsights/mgmt/2019-10-01/policyinsights\"\n", 592 | ")\n", 593 | "\"\"\"\n", 594 | "\n", 595 | "code_c = \"\"\"/*\n", 596 | " * copyright (c) 2008 - 2011 Espressif System\n", 597 | " *\n", 598 | " * Define user specified Event signals and Task priorities here\n", 599 | " *\n", 600 | " */\n", 601 | "\n", 602 | "#ifndef _ETS_SYS_H\n", 603 | "#define _ETS_SYS_H\n", 604 | "\n", 605 | "#include \"c_types.h\"\n", 606 | "#include \"eagle_soc.h\"\n", 607 | "\n", 608 | "typedef uint32_t ETSSignal;\n", 609 | "\"\"\"\n", 610 | "\n", 611 | "code_cpp = \"\"\"/* Pokemon Automation Bot Base - Client Example\n", 612 | "\n", 613 | " * \n", 614 | "\n", 615 | " * From: https://github.com/PokemonAutomation/Arduino-Source\n", 616 | "\n", 617 | " * \n", 618 | "\n", 619 | " */\n", 620 | "\n", 621 | "\n", 622 | "\n", 623 | "#include \"Common/CRC32.h\"\n", 624 | "\n", 625 | "#include \"Common/Microcontroller/MessageProtocol.h\"\n", 626 | "\n", 627 | "#include \"ClientSource/Libraries/Logging.h\"\n", 628 | "\n", 629 | "#include \"ClientSource/Libraries/MessageConverter.h\"\n", 630 | "\n", 631 | "#include \"BotBaseMessage.h\"\n", 632 | "\n", 633 | "#include \"PABotBaseConnection.h\"\n", 634 | "\n", 635 | "\n", 636 | "\n", 637 | "#include \n", 638 | "\n", 639 | "using std::cout;\n", 640 | "\"\"\"\n", 641 | "\n", 642 | "code_java = \"\"\"/*\n", 643 | " * Copyright (C) 2012-2021 DuyHai DOAN\n", 644 | " *\n", 645 | " * Licensed under the Apache License, Version 2.0 (the \"License\");\n", 646 | " * you may not use this file except in compliance with the License.\n", 647 | " * You may obtain a copy of the License at\n", 648 | " *\n", 649 | " * http://www.apache.org/licenses/LICENSE-2.0\n", 650 | " *\n", 651 | " * Unless required by applicable law or agreed to in writing, software\n", 652 | " * distributed under the License is distributed on an \"AS IS\" BASIS,\n", 653 | " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 654 | " * See the License for the specific language governing permissions and\n", 655 | " * limitations under the License.\n", 656 | " */\n", 657 | "\n", 658 | "package info.archinnov.achilles.internals.sample_classes.parser.entity;\n", 659 | "\n", 660 | "import info.archinnov.achilles.annotations.Column;\n", 661 | "\"\"\"\n", 662 | "\n", 663 | "code_javascript = \"\"\"/*\n", 664 | "** Copyright (c) 2016-2019, Thomas Farr\n", 665 | "**\n", 666 | "** This Source Code Form is subject to the terms of the Mozilla Public\n", 667 | "** License, v. 2.0. If a copy of the MPL was not distributed with this\n", 668 | "** file, You can obtain one at https://mozilla.org/MPL/2.0/.\n", 669 | "*/\n", 670 | "\n", 671 | "// TODO: Implement testing of option handling, and filename arrays\n", 672 | "\n", 673 | "const anitomy = require('../anitomy');\n", 674 | "const async = require('async');\n", 675 | "\"\"\"\n", 676 | "\n", 677 | "cleaned_code_python = clean_code_license(code_python, language=\"python\")\n", 678 | "cleaned_code_go = clean_code_license(code_go, language=\"go\")\n", 679 | "cleaned_code_c = clean_code_license(code_c, language=\"c\")\n", 680 | "cleaned_code_cpp = clean_code_license(code_cpp, language=\"cpp\")\n", 681 | "cleaned_code_java = clean_code_license(code_java, language=\"java\")\n", 682 | "cleaned_code_javascript = clean_code_license(code_javascript, language=\"javascript\")\n", 683 | "\n", 684 | "assert (\n", 685 | " cleaned_code_python\n", 686 | " == \"\"\"\\\"\\\"\\\"\n", 687 | "Given two dates and region, download N Sentinel Collections scenes from ESA\n", 688 | "Sentinel dataHUB.\n", 689 | "The downloaded Sentinel collection scenes are compatible with:\n", 690 | "S2MSI1C: Top-of-atmosphere reflectances in cartographic geometry\n", 691 | "or S2MSI2A: Bottom-of-atmosphere reflectance in cartographic geometry\n", 692 | "Parameters\n", 693 | "----------\n", 694 | "inidate: datetime.strptime(\"YYYY-MM-dd\", \"%Y-%m-%d\")\n", 695 | "enddate: datetime.strptime(\"YYYY-MM-dd\", \"%Y-%m-%d\")\n", 696 | "region: name of one reservoir saved in the \"coord_reservoirs.json\" file\n", 697 | "coordinates : dict. Coordinates of the region to search.\n", 698 | "Example: {\"W\": -2.830, \"S\": 41.820, \"E\": -2.690, \"N\": 41.910}}\n", 699 | "platform : str. Satellite to use from the Sentinel family\n", 700 | "producttype : str. Dataset type.\n", 701 | "cloud: int\n", 702 | "path : path\n", 703 | "Author: Daniel García Díaz\n", 704 | "Email: garciad@ifca.unican.es\n", 705 | "Institute of Physics of Cantabria (IFCA)\n", 706 | "Advanced Computing and e-Science\n", 707 | "Date: Sep 2018\n", 708 | "\\\"\\\"\\\"\n", 709 | "#imports apis\n", 710 | "import requests\n", 711 | "import os\n", 712 | "\n", 713 | "# Subfunctions\n", 714 | "from wq_sat.utils import config\"\"\"\n", 715 | ")\n", 716 | "assert (\n", 717 | " cleaned_code_go\n", 718 | " == \"\"\"package policyinsights\n", 719 | "\n", 720 | "import (\n", 721 | "\t\"context\"\n", 722 | "\n", 723 | "\toriginal \"github.com/Azure/azure-sdk-for-go/services/policyinsights/mgmt/2019-10-01/policyinsights\"\n", 724 | ")\"\"\"\n", 725 | ")\n", 726 | "assert (\n", 727 | " cleaned_code_c\n", 728 | " == \"\"\"#ifndef _ETS_SYS_H\n", 729 | "#define _ETS_SYS_H\n", 730 | "\n", 731 | "#include \"c_types.h\"\n", 732 | "#include \"eagle_soc.h\"\n", 733 | "\n", 734 | "typedef uint32_t ETSSignal;\"\"\"\n", 735 | ")\n", 736 | "assert (\n", 737 | " cleaned_code_cpp\n", 738 | " == \"\"\"#include \"Common/CRC32.h\"\n", 739 | "\n", 740 | "#include \"Common/Microcontroller/MessageProtocol.h\"\n", 741 | "\n", 742 | "#include \"ClientSource/Libraries/Logging.h\"\n", 743 | "\n", 744 | "#include \"ClientSource/Libraries/MessageConverter.h\"\n", 745 | "\n", 746 | "#include \"BotBaseMessage.h\"\n", 747 | "\n", 748 | "#include \"PABotBaseConnection.h\"\n", 749 | "\n", 750 | "\n", 751 | "\n", 752 | "#include \n", 753 | "\n", 754 | "using std::cout;\"\"\"\n", 755 | ")\n", 756 | "assert (\n", 757 | " cleaned_code_java\n", 758 | " == \"\"\"package info.archinnov.achilles.internals.sample_classes.parser.entity;\n", 759 | "\n", 760 | "import info.archinnov.achilles.annotations.Column;\"\"\"\n", 761 | ")\n", 762 | "assert (\n", 763 | " cleaned_code_javascript\n", 764 | " == \"\"\"const anitomy = require('../anitomy');\n", 765 | "const async = require('async');\"\"\"\n", 766 | ")" 767 | ] 768 | }, 769 | { 770 | "cell_type": "code", 771 | "execution_count": null, 772 | "metadata": {}, 773 | "outputs": [], 774 | "source": [ 775 | "# | hide\n", 776 | "import nbdev\n", 777 | "\n", 778 | "nbdev.nbdev_export()" 779 | ] 780 | } 781 | ], 782 | "metadata": { 783 | "kernelspec": { 784 | "display_name": "python3", 785 | "language": "python", 786 | "name": "python3" 787 | } 788 | }, 789 | "nbformat": 4, 790 | "nbformat_minor": 4 791 | } 792 | -------------------------------------------------------------------------------- /nbs/04_tutorials.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/fsx/nathan/miniconda3/envs/squeakily/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "# | hide\n", 19 | "import logging\n", 20 | "\n", 21 | "from squeakily.core import *\n", 22 | "\n", 23 | "# Turn off logging for datasets\n", 24 | "logging.getLogger(\"datasets\").setLevel(logging.ERROR)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "# Tutorial: Using another library\n", 32 | "\n", 33 | "> This tutorial shows how to use another library in a notebook. We will use the [scrubadub](https://scrubadub.readthedocs.io/en/stable/index.html) library to remove personal information from text." 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "First off, we need to install the library.\n", 41 | "\n", 42 | "```bash\n", 43 | "pip install scrubadub\n", 44 | "```\n", 45 | "\n", 46 | "Now we will use the same (wikitext) dataset as in the previous tutorial." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "from datasets import load_dataset\n", 56 | "\n", 57 | "ds = load_dataset(\"wikitext\", \"wikitext-103-v1\", split=\"train[:1%]\")" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "We will use the `scrubadub` library to remove personal information from the text. `scrubadub` usually defaults to removing the following types:\n", 65 | "* [credential](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-credentialdetector) - username and password combinations\n", 66 | "* [credit_card](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-creditcarddetector) - credit card numbers\n", 67 | "* [drivers_license](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-driverslicencedetector) - drivers license numbers\n", 68 | "* [email](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-emaildetector) - email addresses\n", 69 | "* [national_insurance_number](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-en-gb-nationalinsurancenumberdetector) - GB National Insurance numbers (NINOs)\n", 70 | "* [phone](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-phonedetector) - phone numbers\n", 71 | "* [postalcode](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-postalcodedetector) - british postal codes\n", 72 | "* [social_security_number](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-en-us-socialsecuritynumberdetector) - US Social Security numbers (SSNs)\n", 73 | "* [tax_reference_number](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-en-gb-taxreferencenumberdetector) - UK PAYE temporary reference number (TRN)\n", 74 | "* [twitter](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-twitterdetector) - twitter handles\n", 75 | "* [url](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-urldetector) - URLs\n", 76 | "* [vehicle_license_plate](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_detectors.html#scrubadub-detectors-vehiclelicenceplatedetector) - british vehicle license plates\n", 77 | "\n", 78 | "However, while experimenting with the library it seems some of these are not on by default. Either way, we are only going to focus on the `credit_card`, `drivers_license`, `email`, `phone`, and `social_security_number` detectors. Therefore, we must turn the others off:" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "from scrubadub import Scrubber\n", 88 | "from scrubadub.detectors import CredentialDetector, TwitterDetector, UrlDetector\n", 89 | "\n", 90 | "scrubber = Scrubber()\n", 91 | "scrubber.remove_detector(CredentialDetector)\n", 92 | "scrubber.remove_detector(TwitterDetector)\n", 93 | "scrubber.remove_detector(UrlDetector)\n", 94 | "\n", 95 | "datasources = [\n", 96 | " {\n", 97 | " \"dataset\": ds,\n", 98 | " \"name\": \"wikitext\",\n", 99 | " \"columns\": [\"text\"],\n", 100 | " \"filters\": [],\n", 101 | " \"cleaners\": [scrubber.clean],\n", 102 | " },\n", 103 | " # ...\n", 104 | "]" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "Essentially, any function that takes in a string and returns a string will work out of the box with `squeakily`. Luckily for us, `scrubadub` has a `clean` function that does just that. We can use this function to remove personal information from the text!\n", 112 | "\n", 113 | "A similar process can be used for filters, except the return type is a `bool` instead of a `str` denoting whether or not the text should be kept.\n", 114 | "\n", 115 | ":::{.callout-note}\n", 116 | "Note: If you want to mix and match, it is super easy!\n", 117 | "\n", 118 | "```python\n", 119 | "from squeakily.clean import remove_empty_lines, remove_ip\n", 120 | "datasources = [\n", 121 | " {\n", 122 | " \"dataset\": ds,\n", 123 | " \"name\": \"wikitext\",\n", 124 | " \"columns\": [\"text\"],\n", 125 | " \"filters\": [],\n", 126 | " \"cleaners\": [scrubber.clean, remove_empty_lines, remove_ip],\n", 127 | " },\n", 128 | " # ...\n", 129 | "]\n", 130 | "```\n", 131 | ":::\n", 132 | "\n", 133 | "Now we can process the `datasources` as before with a `Pipeline` object." 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "data": { 143 | "text/html": [ 144 | "
[11/16/22 04:50:08] INFO     Running datasource: wikitext                                                core.py:41\n",
145 |        "
\n" 146 | ], 147 | "text/plain": [ 148 | "\u001b[2;36m[11/16/22 04:50:08]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running datasource: wikitext \u001b]8;id=538643;file:///fsx/home-nathan/work/squeakily/squeakily/core.py\u001b\\\u001b[2mcore.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=230902;file:///fsx/home-nathan/work/squeakily/squeakily/core.py#41\u001b\\\u001b[2m41\u001b[0m\u001b]8;;\u001b\\\n" 149 | ] 150 | }, 151 | "metadata": {}, 152 | "output_type": "display_data" 153 | }, 154 | { 155 | "data": { 156 | "text/html": [ 157 | "
                    INFO     Running cleaner: clean on text                                              core.py:57\n",
158 |        "
\n" 159 | ], 160 | "text/plain": [ 161 | "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running cleaner: clean on text \u001b]8;id=441718;file:///fsx/home-nathan/work/squeakily/squeakily/core.py\u001b\\\u001b[2mcore.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=808891;file:///fsx/home-nathan/work/squeakily/squeakily/core.py#57\u001b\\\u001b[2m57\u001b[0m\u001b]8;;\u001b\\\n" 162 | ] 163 | }, 164 | "metadata": {}, 165 | "output_type": "display_data" 166 | }, 167 | { 168 | "name": "stderr", 169 | "output_type": "stream", 170 | "text": [ 171 | "#0: 0%| | 0/251 [00:00 .sourceCode { 6 | margin-bottom: 0; 7 | } 8 | 9 | .cell-output > pre { 10 | margin-bottom: 0; 11 | } 12 | 13 | .cell-output > pre, .cell-output > .sourceCode > pre, .cell-output-stdout > pre { 14 | margin-left: 0.8rem; 15 | margin-top: 0; 16 | background: none; 17 | border-left: 2px solid lightsalmon; 18 | border-top-left-radius: 0; 19 | border-top-right-radius: 0; 20 | } 21 | 22 | .cell-output > .sourceCode { 23 | border: none; 24 | } 25 | 26 | .cell-output > .sourceCode { 27 | background: none; 28 | margin-top: 0; 29 | } 30 | 31 | div.description { 32 | padding-left: 2px; 33 | padding-top: 5px; 34 | font-style: italic; 35 | font-size: 135%; 36 | opacity: 70%; 37 | } 38 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | repo = squeakily 3 | lib_name = squeakily 4 | version = 0.0.3 5 | min_python = 3.7 6 | license = apache2 7 | doc_path = _docs 8 | lib_path = squeakily 9 | nbs_path = nbs 10 | recursive = True 11 | tst_flags = notest 12 | put_version_in_init = True 13 | branch = main 14 | custom_sidebar = False 15 | doc_host = https://CarperAI.github.io 16 | doc_baseurl = /squeakily 17 | git_url = https://github.com/CarperAI/squeakily 18 | title = squeakily 19 | audience = Developers 20 | author = ncoop57 21 | author_email = nathan.cooper@stability.ai 22 | copyright = 2022 onwards, ncoop57 23 | description = A library for squeakily cleaning and filtering language datasets. 24 | keywords = nbdev jupyter notebook python 25 | language = English 26 | status = 3 27 | user = CarperAI 28 | requirements = datasketch==1.5.8 datasets==2.7.1 Faker==15.3.3 fastcore huggingface-hub networkit pydantic rich ftfy scikit-learn 29 | dev_requirements = BeautifulSoup4 fasttext nbdev scrubadub twine sentencepiece code-tokenize langchain==0.0.212 openai code-ast 30 | black_formatting = False 31 | readme_nb = index.ipynb 32 | allowed_metadata_keys = 33 | allowed_cell_metadata_keys = 34 | jupyter_hooks = True 35 | clean_ids = True 36 | clear_all = False 37 | 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | licenses = { 17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 18 | 'mit': ('MIT License', 'OSI Approved :: MIT License'), 19 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), 20 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), 21 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '3.6 3.7 3.8 3.9 3.10'.split() 26 | 27 | requirements = cfg.get('requirements','').split() 28 | if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split() 29 | min_python = cfg['min_python'] 30 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) 31 | dev_requirements = (cfg.get('dev_requirements') or '').split() 32 | 33 | setuptools.setup( 34 | name = cfg['lib_name'], 35 | license = lic[0], 36 | classifiers = [ 37 | 'Development Status :: ' + statuses[int(cfg['status'])], 38 | 'Intended Audience :: ' + cfg['audience'].title(), 39 | 'Natural Language :: ' + cfg['language'].title(), 40 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), 41 | url = cfg['git_url'], 42 | packages = setuptools.find_packages(), 43 | include_package_data = True, 44 | install_requires = requirements, 45 | extras_require={ 'dev': dev_requirements }, 46 | dependency_links = cfg.get('dep_links','').split(), 47 | python_requires = '>=' + cfg['min_python'], 48 | long_description = open('README.md').read(), 49 | long_description_content_type = 'text/markdown', 50 | zip_safe = False, 51 | entry_points = { 52 | 'console_scripts': cfg.get('console_scripts','').split(), 53 | 'nbdev': [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'] 54 | }, 55 | **setup_cfg) 56 | 57 | 58 | -------------------------------------------------------------------------------- /squeakily/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.3" 2 | -------------------------------------------------------------------------------- /squeakily/_modidx.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by nbdev 2 | 3 | d = { 'settings': { 'branch': 'main', 4 | 'doc_baseurl': '/squeakily', 5 | 'doc_host': 'https://CarperAI.github.io', 6 | 'git_url': 'https://github.com/CarperAI/squeakily', 7 | 'lib_path': 'squeakily'}, 8 | 'syms': { 'squeakily.clean': { 'squeakily.clean.clean_code_license': ('clean.html#clean_code_license', 'squeakily/clean.py'), 9 | 'squeakily.clean.fix_utf8_encoding': ('clean.html#fix_utf8_encoding', 'squeakily/clean.py'), 10 | 'squeakily.clean.normalize_punctuation': ('clean.html#normalize_punctuation', 'squeakily/clean.py'), 11 | 'squeakily.clean.normalize_whitespace': ('clean.html#normalize_whitespace', 'squeakily/clean.py'), 12 | 'squeakily.clean.remove_empty_lines': ('clean.html#remove_empty_lines', 'squeakily/clean.py'), 13 | 'squeakily.clean.replace_credit_card': ('clean.html#replace_credit_card', 'squeakily/clean.py'), 14 | 'squeakily.clean.replace_dates': ('clean.html#replace_dates', 'squeakily/clean.py'), 15 | 'squeakily.clean.replace_email': ('clean.html#replace_email', 'squeakily/clean.py'), 16 | 'squeakily.clean.replace_ip': ('clean.html#replace_ip', 'squeakily/clean.py'), 17 | 'squeakily.clean.replace_phone': ('clean.html#replace_phone', 'squeakily/clean.py'), 18 | 'squeakily.clean.replace_ssn': ('clean.html#replace_ssn', 'squeakily/clean.py'), 19 | 'squeakily.clean.replace_urls': ('clean.html#replace_urls', 'squeakily/clean.py')}, 20 | 'squeakily.core': { 'squeakily.core.Pipeline': ('core.html#pipeline', 'squeakily/core.py'), 21 | 'squeakily.core.Pipeline.__init__': ('core.html#pipeline.__init__', 'squeakily/core.py'), 22 | 'squeakily.core.Pipeline.__run_filter': ('core.html#pipeline.__run_filter', 'squeakily/core.py'), 23 | 'squeakily.core.Pipeline.export_to_path': ('core.html#pipeline.export_to_path', 'squeakily/core.py'), 24 | 'squeakily.core.Pipeline.run': ('core.html#pipeline.run', 'squeakily/core.py')}, 25 | 'squeakily.filter': { 'squeakily.filter._calculate_average_false_positive_rate': ( 'filter.html#_calculate_average_false_positive_rate', 26 | 'squeakily/filter.py'), 27 | 'squeakily.filter._char_rep_ratio': ('filter.html#_char_rep_ratio', 'squeakily/filter.py'), 28 | 'squeakily.filter._compress_ratio': ('filter.html#_compress_ratio', 'squeakily/filter.py'), 29 | 'squeakily.filter._find_duplicate_communities': ( 'filter.html#_find_duplicate_communities', 30 | 'squeakily/filter.py'), 31 | 'squeakily.filter._flag_word_ratio': ('filter.html#_flag_word_ratio', 'squeakily/filter.py'), 32 | 'squeakily.filter._hash_func': ('filter.html#_hash_func', 'squeakily/filter.py'), 33 | 'squeakily.filter._jaccard_similarity': ('filter.html#_jaccard_similarity', 'squeakily/filter.py'), 34 | 'squeakily.filter._query_content': ('filter.html#_query_content', 'squeakily/filter.py'), 35 | 'squeakily.filter.check_char_repetition': ('filter.html#check_char_repetition', 'squeakily/filter.py'), 36 | 'squeakily.filter.check_code_parsability': ('filter.html#check_code_parsability', 'squeakily/filter.py'), 37 | 'squeakily.filter.check_compression_ratio': ( 'filter.html#check_compression_ratio', 38 | 'squeakily/filter.py'), 39 | 'squeakily.filter.check_flagged_words': ('filter.html#check_flagged_words', 'squeakily/filter.py'), 40 | 'squeakily.filter.check_labels': ('filter.html#check_labels', 'squeakily/filter.py'), 41 | 'squeakily.filter.check_language': ('filter.html#check_language', 'squeakily/filter.py'), 42 | 'squeakily.filter.check_perplexity': ('filter.html#check_perplexity', 'squeakily/filter.py'), 43 | 'squeakily.filter.check_stop_word_ratio': ('filter.html#check_stop_word_ratio', 'squeakily/filter.py'), 44 | 'squeakily.filter.check_word_number': ('filter.html#check_word_number', 'squeakily/filter.py'), 45 | 'squeakily.filter.minhash_dedup': ('filter.html#minhash_dedup', 'squeakily/filter.py')}, 46 | 'squeakily.helpers': { 'squeakily.helpers.FastTextLanguageDetector': ( 'helpers.html#fasttextlanguagedetector', 47 | 'squeakily/helpers.py'), 48 | 'squeakily.helpers.FastTextLanguageDetector.__eq__': ( 'helpers.html#fasttextlanguagedetector.__eq__', 49 | 'squeakily/helpers.py'), 50 | 'squeakily.helpers.FastTextLanguageDetector.__init__': ( 'helpers.html#fasttextlanguagedetector.__init__', 51 | 'squeakily/helpers.py'), 52 | 'squeakily.helpers.FastTextLanguageDetector.__reduce__': ( 'helpers.html#fasttextlanguagedetector.__reduce__', 53 | 'squeakily/helpers.py'), 54 | 'squeakily.helpers.FastTextLanguageDetector.from_pretrained': ( 'helpers.html#fasttextlanguagedetector.from_pretrained', 55 | 'squeakily/helpers.py'), 56 | 'squeakily.helpers.FastTextLanguageDetector.get_language': ( 'helpers.html#fasttextlanguagedetector.get_language', 57 | 'squeakily/helpers.py'), 58 | 'squeakily.helpers.KenlmModel': ('helpers.html#kenlmmodel', 'squeakily/helpers.py'), 59 | 'squeakily.helpers.KenlmModel.__init__': ('helpers.html#kenlmmodel.__init__', 'squeakily/helpers.py'), 60 | 'squeakily.helpers.KenlmModel.download_kenlm_model': ( 'helpers.html#kenlmmodel.download_kenlm_model', 61 | 'squeakily/helpers.py'), 62 | 'squeakily.helpers.KenlmModel.from_pretrained': ( 'helpers.html#kenlmmodel.from_pretrained', 63 | 'squeakily/helpers.py'), 64 | 'squeakily.helpers.KenlmModel.get_perplexity': ( 'helpers.html#kenlmmodel.get_perplexity', 65 | 'squeakily/helpers.py'), 66 | 'squeakily.helpers.KenlmModel.normalize': ('helpers.html#kenlmmodel.normalize', 'squeakily/helpers.py'), 67 | 'squeakily.helpers.KenlmModel.pp': ('helpers.html#kenlmmodel.pp', 'squeakily/helpers.py'), 68 | 'squeakily.helpers.KenlmModel.remove_non_printing_char': ( 'helpers.html#kenlmmodel.remove_non_printing_char', 69 | 'squeakily/helpers.py'), 70 | 'squeakily.helpers.KenlmModel.remove_unicode_punct': ( 'helpers.html#kenlmmodel.remove_unicode_punct', 71 | 'squeakily/helpers.py'), 72 | 'squeakily.helpers.KenlmModel.replace_unicode_punct': ( 'helpers.html#kenlmmodel.replace_unicode_punct', 73 | 'squeakily/helpers.py'), 74 | 'squeakily.helpers.KenlmModel.strip_accents': ( 'helpers.html#kenlmmodel.strip_accents', 75 | 'squeakily/helpers.py'), 76 | 'squeakily.helpers.LLMLabeler': ('helpers.html#llmlabeler', 'squeakily/helpers.py'), 77 | 'squeakily.helpers.LLMLabeler.__call__': ('helpers.html#llmlabeler.__call__', 'squeakily/helpers.py'), 78 | 'squeakily.helpers.LLMLabeler.__init__': ('helpers.html#llmlabeler.__init__', 'squeakily/helpers.py'), 79 | 'squeakily.helpers.LLMLabelerParser': ('helpers.html#llmlabelerparser', 'squeakily/helpers.py'), 80 | 'squeakily.helpers.SentencePiece': ('helpers.html#sentencepiece', 'squeakily/helpers.py'), 81 | 'squeakily.helpers.SentencePiece.__init__': ( 'helpers.html#sentencepiece.__init__', 82 | 'squeakily/helpers.py'), 83 | 'squeakily.helpers.SentencePiece.do': ('helpers.html#sentencepiece.do', 'squeakily/helpers.py'), 84 | 'squeakily.helpers.get_words': ('helpers.html#get_words', 'squeakily/helpers.py')}}} -------------------------------------------------------------------------------- /squeakily/clean.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_clean.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['fake', 'whitespace', 'unicode_punctuation', 'normalize_whitespace', 'normalize_punctuation', 'remove_empty_lines', 5 | 'replace_urls', 'replace_dates', 'replace_email', 'replace_phone', 'replace_ip', 'replace_credit_card', 6 | 'replace_ssn', 'fix_utf8_encoding', 'clean_code_license'] 7 | 8 | # %% ../nbs/02_clean.ipynb 2 9 | import re 10 | from faker import Faker 11 | import ftfy 12 | 13 | fake = Faker() 14 | 15 | # %% ../nbs/02_clean.ipynb 4 16 | # From: https://github.com/bigscience-workshop/data-preparation/blob/main/preprocessing/training/01b_oscar_cleaning_and_filtering/filtering.py#L95 17 | whitespace = { 18 | " ", 19 | " ", 20 | " ", 21 | " ", 22 | " ", 23 | " ", 24 | " ", 25 | " ", 26 | " ", 27 | " ", 28 | "", 29 | "„", 30 | } 31 | 32 | 33 | def normalize_whitespace( 34 | text: str, # The text to normalize 35 | ) -> str: # The normalized text 36 | """ 37 | Replace the various whitespace characters with the standard one. 38 | """ 39 | text = "".join([char if char not in whitespace else " " for char in text]) 40 | return text 41 | 42 | # %% ../nbs/02_clean.ipynb 6 43 | unicode_punctuation = { 44 | ",": ",", 45 | "。": ".", 46 | "、": ",", 47 | "„": '"', 48 | "”": '"', 49 | "“": '"', 50 | "«": '"', 51 | "»": '"', 52 | "1": '"', 53 | "」": '"', 54 | "「": '"', 55 | "《": '"', 56 | "》": '"', 57 | "´": "'", 58 | "∶": ":", 59 | ":": ":", 60 | "?": "?", 61 | "!": "!", 62 | "(": "(", 63 | ")": ")", 64 | ";": ";", 65 | "–": "-", 66 | "—": " - ", 67 | ".": ". ", 68 | "~": "~", 69 | "’": "'", 70 | "…": "...", 71 | "━": "-", 72 | "〈": "<", 73 | "〉": ">", 74 | "【": "[", 75 | "】": "]", 76 | "%": "%", 77 | "►": "-", 78 | } 79 | 80 | 81 | def normalize_punctuation( 82 | text: str, # The text to normalize 83 | ) -> str: # The normalized text 84 | """ 85 | Replace the various unicode punctuation characters with the standard ones. 86 | """ 87 | text = "".join([unicode_punctuation.get(char, char) for char in text]) 88 | return text 89 | 90 | # %% ../nbs/02_clean.ipynb 8 91 | def remove_empty_lines( 92 | text: str, # The text to remove empty lines from 93 | ) -> str: # The text with empty lines removed 94 | """ 95 | Remove empty lines from the text. 96 | Solution from https://stackoverflow.com/a/3711884/5768407 97 | """ 98 | lines = text.splitlines() 99 | filtered = filter(lambda x: not re.match(r"^\s*$", x), lines) 100 | return "\n".join(filtered) 101 | 102 | # %% ../nbs/02_clean.ipynb 10 103 | def replace_urls( 104 | text: str, # The text to replace URLs in 105 | dummy: str = "https://example.com/", # The dummy text to replace URLs with 106 | ) -> str: # The text with URLs replaced 107 | """Replace urls from text with a dummy.""" 108 | return re.sub(r"http\S+", dummy, text) 109 | 110 | # %% ../nbs/02_clean.ipynb 12 111 | def replace_dates( 112 | text: str, # The text to remove dates from 113 | dummy: str = fake.date(), # The dummy text to replace dates with 114 | ) -> str: # The text with dates replaced 115 | """Replace dates from text with a dummy.""" 116 | return re.sub(r"\d{1,2}/\d{1,2}/\d{4}", dummy, text) 117 | 118 | # %% ../nbs/02_clean.ipynb 15 119 | def replace_email( 120 | text: str, # The text to replace email addresses in 121 | dummy: str = fake.email(), # The dummy text to replace email addresses with 122 | ) -> str: # The text with email addresses replaced 123 | """Replace email addresses from text with a dummy.""" 124 | return re.sub(r"[\w\.-]+@[\w\.-]+", dummy, text) 125 | 126 | # %% ../nbs/02_clean.ipynb 17 127 | def replace_phone( 128 | text: str, # The text to replace phone numbers in 129 | dummy: str = fake.phone_number(), # The dummy text to replace phone numbers with 130 | ) -> str: # The text with phone numbers replaced 131 | """Replace phone numbers from text with a dummy.""" 132 | return re.sub(r"\(?\d{3}\)?-? *\d{3}-? *-?\d{4}", dummy, text) 133 | 134 | # %% ../nbs/02_clean.ipynb 19 135 | def replace_ip( 136 | text, # The text to replace ip addresses in 137 | dummy1: str = fake.ipv4(), # The dummy text to replace ipv4 addresses with 138 | dummy2: str = fake.ipv6(), # The dummy text to replace ipv6 addresses with 139 | ) -> str: # The text with ip addresses replaced 140 | """ 141 | Replace ip addresses from text with a dummy. 142 | Solution from https://github.com/bigcode-project/bigcode-analysis/blob/main/data_analysis/pii/utils/emails_ip_addresses_detection.py#L48 143 | """ 144 | ipv4_pattern = r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}" 145 | text = re.sub(ipv4_pattern, dummy1, text) 146 | ipv6_pattern = r"(?:[0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,7}:|(?:[0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,5}(?::[0-9a-fA-F]{1,4}){1,2}|(?:[0-9a-fA-F]{1,4}:){1,4}(?::[0-9a-fA-F]{1,4}){1,3}|(?:[0-9a-fA-F]{1,4}:){1,3}(?::[0-9a-fA-F]{1,4}){1,4}|(?:[0-9a-fA-F]{1,4}:){1,2}(?::[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:(?:(?::[0-9a-fA-F]{1,4}){1,6})|:(?:(?::[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(?::[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(?:ffff(?::0{1,4}){0,1}:){0,1}(?:(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])|(?:[0-9a-fA-F]{1,4}:){1,4}:(?:(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])" 147 | text = re.sub(ipv6_pattern, dummy2, text) 148 | return text 149 | 150 | # %% ../nbs/02_clean.ipynb 21 151 | def replace_credit_card( 152 | text: str, # The text to replace credit card numbers in 153 | dummy: str = fake.credit_card_number(), # The dummy text to replace credit card numbers with 154 | ) -> str: # The text with credit card numbers replaced 155 | """Replace credit card numbers from text with a dummy.""" 156 | return re.sub(r"\d{4}-\d{4}-\d{4}-\d{4}", dummy, text) 157 | 158 | # %% ../nbs/02_clean.ipynb 23 159 | def replace_ssn( 160 | text: str, # The text to replace social security numbers in 161 | dummy: str = fake.ssn(), # The dummy text to replace social security numbers with 162 | ) -> str: # The text with social security numbers replaced 163 | """Replace social security numbers from text with a dummy.""" 164 | return re.sub(r"\d{3}-\d{2}-\d{4}", dummy, text) 165 | 166 | # %% ../nbs/02_clean.ipynb 25 167 | def fix_utf8_encoding( 168 | text: str, # The text to fix 169 | ) -> str: # The fixed text 170 | """Fix utf8 text using ftfy.""" 171 | return ftfy.fix_text(text) 172 | 173 | # %% ../nbs/02_clean.ipynb 27 174 | def clean_code_license( 175 | code: str, # The code to clean 176 | language: str = "python", # The language of the code 177 | min_lines: int = 3, # The minimum number of lines that need to be removed 178 | ): 179 | import code_ast 180 | from code_ast import ASTVisitor 181 | from code_ast.ast import LEAVE_WHITELIST 182 | 183 | class FirstNonCommentVisitor(ASTVisitor): 184 | def __init__(self): 185 | self.passed_global_node = False 186 | self.first_node = None 187 | 188 | def visit(self, node): 189 | if not self.passed_global_node: 190 | self.passed_global_node = True 191 | return 192 | if self.first_node is None: 193 | if node.child_count > 0 or node.type in LEAVE_WHITELIST: 194 | self.first_node = node 195 | 196 | """Remove the license or other boilerplate comments from the code.""" 197 | ast = code_ast.ast(code, lang=language) 198 | visitor = FirstNonCommentVisitor() 199 | ast.visit(visitor) 200 | start_line = visitor.first_node.start_point[0] 201 | if start_line < min_lines: 202 | return code 203 | else: 204 | return "\n".join(code.splitlines()[start_line:]) 205 | -------------------------------------------------------------------------------- /squeakily/core.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_core.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['logger', 'Pipeline'] 5 | 6 | # %% ../nbs/00_core.ipynb 2 7 | import logging 8 | import os 9 | 10 | from datasets import concatenate_datasets, Dataset 11 | from rich.logging import RichHandler 12 | 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.INFO) 15 | logger.addHandler(RichHandler(rich_tracebacks=True)) 16 | # Turn off logging for datasets 17 | logging.getLogger("datasets").setLevel(logging.ERROR) 18 | 19 | # %% ../nbs/00_core.ipynb 5 20 | class Pipeline: 21 | """ 22 | A pipeline is a collection of datasources and their associated transformations to be run. 23 | """ 24 | 25 | def __init__(self, datasources): # The datasources to be run 26 | self.datasources = datasources 27 | 28 | def __run_filter(self, dataset, column, filter_fn, dry_run, num_proc): 29 | """ 30 | Run a filter on a dataset. 31 | """ 32 | name = filter_fn.__name__ 33 | logger.info(f"Running filter: {name} on {column}") 34 | if dry_run: 35 | logger.info(f"Running in dry-run mode") 36 | return dataset.map( 37 | lambda x: {f"{name}_criteria": filter_fn(x[column], dry_run=True)}, 38 | num_proc=num_proc, 39 | ) 40 | else: 41 | return dataset.filter( 42 | lambda x: filter_fn(x[column]), 43 | num_proc=num_proc, 44 | ) 45 | 46 | def run( 47 | self, 48 | global_filters=[], # Filters to be run at the dataset level rather than the example level 49 | global_cleaners=[], # Cleaners to be run at the dataset level rather than the example level 50 | cleaning_first=False, # Whether to run the cleaning transformations first 51 | globals_first=False, # Whether to run the global transformations first 52 | dry_run=False, # Whether to run the pipeline or only calculate the various criteria and add as a column 53 | num_proc=os.cpu_count(), # Number of processes to use 54 | ): 55 | """ 56 | Run the pipeline. 57 | """ 58 | for i in range(len(self.datasources)): 59 | column = self.datasources[i]["columns"][0] 60 | logger.info(f"Running datasource: {self.datasources[i]['name']}") 61 | if cleaning_first: 62 | for c in self.datasources[i]["cleaners"]: 63 | name = c.__name__ 64 | logger.info(f"Running cleaner: {name} on {column}") 65 | self.datasources[i]["dataset"] = self.datasources[i]["dataset"].map( 66 | lambda x: {column: c(x[column])}, 67 | num_proc=num_proc, 68 | ) 69 | for f in self.datasources[i]["filters"]: 70 | self.datasources[i]["dataset"] = self.__run_filter( 71 | self.datasources[i]["dataset"], column, f, dry_run, num_proc 72 | ) 73 | else: 74 | for f in self.datasources[i]["filters"]: 75 | self.datasources[i]["dataset"] = self.__run_filter( 76 | self.datasources[i]["dataset"], column, f, dry_run, num_proc 77 | ) 78 | for c in self.datasources[i]["cleaners"]: 79 | name = c.__name__ 80 | logger.info(f"Running cleaner: {name} on {column}") 81 | self.datasources[i]["dataset"] = self.datasources[i]["dataset"].map( 82 | lambda x: {column: c(x[column])}, 83 | num_proc=num_proc, 84 | ) 85 | 86 | if len(global_filters) > 0: 87 | # concatenate all datasets 88 | datasets = [ 89 | d["dataset"] 90 | for d in self.datasources 91 | if not d.get("skip_global", False) 92 | ] 93 | global_column = self.datasources[0]["columns"][0] 94 | global_dataset = concatenate_datasets(datasets) 95 | 96 | # Add a column representing the original dataset name 97 | md = [] 98 | for d in self.datasources: 99 | if not d.get("skip_global", False): 100 | md.extend([d["name"]] * len(d["dataset"])) 101 | meta_data = Dataset.from_dict({"meta_data": md}) 102 | global_dataset_with_meta = concatenate_datasets( 103 | [global_dataset, meta_data], axis=1 104 | ) 105 | 106 | # Run the global filters 107 | for f in global_filters: 108 | logger.info(f"Running global filter: {f.__name__}") 109 | global_dataset_with_meta = f( 110 | global_dataset_with_meta, global_column, dry_run=dry_run 111 | ) 112 | 113 | # Split the dataset back up 114 | for i, d in enumerate(self.datasources): 115 | if not d.get("skip_global", False): 116 | self.datasources[i]["dataset"] = global_dataset_with_meta.filter( 117 | lambda x: x["meta_data"] == d["name"], 118 | num_proc=num_proc, 119 | ) 120 | 121 | def export_to_path(self, export_path, output_type="csv"): 122 | """ 123 | Export the cleaned & filtered dataset to a desired export path 124 | 125 | Args: 126 | export_path(str): Path to directory 127 | output_type(str, optional param): Output type of the file to export as 128 | """ 129 | try: 130 | os.makedirs(export_path, exist_ok=True) 131 | except OSError as e: 132 | logger.error(f"Failed to create directory: {export_path}. Error: {str(e)}") 133 | return 134 | 135 | for i, datasource in enumerate(self.datasources): 136 | name = datasource["name"] 137 | filename = f"{name}.csv" 138 | filepath = os.path.join(export_path, filename) 139 | try: 140 | if output_type == "csv": 141 | datasource["dataset"].to_csv(filepath, index=False) 142 | elif output_type == "json": 143 | datasource["dataset"].to_json(filepath, index=False) 144 | else: 145 | logger.error( 146 | f"Invalid output_type: {output_type}. Skipping export for {name} dataset." 147 | ) 148 | logger.info(f"Exported {name} dataset to {filepath}") 149 | except Exception as e: 150 | logger.error( 151 | f"Failed to export {name} dataset to {filepath}. Error: {str(e)}" 152 | ) 153 | -------------------------------------------------------------------------------- /squeakily/filter.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_filter.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['logger', 'zstd_cntxt', 'MINHASH_SEED', 'NON_ALPHA', 'lsh', 'dup_ids', 'check_compression_ratio', 5 | 'check_char_repetition', 'check_flagged_words', 'check_perplexity', 'check_language', 'check_word_number', 6 | 'check_stop_word_ratio', 'check_code_parsability', 'check_labels', 'minhash_dedup'] 7 | 8 | # %% ../nbs/01_filter.ipynb 2 9 | import datasets 10 | import gc 11 | import logging 12 | import multiprocessing 13 | import os 14 | import random 15 | import re 16 | 17 | import networkit as nk 18 | import numpy as np 19 | 20 | from collections import Counter 21 | from datasets import Dataset, Features, Value, Sequence 22 | from datasketch import LeanMinHash, MinHash, MinHashLSH 23 | from rich.logging import RichHandler 24 | from .helpers import flagged_words, get_words 25 | from .helpers import stopwords, stopword_ratios 26 | from tqdm.auto import tqdm 27 | from typing import Set 28 | 29 | # %% ../nbs/01_filter.ipynb 3 30 | logger = logging.getLogger(__name__) 31 | logger.setLevel(logging.INFO) 32 | logger.addHandler(RichHandler(rich_tracebacks=True)) 33 | logger.propagate = False 34 | datasets.logging.set_verbosity_error() 35 | # Turn off logging for datasets 36 | logging.getLogger("datasets").setLevel(logging.ERROR) 37 | 38 | # %% ../nbs/01_filter.ipynb 5 39 | multiprocessing.set_start_method("fork", force=True) 40 | 41 | zstd_cntxt = None 42 | 43 | # %% ../nbs/01_filter.ipynb 6 44 | def _compress_ratio( 45 | doc: str, # document to be analyzed 46 | compression_level: int = 3, # compression level to use 47 | ) -> float: 48 | """ 49 | Returns the ratio of the compressed document to the original document. 50 | """ 51 | global zstd_cntxt 52 | if zstd_cntxt is None: 53 | import zstandard as zstd 54 | 55 | zstd_cntxt = zstd.ZstdCompressor(level=compression_level) 56 | bts = doc.encode("utf-8") 57 | compressed_bts = zstd_cntxt.compress(bts) 58 | try: 59 | ratio = len(compressed_bts) / len(bts) 60 | except ZeroDivisionError: 61 | ratio = 0 62 | return ratio 63 | 64 | # %% ../nbs/01_filter.ipynb 7 65 | def check_compression_ratio( 66 | document, # document to be analyzed 67 | compression_threshold: float = 0.5, # threshold for compression ratio 68 | compression_level: int = 3, # compression level to use 69 | dry_run=False, # if True, returns the ratio of character repetition 70 | ) -> bool: # returns True if document is below threshold 71 | """ 72 | Checks if the document is below the character repetition threshold. 73 | """ 74 | compress_ratio = _compress_ratio(document, compression_level=compression_level) 75 | if dry_run: 76 | return compress_ratio 77 | else: 78 | return compress_ratio > compression_threshold 79 | 80 | # %% ../nbs/01_filter.ipynb 9 81 | def _char_rep_ratio( 82 | doc: str, # document to be analyzed 83 | char_rep_len: int, # length of character repetition 84 | ) -> float: 85 | """ 86 | Returns the ratio of character repetitions in a document. 87 | """ 88 | 89 | def calc_ngrams(doc, n): 90 | char_ngrams = [doc[i : i + n] for i in range(len(doc) - n + 1)] 91 | freq_char_ngrams = Counter(char_ngrams) 92 | return freq_char_ngrams 93 | 94 | freq_char_ngrams = calc_ngrams(doc, char_rep_len) 95 | if len(freq_char_ngrams) == 0: 96 | return 0 97 | freq_char_ngrams = list(freq_char_ngrams.values()) 98 | freq_char_ngrams = sorted(freq_char_ngrams, reverse=True) 99 | val_one = len([el for el in freq_char_ngrams if el == 1]) 100 | num_rep_char_ngrams = min( 101 | int(np.sqrt(len(freq_char_ngrams))), 102 | len(freq_char_ngrams) - val_one, 103 | ) 104 | char_rep_ratio = sum(freq_char_ngrams[:num_rep_char_ngrams]) / sum(freq_char_ngrams) 105 | return char_rep_ratio 106 | 107 | # %% ../nbs/01_filter.ipynb 10 108 | def check_char_repetition( 109 | document, # document to be analyzed 110 | char_repetition_len=10, # length of character repetition 111 | char_repetition_threshold=0.2, # threshold for character repetition 112 | dry_run=False, # if True, returns the ratio of character repetition 113 | ) -> bool: # returns True if document is below threshold 114 | """ 115 | Checks if the document is below the character repetition threshold. 116 | """ 117 | char_rep_ratio = _char_rep_ratio(document, char_repetition_len) 118 | if dry_run: 119 | return char_rep_ratio 120 | else: 121 | return char_rep_ratio <= char_repetition_threshold 122 | 123 | # %% ../nbs/01_filter.ipynb 12 124 | def _flag_word_ratio( 125 | doc: str, # document to be analyzed 126 | flagged_words: list, # list of flagged words 127 | get_words_func: callable, # function to get words from document 128 | ) -> float: # returns ratio of flagged words in document 129 | """ 130 | Returns the ratio of flagged words in a document. 131 | """ 132 | words = get_words_func(doc) 133 | if not words: 134 | return 0.0 135 | flagged_words_ratio = len([word for word in words if word in flagged_words]) / len( 136 | words 137 | ) 138 | if flagged_words_ratio > 1.0: 139 | flagged_words_ratio = 1.0 140 | return flagged_words_ratio 141 | 142 | # %% ../nbs/01_filter.ipynb 13 143 | def check_flagged_words( 144 | document: str, # document to be analyzed 145 | flagged_words: list = flagged_words["en"], # list of flagged words 146 | flagged_words_threshold: float = 0.1, # threshold for flagged words 147 | get_words_func: callable = get_words, # function to get words from document 148 | dry_run: bool = False, # if True, returns the ratio of flagged words 149 | ) -> bool: # returns True if document is below threshold unless dry_run is True 150 | """ 151 | Checks if a document contains a high percentage of flagged words. 152 | """ 153 | cond = True 154 | if flagged_words: 155 | flagged_words_ratio = _flag_word_ratio( 156 | document, 157 | flagged_words, 158 | get_words_func, 159 | ) 160 | if dry_run: 161 | return flagged_words_ratio 162 | 163 | cond = flagged_words_ratio <= flagged_words_threshold 164 | return cond 165 | 166 | # %% ../nbs/01_filter.ipynb 16 167 | def check_perplexity( 168 | document, # document to be analyzed 169 | perplexity_threshold=10_000, # threshold for perplexity 170 | model=None, # model to calculate perplexity 171 | dry_run=False, # if True, returns the perplexity of the document 172 | ) -> bool: # returns True if document is below threshold 173 | """ 174 | Checks if the document is below the perplexity threshold. 175 | """ 176 | perplexity = model.get_perplexity(document) 177 | if dry_run: 178 | return perplexity 179 | else: 180 | return perplexity <= perplexity_threshold 181 | 182 | # %% ../nbs/01_filter.ipynb 19 183 | def check_language( 184 | document, # document to be analyzed 185 | language="en", # language to check 186 | language_threshold=0.9, # threshold for language 187 | model=None, # model to check language 188 | dry_run=False, # if True, returns the language of the document 189 | ) -> bool: # returns True if document is below threshold 190 | """ 191 | Checks if the document is below the language threshold. 192 | """ 193 | lang, prob = model.get_language(document) 194 | if dry_run: 195 | if lang == language: 196 | return prob 197 | else: 198 | return -1.0 199 | else: 200 | return language == lang and prob > language_threshold 201 | 202 | # %% ../nbs/01_filter.ipynb 21 203 | def check_word_number( 204 | document, # document to be analyzed 205 | min_word_threshold=5, # minimum number of words 206 | max_word_threshold=100, # maximum number of words 207 | get_words_func=get_words, # function to get words from document 208 | dry_run=False, # if True, returns the number of words in the document 209 | ) -> bool: # returns True if document is between the minimum and maximum thresholds 210 | """ 211 | Checks if the document is between the minimum and maximum word thresholds. 212 | """ 213 | words = get_words_func(document) 214 | if dry_run: 215 | return len(words) 216 | else: 217 | return len(words) >= min_word_threshold and len(words) <= max_word_threshold 218 | 219 | # %% ../nbs/01_filter.ipynb 23 220 | def check_stop_word_ratio( 221 | document, # document to be analyzed 222 | stop_word_threshold=stopword_ratios["en"], # threshold for stop words 223 | stop_words=stopwords["en"], # list of stop words 224 | get_words_func=get_words, # function to get words from document 225 | dry_run=False, # if True, returns the ratio of stop words in the document 226 | ) -> bool: # returns True if document is below the threshold 227 | """ 228 | Checks if the document contains a high percentage of stop words. 229 | """ 230 | cond = True 231 | if stop_words: 232 | stop_word_ratio = _flag_word_ratio( 233 | document, 234 | stop_words, 235 | get_words_func, 236 | ) 237 | if dry_run: 238 | return stop_word_ratio 239 | else: 240 | cond = stop_word_ratio <= stop_word_threshold 241 | return cond 242 | 243 | # %% ../nbs/01_filter.ipynb 25 244 | def check_code_parsability( 245 | document, # document to be analyzed 246 | program_language="python", # programming language to check 247 | ) -> bool: # returns True if the code is parsable 248 | """ 249 | Checks if the document contains parsable code. 250 | """ 251 | import code_tokenize as ctok 252 | 253 | try: 254 | ctok.tokenize(document, lang=program_language, syntax_error="raise") 255 | return True 256 | except SyntaxError: 257 | return False 258 | 259 | # %% ../nbs/01_filter.ipynb 27 260 | def check_labels( 261 | document, # document to be analyzed 262 | labels: list, # list of labels to check the document against 263 | model=None, # model to check label 264 | dry_run=False, # if True, returns the tags of the document 265 | ) -> bool: # returns True if document relates to any of the labels 266 | """ 267 | Checks if the document relates to any of the labels. 268 | """ 269 | pred_labels = model(document) 270 | if dry_run: 271 | return pred_labels 272 | else: 273 | return any([label in pred_labels for label in labels]) 274 | 275 | # %% ../nbs/01_filter.ipynb 31 276 | MINHASH_SEED = 115 277 | NON_ALPHA = re.compile("[^A-Za-z_0-9]") 278 | 279 | random.seed(MINHASH_SEED) 280 | 281 | lsh: MinHashLSH = None 282 | dup_ids: Set = None 283 | 284 | # %% ../nbs/01_filter.ipynb 32 285 | def _hash_func( 286 | idx: int, # The index of the record. 287 | content: str, # The content to be hashed. 288 | *, 289 | num_perm: int # The number of permutations to use in the MinHash object. 290 | ) -> dict[str, any]: # The MinHash signature and the index of the record. 291 | """ 292 | Embed the content of a record into a MinHash object. This function should be 293 | used with multiprocessing and it scales well with the number of cores. 294 | >>> result = _hash_func(0, "Hello world!", num_perm=128) 295 | >>> result["__id__"] 296 | 0 297 | >>> result["__signature__"].shape 298 | (128,) 299 | >>> result["__signature__"].dtype 300 | dtype('uint64') 301 | """ 302 | m = MinHash(num_perm=num_perm, seed=MINHASH_SEED) 303 | m.update_batch( 304 | [token.encode("utf-8") for token in {t for t in NON_ALPHA.split(content) if t}] 305 | ) 306 | return {"__signature__": m.hashvalues, "__id__": idx} 307 | 308 | # %% ../nbs/01_filter.ipynb 34 309 | def _query_content( 310 | idx: int, # The index of the record. 311 | signature: np.ndarray, # The MinHash signature of the record to be queried. 312 | *, 313 | index: MinHashLSH # The MinHashLSH index. It is shared across all processes when using multiprocessing with fork without copy. 314 | ) -> dict[str, any]: # The query result. 315 | """ 316 | Query the MinHashLSH index for the record. This function can be used with multiprocessing 317 | as long as the index is shared across processes. 318 | """ 319 | return { 320 | "__neighbors__": [ 321 | dup_idx 322 | for dup_idx in index.query( 323 | LeanMinHash(seed=MINHASH_SEED, hashvalues=signature), 324 | ) 325 | if dup_idx != idx # exclude itself 326 | ], 327 | "__id__": idx, 328 | } 329 | 330 | # %% ../nbs/01_filter.ipynb 36 331 | def _jaccard_similarity( 332 | s1: str, s2: str # The first string to compare. # The second string to compare. 333 | ) -> float: # The Jaccard similarity between the two strings. 334 | """ 335 | Calculate the jaccard similarity between two code snippets. 336 | """ 337 | tokens1 = set([t for t in NON_ALPHA.split(s1) if t.strip()]) 338 | tokens2 = set([t for t in NON_ALPHA.split(s2) if t.strip()]) 339 | return len(tokens1 & tokens2) / max(1, len(tokens1 | tokens2)) 340 | 341 | # %% ../nbs/01_filter.ipynb 38 342 | def _calculate_average_false_positive_rate( 343 | clusters: list[list[int]], # The clusters of duplicate records. 344 | reference_records: Dataset, # The reference records. 345 | threshold: float, # The threshold to use for calculating the false positive rate. 346 | column: str, # The column to use for calculating the false positive rate. 347 | ) -> None: 348 | """ 349 | Calculate the average false positive rate within each cluster. The false positives are defined as 350 | number of examples that have a maximum jaccard similarity with any example in the cluster that is 351 | less than the threshold. The false positive rate is defined as the number of false positives divided 352 | by the number of examples in the cluster. The average false positive rate is defined as the average 353 | of the false positive rate across all clusters given. 354 | """ 355 | cluster_false_positive_rates: list[float] = [] 356 | deltas: list[float] = [] 357 | 358 | for cluster in tqdm(clusters, desc="Calculating sampling false positive rate..."): 359 | num_false_positives = 0 360 | ids = sorted(cluster) 361 | for i, x in enumerate(ids): 362 | is_false_positive = True 363 | max_similarity = -float("inf") 364 | for j, y in enumerate(ids): 365 | if i == j: 366 | continue 367 | # TODO This can be redundant but we only calculate this for a small sample 368 | similarity = _jaccard_similarity( 369 | reference_records[x][column], reference_records[y][column] 370 | ) 371 | max_similarity = max(max_similarity, similarity) 372 | if max_similarity >= threshold: 373 | is_false_positive = False 374 | break 375 | if is_false_positive: 376 | num_false_positives += 1 377 | deltas.append(threshold - max_similarity) 378 | cluster_false_positive_rates.append(num_false_positives / len(ids)) 379 | 380 | logger.info( 381 | f"Average false positive rate from {len(clusters)} clusters: {np.mean(cluster_false_positive_rates):.2f}" 382 | ) 383 | logger.info(f"Similarity delta stats from threshold:") 384 | logger.info(f"- Max : {np.max(deltas):0.2f}") 385 | logger.info(f"- Min : {np.min(deltas):0.2f}") 386 | logger.info(f"- Mean: {np.mean(deltas):0.2f}") 387 | logger.info(f"- Std : {np.std(deltas):0.2f}") 388 | 389 | # %% ../nbs/01_filter.ipynb 39 390 | def _find_duplicate_communities( 391 | records: Dataset, # The dataset that contains both `__id__` and `__neighbors__`. 392 | community_detection: bool, # Whether to use community detection to find the duplicate communities, or to use the connected components. 393 | report_false_positive_rate: bool = False, # Whether to report the false positive rate. 394 | reference_records: Dataset = None, # The reference records. It can be an iterable or a Dataset. It is only used when `report_false_positive_rate` is True. 395 | threshold: float = 0.85, # The threshold to use for calculating the false positive rate. 396 | column: str = "content", # The column to use for calculating the false positive rate. 397 | verbose: bool = False, 398 | ) -> ( 399 | Set 400 | ): # The set of duplicate ids that should be removed, leaving only one id in each community. 401 | """ 402 | Find the duplicate communities from the queried dataset. 403 | """ 404 | SAMPLE_MIN_SIZE = 10 405 | SAMPLE_MAX_SIZE = 100 406 | SAMPLE_SIZE = 10 407 | g = nk.graph.Graph() 408 | for record in tqdm(records, desc="Constructing graph..."): 409 | for y in record["__neighbors__"]: 410 | g.addEdge(record["__id__"], y, addMissing=True) 411 | 412 | to_remove: Set = set() 413 | samples: list[list[int]] = [] 414 | if not community_detection: 415 | cc = nk.components.ConnectedComponents(g) 416 | cc.run() 417 | partition = cc.getPartition() 418 | components = list(cc.getComponents()) 419 | random.shuffle(components) 420 | for component in tqdm(components, desc="Iterating over components..."): 421 | component = sorted(component) 422 | to_remove.update(component[1:]) 423 | if ( 424 | len(samples) < SAMPLE_SIZE 425 | and SAMPLE_MAX_SIZE > len(component) >= SAMPLE_MIN_SIZE 426 | ): 427 | samples.append(component[:]) 428 | else: 429 | algo = nk.community.PLM(g, refine=False) 430 | algo.run() 431 | partition = algo.getPartition() 432 | communities = list(partition.getSubsetIds()) 433 | random.shuffle(communities) 434 | # This can be slow if there are many communities 435 | for i in tqdm(communities, desc="Iterating over communities..."): 436 | ids = partition.getMembers(i) 437 | to_remove.update(sorted(ids)[1:]) 438 | if ( 439 | len(samples) < SAMPLE_SIZE 440 | and SAMPLE_MAX_SIZE > len(ids) >= SAMPLE_MIN_SIZE 441 | ): 442 | samples.append(ids) 443 | 444 | if report_false_positive_rate and verbose: 445 | _calculate_average_false_positive_rate( 446 | samples, 447 | reference_records, 448 | threshold, 449 | column, 450 | ) 451 | 452 | return to_remove 453 | 454 | # %% ../nbs/01_filter.ipynb 40 455 | def minhash_dedup( 456 | ds, # The dataset to deduplicate. 457 | column, # The column to use for deduplication. 458 | community_detection: bool = False, # Whether to use community detection to find the duplicate communities, or to use the connected components. 459 | report_false_positive_rate: bool = False, # Whether to report the false positive rate. 460 | threshold: float = 0.85, # The threshold to use for deduplication. 461 | num_perm: int = 128, # The number of permutations to use for minhashing. 462 | dry_run: bool = False, # Whether to run the deduplication in dry run mode. 463 | ) -> Dataset: 464 | """ 465 | Deduplicate the dataset using minhashing as described in the paper "Deduplicating Training Data Makes Language Models Better". 466 | """ 467 | global lsh 468 | global dup_ids 469 | 470 | lsh = MinHashLSH( 471 | threshold=threshold, 472 | num_perm=num_perm, 473 | ) 474 | column_names = ds.column_names 475 | ds = ds.map( 476 | lambda _, idx: {"__id__": idx}, 477 | with_indices=True, 478 | num_proc=os.cpu_count(), 479 | desc="Adding index...", 480 | ) 481 | hashed_ds = ds.map( 482 | function=_hash_func, 483 | fn_kwargs={"num_perm": num_perm}, 484 | input_columns=["__id__", column], 485 | remove_columns=column_names, 486 | num_proc=os.cpu_count(), 487 | desc=f"Fingerprinting...", 488 | ) 489 | with lsh.insertion_session() as session: 490 | for data in tqdm(hashed_ds, desc="Indexing signatures..."): 491 | if data["__id__"] in lsh: 492 | continue 493 | session.insert( 494 | data["__id__"], 495 | LeanMinHash(seed=MINHASH_SEED, hashvalues=data["__signature__"]), 496 | check_duplication=False, 497 | ) 498 | 499 | gc.disable() 500 | gc.freeze() 501 | 502 | conf = { 503 | "threshold": threshold, 504 | "community_detection": community_detection, 505 | "report_false_positive_rate": report_false_positive_rate, 506 | "num_perm": num_perm, 507 | "name": ds.builder_name, 508 | "column": column, 509 | } 510 | queried = hashed_ds.map( 511 | lambda x, y: _query_content(x, y, index=lsh), 512 | num_proc=os.cpu_count(), 513 | features=Features( 514 | { 515 | "__id__": Value(dtype="int64", id=None), 516 | "__neighbors__": Sequence( 517 | feature=Value(dtype="int64", id=None), length=-1, id=None 518 | ), 519 | } 520 | ), 521 | input_columns=["__id__", "__signature__"], 522 | remove_columns=["__signature__"], 523 | desc=f"Querying...", 524 | ) 525 | 526 | del lsh 527 | gc.collect() 528 | 529 | queried = queried.filter( 530 | lambda x: len(x["__neighbors__"]) > 0, 531 | num_proc=os.cpu_count(), 532 | desc="Finding duplicates...", 533 | ) 534 | dup_ids = _find_duplicate_communities( 535 | records=queried, 536 | community_detection=conf["community_detection"], 537 | report_false_positive_rate=conf["report_false_positive_rate"], 538 | reference_records=ds, 539 | threshold=conf["threshold"], 540 | column=conf["column"], 541 | ) 542 | 543 | del queried 544 | gc.collect() 545 | 546 | if dry_run: 547 | final_data = ds.map( 548 | lambda idx: {"duplicate": idx in dup_ids}, 549 | input_columns=["__id__"], 550 | num_proc=os.cpu_count(), 551 | desc="Labeling duplicates...", 552 | ) 553 | else: 554 | final_data = ds.filter( 555 | lambda idx: idx not in dup_ids, 556 | input_columns=["__id__"], 557 | num_proc=os.cpu_count(), 558 | desc="Filtering duplicates...", 559 | ) 560 | return final_data 561 | --------------------------------------------------------------------------------